├── Log ├── record_pretrain.py ├── record_semisup.py └── record_sup.py ├── Main ├── __init__.py ├── augmentation.py ├── dataset.py ├── main(pretrain).py ├── main(semisup).py ├── main(sup).py ├── model.py ├── pargs.py ├── sort.py ├── utils.py └── word2vec.py ├── README.md └── centrality.py /Log/record_pretrain.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/5/31 17:49 3 | # @Author : 4 | # @Email : 5 | # @File : record_pretrain.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import sys 10 | import json 11 | import math 12 | 13 | dirname = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(dirname, '..')) 15 | 16 | 17 | def get_value(acc_list): 18 | mean = round(sum(acc_list) / len(acc_list), 3) 19 | sd = round(math.sqrt(sum([(x - mean) ** 2 for x in acc_list]) / len(acc_list)), 3) 20 | maxx = max(acc_list) 21 | return 'test acc: {:.3f}±{:.3f}'.format(mean, sd), 'max acc: {:.3f}'.format(maxx) 22 | 23 | 24 | if __name__ == '__main__': 25 | log_dir_path = os.path.join(dirname, '..', 'Log') 26 | 27 | for filename in os.listdir(log_dir_path): 28 | if filename[-4:] == 'json': 29 | print(f'【{filename[:-5]}】') 30 | filepath = os.path.join(log_dir_path, filename) 31 | 32 | log = json.load(open(filepath, 'r', encoding='utf-8')) 33 | print('dataset:', log['dataset']) 34 | print('unsup dataset:', log['unsup dataset']) 35 | print('tokenize mode:', log['tokenize mode']) 36 | print('unsup train size:', log['unsup train size']) 37 | print('batch size:', log['batch size']) 38 | print('undirected:', log['undirected']) 39 | if 'model' in log.keys(): 40 | print('model:', log['model']) 41 | print('n layers feat:', log['n layers feat']) 42 | print('n layers conv:', log['n layers conv']) 43 | print('n layers fc:', log['n layers fc']) 44 | print('vector size:', log['vector size']) 45 | print('hidden size:', log['hidden']) 46 | print('global pool:', log['global pool']) 47 | print('skip connection:', log['skip connection']) 48 | print('res branch:', log['res branch']) 49 | print('dropout:', log['dropout']) 50 | print('edge norm:', log['edge norm']) 51 | print('lr:', log['lr']) 52 | print('ft lr:', log['ft_lr']) 53 | print('epochs:', log['epochs']) 54 | print('ft epochs:', log['ft_epochs']) 55 | print('weight decay:', log['weight decay']) 56 | print('centrality:', log['centrality']) 57 | print('aug1:', log['aug1']) 58 | print('aug2:', log['aug2']) 59 | 60 | record = log['record'] 61 | acc_lists = {10: [], 20: [], 40: [], 80: [], 100: [], 200: [], 300: [], 500: [], 10000: []} 62 | 63 | for run_record in record: 64 | for re in run_record['record']: 65 | acc_lists[re['k']].append(re['mean acc']) 66 | for key in acc_lists.keys(): 67 | acc, max_acc = get_value(acc_lists[key]) 68 | print(f'k: {key}, {acc}, {max_acc}') 69 | print() 70 | -------------------------------------------------------------------------------- /Log/record_semisup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/5/7 20:26 3 | # @Author : 4 | # @Email : 5 | # @File : record.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import sys 10 | import json 11 | import math 12 | import numpy as np 13 | 14 | dirname = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(os.path.join(dirname, '..')) 16 | 17 | cal_mean = -10 18 | 19 | if __name__ == '__main__': 20 | log_dir_path = os.path.join(dirname, '..', 'Log') 21 | 22 | for filename in os.listdir(log_dir_path): 23 | if filename[-4:] == 'json': 24 | print(f'【{filename[:-5]}】') 25 | filepath = os.path.join(log_dir_path, filename) 26 | 27 | log = json.load(open(filepath, 'r', encoding='utf-8')) 28 | print('dataset:', log['dataset']) 29 | print('unsup dataset:', log['unsup dataset']) 30 | print('tokenize mode:', log['tokenize mode']) 31 | # print('unsup train size:', log['unsup train size']) 32 | # print('batch size:', log['batch size']) 33 | # print('unsup bs ratio:', log['unsup_bs_ratio']) 34 | # print('undirected:', log['undirected']) 35 | # if 'model' in log.keys(): 36 | # print('model:', log['model']) 37 | # print('n layers feat:', log['n layers feat']) 38 | # print('n layers conv:', log['n layers conv']) 39 | # print('n layers fc:', log['n layers fc']) 40 | # print('vector size:', log['vector size']) 41 | # print('hidden size:', log['hidden']) 42 | # print('global pool:', log['global pool']) 43 | # print('skip connection:', log['skip connection']) 44 | # print('res branch:', log['res branch']) 45 | # print('dropout:', log['dropout']) 46 | # print('edge norm:', log['edge norm']) 47 | # print('lr:', log['lr']) 48 | # print('epochs:', log['epochs']) 49 | # print('weight decay:', log['weight decay']) 50 | # print('lamda:', log['lamda']) 51 | print('centrality:', log['centrality']) 52 | print('aug1:', log['aug1']) 53 | print('aug2:', log['aug2']) 54 | # print('k:', log['k']) 55 | 56 | acc_list = [] 57 | max_epoch_acc_list = [] 58 | for run in log['record']: 59 | # mean_acc = run['mean acc'] 60 | mean_acc = round(np.mean(run['test accs'][cal_mean:]), 3) 61 | max_epoch_acc = round(np.max(run['test accs']), 3) 62 | acc_list.append(mean_acc) 63 | max_epoch_acc_list.append(max_epoch_acc) 64 | 65 | mean = round(sum(acc_list) / len(acc_list), 3) 66 | mean_max_epoch = round(sum(max_epoch_acc_list) / len(max_epoch_acc_list), 3) 67 | sd = round(math.sqrt(sum([(x - mean) ** 2 for x in acc_list]) / len(acc_list)), 3) 68 | sd_max_epoch = round( 69 | math.sqrt(sum([(x - mean_max_epoch) ** 2 for x in max_epoch_acc_list]) / len(max_epoch_acc_list)), 3) 70 | maxx = max(acc_list) 71 | maxx_max_epoch = max(max_epoch_acc_list) 72 | print('test acc | max acc: {:.3f}±{:.3f} | {:.3f}'.format(mean, sd, maxx)) 73 | print('test acc | max acc (max epoch): {:.3f}±{:.3f} | {:.3f}'.format(mean_max_epoch, sd_max_epoch, 74 | maxx_max_epoch)) 75 | print() 76 | -------------------------------------------------------------------------------- /Log/record_sup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2022/9/20 14:38 3 | # @Author : 4 | # @Email : 5 | # @File : record_sup.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import sys 10 | import json 11 | import math 12 | import numpy as np 13 | 14 | dirname = os.path.dirname(os.path.abspath(__file__)) 15 | sys.path.append(os.path.join(dirname, '..')) 16 | 17 | cal_mean = -10 18 | 19 | if __name__ == '__main__': 20 | # log_dir_path = os.path.join(dirname, '..', 'Log') 21 | log_dir_path = os.path.join(dirname, '..', 'Log', 'recorded') 22 | 23 | for filename in os.listdir(log_dir_path): 24 | if filename[-4:] == 'json': 25 | print(f'【{filename[:-5]}】') 26 | filepath = os.path.join(log_dir_path, filename) 27 | 28 | log = json.load(open(filepath, 'r', encoding='utf-8')) 29 | print('dataset:', log['dataset']) 30 | # print('tokenize mode:', log['tokenize mode']) 31 | # print('unsup train size:', log['unsup train size']) 32 | # print('batch size:', log['batch size']) 33 | # print('undirected:', log['undirected']) 34 | # if 'model' in log.keys(): 35 | # print('model:', log['model']) 36 | # print('n layers feat:', log['n layers feat']) 37 | # print('n layers conv:', log['n layers conv']) 38 | # print('n layers fc:', log['n layers fc']) 39 | # print('vector size:', log['vector size']) 40 | # print('hidden size:', log['hidden']) 41 | # print('global pool:', log['global pool']) 42 | # print('skip connection:', log['skip connection']) 43 | # print('res branch:', log['res branch']) 44 | # print('dropout:', log['dropout']) 45 | # print('edge norm:', log['edge norm']) 46 | # print('lr:', log['lr']) 47 | # print('epochs:', log['epochs']) 48 | # print('weight decay:', log['weight decay']) 49 | # print('lamda:', log['lamda']) 50 | # print('centrality:', log['centrality']) 51 | # print('aug1:', log['aug1']) 52 | # print('aug2:', log['aug2']) 53 | # print('use unlabel:', log['use unlabel']) 54 | # print('use unsup loss:', log['use unsup loss']) 55 | # print('k:', log['k']) 56 | 57 | acc_list = [] 58 | max_epoch_acc_list = [] 59 | for run in log['record']: 60 | # mean_acc = run['mean acc'] 61 | mean_acc = round(np.mean(run['test accs'][cal_mean:]), 3) 62 | max_epoch_acc = round(np.max(run['test accs']), 3) 63 | acc_list.append(mean_acc) 64 | max_epoch_acc_list.append(max_epoch_acc) 65 | 66 | mean = round(sum(acc_list) / len(acc_list), 3) 67 | mean_max_epoch = round(sum(max_epoch_acc_list) / len(max_epoch_acc_list), 3) 68 | sd = round(math.sqrt(sum([(x - mean) ** 2 for x in acc_list]) / len(acc_list)), 3) 69 | sd_max_epoch = round( 70 | math.sqrt(sum([(x - mean_max_epoch) ** 2 for x in max_epoch_acc_list]) / len(max_epoch_acc_list)), 3) 71 | maxx = max(acc_list) 72 | maxx_max_epoch = max(max_epoch_acc_list) 73 | print('test acc | max acc: {:.3f}±{:.3f} | {:.3f}'.format(mean, sd, maxx)) 74 | print('test acc | max acc (max epoch): {:.3f}±{:.3f} | {:.3f}'.format(mean_max_epoch, sd_max_epoch, 75 | maxx_max_epoch)) 76 | print('{:.3f}±{:.3f} | {:.3f} | {:.3f}±{:.3f} | {:.3f}'.format(mean, sd, maxx, mean_max_epoch, sd_max_epoch, 77 | maxx_max_epoch)) 78 | print() 79 | -------------------------------------------------------------------------------- /Main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CcQunResearch/RAGCL/bda6a1885cd00fe174ac8fae4a629f1289fa496b/Main/__init__.py -------------------------------------------------------------------------------- /Main/augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : augmentation.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import torch 9 | from torch_scatter import scatter 10 | import networkx as nx 11 | from torch_geometric.utils import degree, to_undirected, to_networkx 12 | import torch_geometric.utils as tg_utils 13 | from torch_geometric.data.batch import Batch 14 | 15 | 16 | def dege_drop_weights(data, aggr='mean', norm=True): 17 | centrality = data.centrality 18 | w_row = centrality[data.edge_index[0]].to(torch.float32) 19 | w_col = centrality[data.edge_index[1]].to(torch.float32) 20 | s_row = torch.log(w_row) if norm else w_row 21 | s_col = torch.log(w_col) if norm else w_col 22 | if aggr == 'sink': 23 | s = s_col 24 | elif aggr == 'source': 25 | s = s_row 26 | elif aggr == 'mean': 27 | s = (s_col + s_row) * 0.5 28 | weights = (s.max() - s) / (s.max() - s.mean()) 29 | return weights 30 | 31 | 32 | def drop_edge_weighted(data, edge_weights, p, threshold): 33 | edge_weights = edge_weights / edge_weights.mean() * p 34 | edge_weights = edge_weights.where(edge_weights < threshold, torch.ones_like(edge_weights) * threshold) 35 | sel_mask = torch.bernoulli(1. - edge_weights).to(torch.bool) 36 | return data.edge_index[:, sel_mask] 37 | 38 | 39 | def node_aug_weights(centrality, norm=True): 40 | s = torch.log(centrality) if norm else centrality 41 | weights = (s.max() - s) / (s.max() - s.mean()) 42 | return weights 43 | 44 | 45 | def aug_node_weighted(node_weights, p, threshold): 46 | node_weights = node_weights / node_weights.mean() * p 47 | node_weights = node_weights.where(node_weights < threshold, torch.ones_like(node_weights) * threshold) 48 | sel_mask = torch.bernoulli(1. - node_weights).to(torch.bool) 49 | return sel_mask 50 | 51 | 52 | def drop_edge(batch_data, aggr, p, threshold): 53 | aug_data = batch_data.clone() 54 | aug_data_list = aug_data.to_data_list() 55 | for i in range(aug_data.num_graphs): 56 | if aug_data_list[i].num_nodes > 1: 57 | edge_weights = dege_drop_weights(aug_data_list[i], aggr=aggr) 58 | aug_edge_index = drop_edge_weighted(aug_data_list[i], edge_weights, p, threshold) 59 | aug_data_list[i].edge_index = aug_edge_index 60 | return Batch.from_data_list(aug_data_list).to(aug_data.x.device) 61 | 62 | 63 | def drop_node(batch_data, p, threshold): 64 | aug_data = batch_data.clone() 65 | aug_data_list = aug_data.to_data_list() 66 | for i in range(aug_data.num_graphs): 67 | node_weights = node_aug_weights(aug_data_list[i].centrality) 68 | sel_mask = aug_node_weighted(node_weights, p, threshold) 69 | sel_mask[0] = True 70 | aug_edge_index, _ = tg_utils.subgraph(sel_mask, aug_data_list[i].edge_index, relabel_nodes=True, 71 | num_nodes=aug_data_list[i].num_nodes) 72 | aug_data_list[i].x = aug_data_list[i].x[sel_mask] 73 | aug_data_list[i].edge_index = aug_edge_index 74 | aug_data_list[i].__num_nodes__ = aug_data_list[i].x.shape[0] 75 | return Batch.from_data_list(aug_data_list).to(aug_data.x.device) 76 | 77 | 78 | def mask_attr(batch_data, p, threshold): 79 | aug_data = batch_data.clone() 80 | aug_data_list = aug_data.to_data_list() 81 | for i in range(aug_data.num_graphs): 82 | node_weights = node_aug_weights(aug_data_list[i].centrality) 83 | sel_mask = aug_node_weighted(node_weights, p, threshold) 84 | sel_mask[0] = True 85 | # mask_token = aug_data_list[i].x.mean(dim=0) 86 | mask_token = torch.zeros_like(aug_data_list[i].x[0], dtype=torch.float) 87 | aug_data_list[i].x[sel_mask] = mask_token 88 | return Batch.from_data_list(aug_data_list).to(aug_data.x.device) 89 | 90 | 91 | def augment(batch_data, augs): 92 | first_aug = augs[0] 93 | first_argu = first_aug.split(',') 94 | if first_argu[0] == 'DropEdge': 95 | aug_data = drop_edge(batch_data, first_argu[1], float(first_argu[2]), float(first_argu[3])) 96 | elif first_argu[0] == 'NodeDrop': 97 | aug_data = drop_node(batch_data, float(first_argu[1]), float(first_argu[2])) 98 | elif first_argu[0] == 'AttrMask': 99 | aug_data = mask_attr(batch_data, float(first_argu[1]), float(first_argu[2])) 100 | if len(augs) > 1: 101 | last_aug = augs[1] 102 | last_argu = last_aug.split(',') 103 | if last_argu[0] == 'NodeDrop': 104 | aug_data = drop_node(batch_data, float(last_argu[1]), float(last_argu[2])) 105 | elif last_argu[0] == 'AttrMask': 106 | aug_data = mask_attr(batch_data, float(last_argu[1]), float(last_argu[2])) 107 | 108 | return aug_data 109 | -------------------------------------------------------------------------------- /Main/dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : dataset.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import json 10 | import torch 11 | from torch_geometric.data import Data, InMemoryDataset 12 | from torch_geometric.utils import to_undirected 13 | 14 | 15 | class TreeDataset(InMemoryDataset): 16 | def __init__(self, root, word_embedding, word2vec, centrality_metric, undirected, transform=None, pre_transform=None, 17 | pre_filter=None): 18 | self.word_embedding = word_embedding 19 | self.word2vec = word2vec 20 | self.centrality_metric = centrality_metric 21 | self.undirected = undirected 22 | super().__init__(root, transform, pre_transform, pre_filter) 23 | self.data, self.slices = torch.load(self.processed_paths[0]) 24 | 25 | @property 26 | def raw_file_names(self): 27 | return os.listdir(self.raw_dir) 28 | 29 | @property 30 | def processed_file_names(self): 31 | return ['data.pt'] 32 | 33 | def download(self): 34 | pass 35 | 36 | def process(self): 37 | data_list = [] 38 | raw_file_names = self.raw_file_names 39 | 40 | for filename in raw_file_names: 41 | centrality = None 42 | y = [] 43 | row = [] 44 | col = [] 45 | no_root_row = [] 46 | no_root_col = [] 47 | 48 | filepath = os.path.join(self.raw_dir, filename) 49 | post = json.load(open(filepath, 'r', encoding='utf-8')) 50 | if self.word_embedding == 'word2vec': 51 | x = self.word2vec.get_sentence_embedding(post['source']['content']).view(1, -1) 52 | elif self.word_embedding == 'tfidf': 53 | tfidf = post['source']['content'] 54 | indices = [[0, int(index_freq.split(':')[0])] for index_freq in tfidf.split()] 55 | values = [int(index_freq.split(':')[1]) for index_freq in tfidf.split()] 56 | if 'label' in post['source'].keys(): 57 | y.append(post['source']['label']) 58 | for i, comment in enumerate(post['comment']): 59 | if self.word_embedding == 'word2vec': 60 | x = torch.cat( 61 | [x, self.word2vec.get_sentence_embedding(comment['content']).view(1, -1)], 0) 62 | elif self.word_embedding == 'tfidf': 63 | indices += [[i + 1, int(index_freq.split(':')[0])] for index_freq in comment['content'].split()] 64 | values += [int(index_freq.split(':')[1]) for index_freq in comment['content'].split()] 65 | if comment['parent'] != -1: 66 | no_root_row.append(comment['parent'] + 1) 67 | no_root_col.append(comment['comment id'] + 1) 68 | row.append(comment['parent'] + 1) 69 | col.append(comment['comment id'] + 1) 70 | 71 | if self.centrality_metric == "Degree": 72 | centrality = torch.tensor(post['centrality']['Degree'], dtype=torch.float32) 73 | elif self.centrality_metric == "PageRank": 74 | centrality = torch.tensor(post['centrality']['Pagerank'], dtype=torch.float32) 75 | elif self.centrality_metric == "Eigenvector": 76 | centrality = torch.tensor(post['centrality']['Eigenvector'], dtype=torch.float32) 77 | elif self.centrality_metric == "Betweenness": 78 | centrality = torch.tensor(post['centrality']['Betweenness'], dtype=torch.float32) 79 | edge_index = [row, col] 80 | no_root_edge_index = [no_root_row, no_root_col] 81 | y = torch.LongTensor(y) 82 | edge_index = to_undirected(torch.LongTensor(edge_index)) if self.undirected else torch.LongTensor(edge_index) 83 | no_root_edge_index = torch.LongTensor(no_root_edge_index) 84 | if self.word_embedding == 'tfidf': 85 | x = torch.sparse_coo_tensor(torch.tensor(indices).t(), values, (len(post['comment']) + 1, 5000), 86 | dtype=torch.float32).to_dense() 87 | one_data = Data(x=x, y=y, edge_index=edge_index, no_root_edge_index=no_root_edge_index, 88 | centrality=centrality) if 'label' in post['source'].keys() else \ 89 | Data(x=x, edge_index=edge_index, no_root_edge_index=no_root_edge_index, centrality=centrality) 90 | data_list.append(one_data) 91 | 92 | if self.pre_filter is not None: 93 | data_list = [data for data in data_list if self.pre_filter(data)] 94 | if self.pre_transform is not None: 95 | data_list = [self.pre_transform(data) for data in data_list] 96 | all_data, slices = self.collate(data_list) 97 | torch.save((all_data, slices), self.processed_paths[0]) 98 | -------------------------------------------------------------------------------- /Main/main(pretrain).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : main(pretrain).py 6 | # @Software: PyCharm 7 | # @Note : 8 | import sys 9 | import os 10 | import os.path as osp 11 | import warnings 12 | 13 | warnings.filterwarnings("ignore") 14 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 15 | dirname = osp.dirname(osp.abspath(__file__)) 16 | sys.path.append(osp.join(dirname, '..')) 17 | 18 | import numpy as np 19 | import time 20 | import torch 21 | import torch.nn.functional as F 22 | from torch.optim import Adam 23 | from torch.optim.lr_scheduler import ReduceLROnPlateau 24 | from torch_geometric.loader import DataLoader 25 | from Main.pargs import pargs 26 | from Main.dataset import TreeDataset 27 | from Main.word2vec import Embedding, collect_sentences, train_word2vec 28 | from Main.sort import sort_dataset 29 | from Main.model import ResGCN_graphcl, BiGCN_graphcl 30 | from Main.utils import create_log_dict_pretrain, write_log, write_json 31 | from Main.augmentation import augment 32 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 33 | 34 | 35 | def pre_train(dataloader, aug1, aug2, model, optimizer, device): 36 | model.train() 37 | total_loss = 0 38 | 39 | augs1 = aug1.split('||') 40 | augs2 = aug2.split('||') 41 | 42 | for data in dataloader: 43 | optimizer.zero_grad() 44 | data = data.to(device) 45 | 46 | aug_data1 = augment(data, augs1) 47 | aug_data2 = augment(data, augs2) 48 | 49 | out1 = model.forward_graphcl(aug_data1) 50 | out2 = model.forward_graphcl(aug_data2) 51 | loss = model.loss_graphcl(out1, out2) 52 | loss.backward() 53 | optimizer.step() 54 | total_loss += loss.item() * data.num_graphs 55 | 56 | return total_loss / len(dataloader.dataset) 57 | 58 | 59 | def fine_tuning(model, aug1, aug2, lamda, optimizer, dataloader, device): 60 | model.train() 61 | 62 | augs1 = aug1.split('||') 63 | augs2 = aug2.split('||') 64 | 65 | total_loss = 0 66 | for data in dataloader: 67 | optimizer.zero_grad() 68 | data = data.to(device) 69 | out = model(data) 70 | 71 | sup_loss = F.nll_loss(out, data.y.long().view(-1)) 72 | 73 | aug_data1 = augment(data, augs1) 74 | aug_data2 = augment(data, augs2) 75 | 76 | out1 = model.forward_graphcl(aug_data1) 77 | out2 = model.forward_graphcl(aug_data2) 78 | unsup_loss = model.loss_graphcl(out1, out2) 79 | 80 | loss = sup_loss + lamda * unsup_loss 81 | 82 | loss.backward() 83 | total_loss += loss.item() * data.num_graphs 84 | optimizer.step() 85 | return total_loss / len(dataloader.dataset) 86 | 87 | 88 | def test(model, dataloader, num_classes, device): 89 | model.eval() 90 | error = 0 91 | 92 | y_true = [] 93 | y_pred = [] 94 | for data in dataloader: 95 | data = data.to(device) 96 | pred = model(data) 97 | error += F.nll_loss(pred, data.y.long().view(-1)).item() * data.num_graphs 98 | y_true += data.y.tolist() 99 | y_pred += pred.max(1).indices.tolist() 100 | 101 | y_true = np.array(y_true) 102 | y_pred = np.array(y_pred) 103 | 104 | acc = round(accuracy_score(y_true, y_pred), 4) 105 | precs = [] 106 | recs = [] 107 | f1s = [] 108 | for label in range(num_classes): 109 | precs.append(round(precision_score(y_true == label, y_pred == label, labels=True), 4)) 110 | recs.append(round(recall_score(y_true == label, y_pred == label, labels=True), 4)) 111 | f1s.append(round(f1_score(y_true == label, y_pred == label, labels=True), 4)) 112 | micro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 113 | micro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 114 | micro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 115 | 116 | macro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 117 | macro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 118 | macro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 119 | return error / len(dataloader.dataset), acc, precs, recs, f1s, \ 120 | [micro_p, micro_r, micro_f1], [macro_p, macro_r, macro_f1] 121 | 122 | 123 | def test_and_log(model, val_loader, test_loader, num_classes, device, epoch, lr, loss, train_acc, ft_log_record): 124 | val_error, val_acc, val_precs, val_recs, val_f1s, val_micro_metric, val_macro_metric = \ 125 | test(model, val_loader, num_classes, device) 126 | test_error, test_acc, test_precs, test_recs, test_f1s, test_micro_metric, test_macro_metric = \ 127 | test(model, test_loader, num_classes, device) 128 | log_info = 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Val ERROR: {:.7f}, Test ERROR: {:.7f}\n Train ACC: {:.4f}, Validation ACC: {:.4f}, Test ACC: {:.4f}\n' \ 129 | .format(epoch, lr, loss, val_error, test_error, train_acc, val_acc, test_acc) \ 130 | + f' Test PREC: {test_precs}, Test REC: {test_recs}, Test F1: {test_f1s}\n' \ 131 | + f' Test Micro Metric(PREC, REC, F1):{test_micro_metric}, Test Macro Metric(PREC, REC, F1):{test_macro_metric}' 132 | 133 | ft_log_record['val accs'].append(val_acc) 134 | ft_log_record['test accs'].append(test_acc) 135 | ft_log_record['test precs'].append(test_precs) 136 | ft_log_record['test recs'].append(test_recs) 137 | ft_log_record['test f1s'].append(test_f1s) 138 | ft_log_record['test micro metric'].append(test_micro_metric) 139 | ft_log_record['test macro metric'].append(test_macro_metric) 140 | return val_error, log_info, ft_log_record 141 | 142 | 143 | if __name__ == '__main__': 144 | args = pargs() 145 | 146 | unsup_train_size = args.unsup_train_size 147 | dataset = args.dataset 148 | unsup_dataset = args.unsup_dataset 149 | vector_size = args.vector_size 150 | device = args.gpu if args.cuda else 'cpu' 151 | runs = args.runs 152 | ft_runs = args.ft_runs 153 | 154 | word_embedding = 'tfidf' if 'tfidf' in dataset else 'word2vec' 155 | lang = 'ch' if 'Weibo' in dataset else 'en' 156 | tokenize_mode = args.tokenize_mode 157 | 158 | split = args.split 159 | batch_size = args.batch_size 160 | undirected = args.undirected 161 | centrality = args.centrality 162 | 163 | weight_decay = args.weight_decay 164 | epochs = args.epochs 165 | ft_epochs = args.ft_epochs 166 | 167 | label_source_path = osp.join(dirname, '..', 'Data', dataset, 'source') 168 | label_dataset_path = osp.join(dirname, '..', 'Data', dataset, 'dataset') 169 | train_path = osp.join(label_dataset_path, 'train') 170 | val_path = osp.join(label_dataset_path, 'val') 171 | test_path = osp.join(label_dataset_path, 'test') 172 | unlabel_dataset_path = osp.join(dirname, '..', 'Data', unsup_dataset, 'dataset') 173 | model_path = osp.join(dirname, '..', 'Model', 174 | f'w2v_{dataset}_{tokenize_mode}_{unsup_train_size}_{vector_size}.model') 175 | 176 | log_name = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime(time.time())) 177 | log_path = osp.join(dirname, '..', 'Log', f'{log_name}.log') 178 | log_json_path = osp.join(dirname, '..', 'Log', f'{log_name}.json') 179 | weight_path = osp.join(dirname, '..', 'Model', f'{log_name}.pt') 180 | 181 | log = open(log_path, 'w') 182 | log_dict = create_log_dict_pretrain(args) 183 | 184 | if not osp.exists(model_path) and word_embedding == 'word2vec': 185 | sentences = collect_sentences(label_source_path, unlabel_dataset_path, unsup_train_size, lang, tokenize_mode) 186 | w2v_model = train_word2vec(sentences, vector_size) 187 | w2v_model.save(model_path) 188 | 189 | word2vec = Embedding(model_path, lang, tokenize_mode) if word_embedding == 'word2vec' else None 190 | 191 | for run in range(runs): 192 | unlabel_dataset = TreeDataset(unlabel_dataset_path, word_embedding, word2vec, centrality, undirected) 193 | unsup_train_loader = DataLoader(unlabel_dataset, batch_size, shuffle=True) 194 | 195 | num_classes = 4 if 'Twitter' in dataset or dataset == 'PHEME' else 2 196 | if args.model == 'ResGCN': 197 | model = ResGCN_graphcl(dataset=unlabel_dataset, num_classes=num_classes, hidden=args.hidden, 198 | num_feat_layers=args.n_layers_feat, num_conv_layers=args.n_layers_conv, 199 | num_fc_layers=args.n_layers_fc, gfn=False, collapse=False, 200 | residual=args.skip_connection,res_branch=args.res_branch, 201 | global_pool=args.global_pool, dropout=args.dropout, 202 | edge_norm=args.edge_norm).to(device) 203 | elif args.model == 'BiGCN': 204 | model = BiGCN_graphcl(unlabel_dataset.num_features, args.hidden, args.hidden, num_classes).to(device) 205 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=weight_decay) 206 | 207 | write_log(log, f'runs:{run}') 208 | log_record = { 209 | 'run': run, 210 | 'record': [] 211 | } 212 | 213 | for epoch in range(1, epochs + 1): 214 | pretrain_loss = pre_train(unsup_train_loader, args.aug1, args.aug2, model, optimizer, device) 215 | 216 | log_info = 'Epoch: {:03d}, Loss: {:.7f}'.format(epoch, pretrain_loss) 217 | write_log(log, log_info) 218 | 219 | torch.save(model.state_dict(), weight_path) 220 | write_log(log, '') 221 | 222 | # ks = [10, 20, 40, 80, 100, 200, 300, 500, 10000] 223 | ks = [10000] 224 | for k in ks: 225 | for r in range(ft_runs): 226 | ft_lr = args.ft_lr 227 | write_log(log, f'k:{k}, r:{r}') 228 | 229 | ft_log_record = {'k': k, 'r': r, 'val accs': [], 'test accs': [], 'test precs': [], 'test recs': [], 230 | 'test f1s': [], 'test micro metric': [], 'test macro metric': []} 231 | 232 | sort_dataset(label_source_path, label_dataset_path, k_shot=k, split=split) 233 | 234 | train_dataset = TreeDataset(train_path, word_embedding, word2vec, centrality, undirected) 235 | val_dataset = TreeDataset(val_path, word_embedding, word2vec, centrality, undirected) 236 | test_dataset = TreeDataset(test_path, word_embedding, word2vec, centrality, undirected) 237 | 238 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 239 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 240 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 241 | 242 | model.load_state_dict(torch.load(weight_path)) 243 | optimizer = Adam(model.parameters(), lr=args.ft_lr, weight_decay=weight_decay) 244 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001) 245 | 246 | val_error, log_info, ft_log_record = test_and_log(model, val_loader, test_loader, num_classes, 247 | device, 0, args.ft_lr, 0, 0, ft_log_record) 248 | write_log(log, log_info) 249 | 250 | for epoch in range(1, ft_epochs + 1): 251 | ft_lr = scheduler.optimizer.param_groups[0]['lr'] 252 | _ = fine_tuning(model, args.aug1, args.aug2, args.lamda, optimizer, train_loader, device) 253 | 254 | train_error, train_acc, _, _, _, _, _ = test(model, train_loader, num_classes, device) 255 | val_error, log_info, ft_log_record = test_and_log(model, val_loader, test_loader, num_classes, 256 | device, epoch, ft_lr, train_error, train_acc, 257 | ft_log_record) 258 | write_log(log, log_info) 259 | 260 | if split == '622': 261 | scheduler.step(val_error) 262 | 263 | ft_log_record['mean acc'] = round(np.mean(ft_log_record['test accs'][-10:]), 3) 264 | log_record['record'].append(ft_log_record) 265 | write_log(log, '') 266 | 267 | log_dict['record'].append(log_record) 268 | write_json(log_dict, log_json_path) 269 | -------------------------------------------------------------------------------- /Main/main(semisup).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : main(semisup).py 6 | # @Software: PyCharm 7 | # @Note : 8 | import sys 9 | import os 10 | import os.path as osp 11 | import warnings 12 | 13 | warnings.filterwarnings("ignore") 14 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 15 | dirname = osp.dirname(osp.abspath(__file__)) 16 | sys.path.append(osp.join(dirname, '..')) 17 | 18 | import numpy as np 19 | import time 20 | import torch 21 | import torch.nn.functional as F 22 | from torch.optim import Adam 23 | from torch.optim.lr_scheduler import ReduceLROnPlateau 24 | from torch_geometric.loader import DataLoader 25 | from torch_geometric.data.batch import Batch 26 | from Main.pargs import pargs 27 | from Main.dataset import TreeDataset 28 | from Main.word2vec import Embedding, collect_sentences, train_word2vec 29 | from Main.sort import sort_dataset 30 | from Main.model import ResGCN_graphcl, BiGCN_graphcl 31 | from Main.utils import create_log_dict_semisup, write_log, write_json 32 | from Main.augmentation import augment 33 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 34 | 35 | 36 | def semisup_train(unsup_train_loader, train_loader, aug1, aug2, model, optimizer, device, lamda): 37 | model.train() 38 | total_loss = 0 39 | nrlabel = 0 40 | 41 | augs1 = aug1.split('||') 42 | augs2 = aug2.split('||') 43 | 44 | for sup_data, unsup_data in zip(train_loader, unsup_train_loader): 45 | optimizer.zero_grad() 46 | sup_data = sup_data.to(device) 47 | unsup_data = unsup_data.to(device) 48 | 49 | out1 = model(sup_data) 50 | sup_loss = F.nll_loss(out1, sup_data.y.long().view(-1)) 51 | 52 | sup_data_list = sup_data.to_data_list() 53 | unsup_data_list = unsup_data.to_data_list() 54 | for item in unsup_data_list: 55 | item.y = torch.LongTensor([nrlabel]).to(device) 56 | data = Batch.from_data_list(sup_data_list + unsup_data_list).to(device) 57 | 58 | out2 = model.forward_graphcl(data) 59 | cl_loss = model.contrastive_loss(out2, data.y, device) 60 | 61 | list = [] 62 | for item in data.to_data_list(): 63 | if item.y.item() == nrlabel: 64 | list.append(item) 65 | data2 = Batch.from_data_list(list).to(device) 66 | aug_data2_1 = augment(data2, augs1) 67 | aug_data2_2 = augment(data2, augs2) 68 | aug_data2_1_out1 = model.forward_graphcl(aug_data2_1) 69 | aug_data2_2_out2 = model.forward_graphcl(aug_data2_2) 70 | nrloss = model.loss_graphcl(aug_data2_1_out1, aug_data2_2_out2) 71 | 72 | loss = sup_loss + 0.001 * cl_loss + 0.001 * nrloss 73 | loss.backward() 74 | optimizer.step() 75 | total_loss += loss.item() * sup_data.num_graphs 76 | return total_loss / len(train_loader.dataset) 77 | 78 | def test(model, dataloader, num_classes, device): 79 | model.eval() 80 | error = 0 81 | 82 | y_true = [] 83 | y_pred = [] 84 | for data in dataloader: 85 | data = data.to(device) 86 | pred = model(data) 87 | error += F.nll_loss(pred, data.y.long().view(-1)).item() * data.num_graphs 88 | y_true += data.y.tolist() 89 | y_pred += pred.max(1).indices.tolist() 90 | 91 | y_true = np.array(y_true) 92 | y_pred = np.array(y_pred) 93 | 94 | acc = round(accuracy_score(y_true, y_pred), 4) 95 | precs = [] 96 | recs = [] 97 | f1s = [] 98 | for label in range(num_classes): 99 | precs.append(round(precision_score(y_true == label, y_pred == label, labels=True), 4)) 100 | recs.append(round(recall_score(y_true == label, y_pred == label, labels=True), 4)) 101 | f1s.append(round(f1_score(y_true == label, y_pred == label, labels=True), 4)) 102 | micro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 103 | micro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 104 | micro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 105 | 106 | macro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 107 | macro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 108 | macro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 109 | return error / len(dataloader.dataset), acc, precs, recs, f1s, \ 110 | [micro_p, micro_r, micro_f1], [macro_p, macro_r, macro_f1] 111 | 112 | 113 | def test_and_log(model, val_loader, test_loader, num_classes, device, epoch, lr, loss, train_acc, log_record): 114 | val_error, val_acc, val_precs, val_recs, val_f1s, val_micro_metric, val_macro_metric = \ 115 | test(model, val_loader, num_classes, device) 116 | test_error, test_acc, test_precs, test_recs, test_f1s, test_micro_metric, test_macro_metric = \ 117 | test(model, test_loader, num_classes, device) 118 | log_info = 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Val ERROR: {:.7f}, Test ERROR: {:.7f}\n Train ACC: {:.4f}, Validation ACC: {:.4f}, Test ACC: {:.4f}\n' \ 119 | .format(epoch, lr, loss, val_error, test_error, train_acc, val_acc, test_acc) \ 120 | + f' Test PREC: {test_precs}, Test REC: {test_recs}, Test F1: {test_f1s}\n' \ 121 | + f' Test Micro Metric(PREC, REC, F1):{test_micro_metric}, Test Macro Metric(PREC, REC, F1):{test_macro_metric}' 122 | 123 | log_record['val accs'].append(val_acc) 124 | log_record['test accs'].append(test_acc) 125 | log_record['test precs'].append(test_precs) 126 | log_record['test recs'].append(test_recs) 127 | log_record['test f1s'].append(test_f1s) 128 | log_record['test micro metric'].append(test_micro_metric) 129 | log_record['test macro metric'].append(test_macro_metric) 130 | return val_error, log_info, log_record 131 | 132 | 133 | if __name__ == '__main__': 134 | args = pargs() 135 | 136 | unsup_train_size = args.unsup_train_size 137 | dataset = args.dataset 138 | unsup_dataset = args.unsup_dataset 139 | vector_size = args.vector_size 140 | device = args.gpu if args.cuda else 'cpu' 141 | runs = args.runs 142 | k = args.k 143 | 144 | word_embedding = 'tfidf' if 'tfidf' in dataset else 'word2vec' 145 | lang = 'ch' if 'Weibo' in dataset else 'en' 146 | tokenize_mode = args.tokenize_mode 147 | 148 | split = args.split 149 | batch_size = args.batch_size 150 | unsup_bs_ratio = args.unsup_bs_ratio 151 | undirected = args.undirected 152 | centrality = args.centrality 153 | 154 | weight_decay = args.weight_decay 155 | lamda = args.lamda 156 | epochs = args.epochs 157 | 158 | label_source_path = osp.join(dirname, '..', 'Data', dataset, 'source') 159 | label_dataset_path = osp.join(dirname, '..', 'Data', dataset, 'dataset') 160 | train_path = osp.join(label_dataset_path, 'train') 161 | val_path = osp.join(label_dataset_path, 'val') 162 | test_path = osp.join(label_dataset_path, 'test') 163 | unlabel_dataset_path = osp.join(dirname, '..', 'Data', unsup_dataset, 'dataset') 164 | model_path = osp.join(dirname, '..', 'Model', 165 | f'w2v_{dataset}_{tokenize_mode}_{unsup_train_size}_{vector_size}.model') 166 | 167 | log_name = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime(time.time())) 168 | log_path = osp.join(dirname, '..', 'Log', f'{log_name}.log') 169 | log_json_path = osp.join(dirname, '..', 'Log', f'{log_name}.json') 170 | 171 | log = open(log_path, 'w') 172 | log_dict = create_log_dict_semisup(args) 173 | 174 | if not osp.exists(model_path) and word_embedding == 'word2vec': 175 | sentences = collect_sentences(label_source_path, unlabel_dataset_path, unsup_train_size, lang, tokenize_mode) 176 | w2v_model = train_word2vec(sentences, vector_size) 177 | w2v_model.save(model_path) 178 | 179 | for run in range(runs): 180 | write_log(log, f'run:{run}') 181 | log_record = {'run': run, 'val accs': [], 'test accs': [], 'test precs': [], 'test recs': [], 'test f1s': [], 182 | 'test micro metric': [], 'test macro metric': []} 183 | 184 | word2vec = Embedding(model_path, lang, tokenize_mode) if word_embedding == 'word2vec' else None 185 | unlabel_dataset = TreeDataset(unlabel_dataset_path, word_embedding, word2vec, centrality, undirected) 186 | unsup_train_loader = DataLoader(unlabel_dataset, batch_size * unsup_bs_ratio, shuffle=True) 187 | 188 | sort_dataset(label_source_path, label_dataset_path, k_shot=k, split=split) 189 | 190 | train_dataset = TreeDataset(train_path, word_embedding, word2vec, centrality, undirected) 191 | val_dataset = TreeDataset(val_path, word_embedding, word2vec, centrality, undirected) 192 | test_dataset = TreeDataset(test_path, word_embedding, word2vec, centrality, undirected) 193 | 194 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 195 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 196 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 197 | 198 | num_classes = train_dataset.num_classes 199 | if args.model == 'ResGCN': 200 | model = ResGCN_graphcl(dataset=train_dataset, num_classes=num_classes, hidden=args.hidden, 201 | num_feat_layers=args.n_layers_feat, num_conv_layers=args.n_layers_conv, 202 | num_fc_layers=args.n_layers_fc, gfn=False, collapse=False, 203 | residual=args.skip_connection, res_branch=args.res_branch, 204 | global_pool=args.global_pool, dropout=args.dropout, 205 | edge_norm=args.edge_norm).to(device) 206 | elif args.model == 'BiGCN': 207 | model = BiGCN_graphcl(train_dataset.num_features, args.hidden, args.hidden, num_classes).to(device) 208 | 209 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=weight_decay) 210 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001) 211 | 212 | val_error, log_info, log_record = test_and_log(model, val_loader, test_loader, num_classes, 213 | device, 0, args.lr, 0, 0, log_record) 214 | write_log(log, log_info) 215 | 216 | for epoch in range(1, epochs + 1): 217 | lr = scheduler.optimizer.param_groups[0]['lr'] 218 | _ = semisup_train(unsup_train_loader, train_loader, args.aug1, args.aug2, model, optimizer, device, lamda) 219 | 220 | train_error, train_acc, _, _, _, _, _ = test(model, train_loader, num_classes, device) 221 | val_error, log_info, log_record = test_and_log(model, val_loader, test_loader, num_classes, device, 222 | epoch, lr, train_error, train_acc, log_record) 223 | write_log(log, log_info) 224 | 225 | if split == '622': 226 | scheduler.step(val_error) 227 | 228 | log_record['mean acc'] = round(np.mean(log_record['test accs'][-10:]), 3) 229 | write_log(log, '') 230 | 231 | log_dict['record'].append(log_record) 232 | write_json(log_dict, log_json_path) 233 | -------------------------------------------------------------------------------- /Main/main(sup).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : main(sup).py.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import sys 9 | import os 10 | import os.path as osp 11 | import warnings 12 | 13 | warnings.filterwarnings("ignore") 14 | os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 15 | dirname = osp.dirname(osp.abspath(__file__)) 16 | sys.path.append(osp.join(dirname, '..')) 17 | 18 | import numpy as np 19 | import time 20 | import torch.nn.functional as F 21 | from torch.optim import Adam 22 | from torch.optim.lr_scheduler import ReduceLROnPlateau 23 | from torch_geometric.loader import DataLoader 24 | from Main.pargs import pargs 25 | from Main.dataset import TreeDataset 26 | from Main.word2vec import Embedding, collect_sentences, train_word2vec 27 | from Main.sort import sort_dataset 28 | from Main.model import ResGCN_graphcl, BiGCN_graphcl 29 | from Main.utils import create_log_dict_sup, write_log, write_json 30 | from Main.augmentation import augment 31 | from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 32 | 33 | 34 | def sup_train(train_loader, aug1, aug2, model, optimizer, device, lamda, use_unsup_loss): 35 | model.train() 36 | total_loss = 0 37 | 38 | augs1 = aug1.split('||') 39 | augs2 = aug2.split('||') 40 | 41 | for data in train_loader: 42 | optimizer.zero_grad() 43 | data = data.to(device) 44 | 45 | out = model(data) 46 | sup_loss = F.nll_loss(out, data.y.long().view(-1)) 47 | 48 | if use_unsup_loss: 49 | aug_data1 = augment(data, augs1) 50 | aug_data2 = augment(data, augs2) 51 | 52 | out1 = model.forward_graphcl(aug_data1) 53 | out2 = model.forward_graphcl(aug_data2) 54 | unsup_loss = model.loss_graphcl(out1, out2) 55 | 56 | loss = sup_loss + lamda * unsup_loss 57 | else: 58 | loss = sup_loss 59 | 60 | loss.backward() 61 | optimizer.step() 62 | total_loss += loss.item() * data.num_graphs 63 | 64 | return total_loss / len(train_loader.dataset) 65 | 66 | 67 | def test(model, dataloader, num_classes, device): 68 | model.eval() 69 | error = 0 70 | 71 | y_true = [] 72 | y_pred = [] 73 | 74 | for data in dataloader: 75 | data = data.to(device) 76 | pred = model(data) 77 | error += F.nll_loss(pred, data.y.long().view(-1)).item() * data.num_graphs 78 | y_true += data.y.tolist() 79 | y_pred += pred.max(1).indices.tolist() 80 | 81 | y_true = np.array(y_true) 82 | y_pred = np.array(y_pred) 83 | 84 | acc = round(accuracy_score(y_true, y_pred), 4) 85 | precs = [] 86 | recs = [] 87 | f1s = [] 88 | for label in range(num_classes): 89 | precs.append(round(precision_score(y_true == label, y_pred == label, labels=True), 4)) 90 | recs.append(round(recall_score(y_true == label, y_pred == label, labels=True), 4)) 91 | f1s.append(round(f1_score(y_true == label, y_pred == label, labels=True), 4)) 92 | micro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 93 | micro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 94 | micro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='micro'), 4) 95 | 96 | macro_p = round(precision_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 97 | macro_r = round(recall_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 98 | macro_f1 = round(f1_score(y_true, y_pred, labels=range(num_classes), average='macro'), 4) 99 | return error / len(dataloader.dataset), acc, precs, recs, f1s, \ 100 | [micro_p, micro_r, micro_f1], [macro_p, macro_r, macro_f1] 101 | 102 | 103 | def test_and_log(model, val_loader, test_loader, num_classes, device, epoch, lr, loss, train_acc, log_record): 104 | val_error, val_acc, val_precs, val_recs, val_f1s, val_micro_metric, val_macro_metric = \ 105 | test(model, val_loader, num_classes, device) 106 | test_error, test_acc, test_precs, test_recs, test_f1s, test_micro_metric, test_macro_metric = \ 107 | test(model, test_loader, num_classes, device) 108 | log_info = 'Epoch: {:03d}, LR: {:7f}, Loss: {:.7f}, Val ERROR: {:.7f}, Test ERROR: {:.7f}\n Train ACC: {:.4f}, Validation ACC: {:.4f}, Test ACC: {:.4f}\n' \ 109 | .format(epoch, lr, loss, val_error, test_error, train_acc, val_acc, test_acc) \ 110 | + f' Test PREC: {test_precs}, Test REC: {test_recs}, Test F1: {test_f1s}\n' \ 111 | + f' Test Micro Metric(PREC, REC, F1):{test_micro_metric}, Test Macro Metric(PREC, REC, F1):{test_macro_metric}' 112 | 113 | log_record['val accs'].append(val_acc) 114 | log_record['test accs'].append(test_acc) 115 | log_record['test precs'].append(test_precs) 116 | log_record['test recs'].append(test_recs) 117 | log_record['test f1s'].append(test_f1s) 118 | log_record['test micro metric'].append(test_micro_metric) 119 | log_record['test macro metric'].append(test_macro_metric) 120 | return val_error, log_info, log_record 121 | 122 | 123 | if __name__ == '__main__': 124 | args = pargs() 125 | 126 | unsup_train_size = args.unsup_train_size 127 | dataset = args.dataset 128 | unsup_dataset = args.unsup_dataset 129 | vector_size = args.vector_size 130 | device = args.gpu if args.cuda else 'cpu' 131 | runs = args.runs 132 | k = args.k 133 | 134 | word_embedding = 'tfidf' if 'tfidf' in dataset else 'word2vec' 135 | lang = 'ch' if 'Weibo' in dataset else 'en' 136 | tokenize_mode = args.tokenize_mode 137 | 138 | split = args.split 139 | batch_size = args.batch_size 140 | undirected = args.undirected 141 | centrality = args.centrality 142 | 143 | weight_decay = args.weight_decay 144 | lamda = args.lamda 145 | epochs = args.epochs 146 | use_unsup_loss = args.use_unsup_loss 147 | 148 | label_source_path = osp.join(dirname, '..', 'Data', dataset, 'source') 149 | label_dataset_path = osp.join(dirname, '..', 'Data', dataset, 'dataset') 150 | train_path = osp.join(label_dataset_path, 'train') 151 | val_path = osp.join(label_dataset_path, 'val') 152 | test_path = osp.join(label_dataset_path, 'test') 153 | unlabel_dataset_path = osp.join(dirname, '..', 'Data', unsup_dataset, 'dataset') 154 | model_path = osp.join(dirname, '..', 'Model', 155 | f'w2v_{dataset}_{tokenize_mode}_{unsup_train_size}_{vector_size}.model') 156 | 157 | log_name = time.strftime("%Y-%m-%d %H-%M-%S", time.localtime(time.time())) 158 | log_path = osp.join(dirname, '..', 'Log', f'{log_name}.log') 159 | log_json_path = osp.join(dirname, '..', 'Log', f'{log_name}.json') 160 | 161 | log = open(log_path, 'w') 162 | log_dict = create_log_dict_sup(args) 163 | 164 | if not osp.exists(model_path) and word_embedding == 'word2vec': 165 | sentences = collect_sentences(label_source_path, unlabel_dataset_path, unsup_train_size, lang, tokenize_mode) 166 | w2v_model = train_word2vec(sentences, vector_size) 167 | w2v_model.save(model_path) 168 | 169 | for run in range(runs): 170 | write_log(log, f'run:{run}') 171 | log_record = {'run': run, 'val accs': [], 'test accs': [], 'test precs': [], 'test recs': [], 'test f1s': [], 172 | 'test micro metric': [], 'test macro metric': []} 173 | 174 | word2vec = Embedding(model_path, lang, tokenize_mode) if word_embedding == 'word2vec' else None 175 | 176 | sort_dataset(label_source_path, label_dataset_path, k_shot=k, split=split) 177 | 178 | train_dataset = TreeDataset(train_path, word_embedding, word2vec, centrality, undirected) 179 | val_dataset = TreeDataset(val_path, word_embedding, word2vec, centrality, undirected) 180 | test_dataset = TreeDataset(test_path, word_embedding, word2vec, centrality, undirected) 181 | 182 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 183 | test_loader = DataLoader(test_dataset, batch_size=batch_size) 184 | val_loader = DataLoader(val_dataset, batch_size=batch_size) 185 | 186 | num_classes = train_dataset.num_classes 187 | if args.model == 'ResGCN': 188 | model = ResGCN_graphcl(dataset=train_dataset, num_classes=num_classes, hidden=args.hidden, 189 | num_feat_layers=args.n_layers_feat, num_conv_layers=args.n_layers_conv, 190 | num_fc_layers=args.n_layers_fc, gfn=False, collapse=False, 191 | residual=args.skip_connection, 192 | res_branch=args.res_branch, global_pool=args.global_pool, dropout=args.dropout, 193 | edge_norm=args.edge_norm).to(device) 194 | elif args.model == 'BiGCN': 195 | model = BiGCN_graphcl(train_dataset.num_features, args.hidden, args.hidden, num_classes).to(device) 196 | optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=weight_decay) 197 | scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.7, patience=5, min_lr=0.000001) 198 | 199 | val_error, log_info, log_record = test_and_log(model, val_loader, test_loader, num_classes, 200 | device, 0, args.lr, 0, 0, log_record) 201 | write_log(log, log_info) 202 | 203 | for epoch in range(1, epochs + 1): 204 | lr = scheduler.optimizer.param_groups[0]['lr'] 205 | _ = sup_train(train_loader, args.aug1, args.aug2, model, optimizer, device, lamda, use_unsup_loss) 206 | 207 | train_error, train_acc, _, _, _, _, _ = test(model, train_loader, num_classes, device) 208 | val_error, log_info, log_record = test_and_log(model, val_loader, test_loader, num_classes, device, epoch, 209 | lr, train_error, train_acc, log_record) 210 | write_log(log, log_info) 211 | 212 | if split == '622': 213 | scheduler.step(val_error) 214 | 215 | log_record['mean acc'] = round(np.mean(log_record['test accs'][-10:]), 3) 216 | write_log(log, '') 217 | 218 | log_dict['record'].append(log_record) 219 | write_json(log_dict, log_json_path) 220 | -------------------------------------------------------------------------------- /Main/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : model.py 6 | # @Software: PyCharm 7 | # @Note : 8 | from functools import partial 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.nn import Linear, BatchNorm1d, Parameter 13 | from torch_scatter import scatter_add, scatter_mean 14 | from torch_geometric.nn import global_mean_pool, global_add_pool 15 | from torch_geometric.nn.conv import MessagePassing 16 | from torch_geometric.utils import remove_self_loops, add_self_loops 17 | from torch_geometric.nn.inits import glorot, zeros 18 | import numpy as np 19 | import random 20 | import copy 21 | 22 | 23 | class GCNConv(MessagePassing): 24 | r"""The graph convolutional operator from the `"Semi-supervised 25 | Classfication with Graph Convolutional Networks" 26 | `_ paper 27 | 28 | .. math:: 29 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 30 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 31 | 32 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 33 | adjacency matrix with inserted self-loops and 34 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 35 | 36 | Args: 37 | in_channels (int): Size of each input sample. 38 | out_channels (int): Size of each output sample. 39 | improved (bool, optional): If set to :obj:`True`, the layer computes 40 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 41 | (default: :obj:`False`) 42 | cached (bool, optional): If set to :obj:`True`, the layer will cache 43 | the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} 44 | \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}`. 45 | (default: :obj:`False`) 46 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 47 | an additive bias. (default: :obj:`True`) 48 | edge_norm (bool, optional): whether or not to normalize adj matrix. 49 | (default: :obj:`True`) 50 | gfn (bool, optional): If `True`, only linear transform (1x1 conv) is 51 | applied to every nodes. (default: :obj:`False`) 52 | """ 53 | 54 | def __init__(self, 55 | in_channels, 56 | out_channels, 57 | improved=False, 58 | cached=False, 59 | bias=True, 60 | edge_norm=True, 61 | gfn=False): 62 | super(GCNConv, self).__init__('add') 63 | 64 | self.in_channels = in_channels 65 | self.out_channels = out_channels 66 | self.improved = improved 67 | self.cached = cached 68 | self.cached_result = None 69 | self.edge_norm = edge_norm 70 | self.gfn = gfn 71 | 72 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 73 | 74 | if bias: 75 | self.bias = Parameter(torch.Tensor(out_channels)) 76 | else: 77 | self.register_parameter('bias', None) 78 | 79 | self.reset_parameters() 80 | 81 | def reset_parameters(self): 82 | glorot(self.weight) 83 | zeros(self.bias) 84 | self.cached_result = None 85 | 86 | @staticmethod 87 | def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None): 88 | if edge_weight is None: 89 | edge_weight = torch.ones((edge_index.size(1),), 90 | dtype=dtype, 91 | device=edge_index.device) 92 | edge_weight = edge_weight.view(-1) 93 | assert edge_weight.size(0) == edge_index.size(1) 94 | 95 | edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) 96 | edge_index = add_self_loops(edge_index, num_nodes=num_nodes) 97 | # Add edge_weight for loop edges. 98 | loop_weight = torch.full((num_nodes,), 99 | 1 if not improved else 2, 100 | dtype=edge_weight.dtype, 101 | device=edge_weight.device) 102 | edge_weight = torch.cat([edge_weight, loop_weight], dim=0) 103 | 104 | edge_index = edge_index[0] 105 | row, col = edge_index 106 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 107 | deg_inv_sqrt = deg.pow(-0.5) 108 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 109 | 110 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 111 | 112 | def forward(self, x, edge_index, edge_weight=None): 113 | """""" 114 | x = torch.matmul(x, self.weight) 115 | if self.gfn: 116 | return x 117 | 118 | if not self.cached or self.cached_result is None: 119 | if self.edge_norm: 120 | edge_index, norm = GCNConv.norm( 121 | edge_index, x.size(0), edge_weight, self.improved, x.dtype) 122 | else: 123 | norm = None 124 | self.cached_result = edge_index, norm 125 | 126 | edge_index, norm = self.cached_result 127 | return self.propagate(edge_index, x=x, norm=norm) 128 | 129 | def message(self, x_j, norm): 130 | if self.edge_norm: 131 | return norm.view(-1, 1) * x_j 132 | else: 133 | return x_j 134 | 135 | def update(self, aggr_out): 136 | if self.bias is not None: 137 | aggr_out = aggr_out + self.bias 138 | return aggr_out 139 | 140 | def __repr__(self): 141 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 142 | self.out_channels) 143 | 144 | 145 | class ResGCN(torch.nn.Module): 146 | """GCN with BN and residual connection.""" 147 | 148 | def __init__(self, dataset=None, num_classes=2, hidden=128, num_feat_layers=1, num_conv_layers=3, 149 | num_fc_layers=2, gfn=False, collapse=False, residual=False, 150 | res_branch="BNConvReLU", global_pool="sum", dropout=0, 151 | edge_norm=True): 152 | super(ResGCN, self).__init__() 153 | assert num_feat_layers == 1, "more feat layers are not now supported" 154 | self.num_classes = num_classes 155 | self.conv_residual = residual 156 | self.fc_residual = False # no skip-connections for fc layers. 157 | self.res_branch = res_branch 158 | self.collapse = collapse 159 | assert "sum" in global_pool or "mean" in global_pool, global_pool 160 | if "sum" in global_pool: 161 | self.global_pool = global_add_pool 162 | else: 163 | self.global_pool = global_mean_pool 164 | self.dropout = dropout 165 | GConv = partial(GCNConv, edge_norm=edge_norm, gfn=gfn) 166 | 167 | self.use_xg = False 168 | if "xg" in dataset[0]: # Utilize graph level features. 169 | self.use_xg = True 170 | self.bn1_xg = BatchNorm1d(dataset[0].xg.size(1)) 171 | self.lin1_xg = Linear(dataset[0].xg.size(1), hidden) 172 | self.bn2_xg = BatchNorm1d(hidden) 173 | self.lin2_xg = Linear(hidden, hidden) 174 | 175 | hidden_in = dataset.num_features 176 | if collapse: 177 | self.bn_feat = BatchNorm1d(hidden_in) 178 | self.bns_fc = torch.nn.ModuleList() 179 | self.lins = torch.nn.ModuleList() 180 | if "gating" in global_pool: 181 | self.gating = torch.nn.Sequential( 182 | Linear(hidden_in, hidden_in), 183 | torch.nn.ReLU(), 184 | Linear(hidden_in, 1), 185 | torch.nn.Sigmoid()) 186 | else: 187 | self.gating = None 188 | for i in range(num_fc_layers - 1): 189 | self.bns_fc.append(BatchNorm1d(hidden_in)) 190 | self.lins.append(Linear(hidden_in, hidden)) 191 | hidden_in = hidden 192 | self.lin_class = Linear(hidden_in, self.num_classes) 193 | else: 194 | self.bn_feat = BatchNorm1d(hidden_in) 195 | feat_gfn = True # set true so GCNConv is feat transform 196 | self.conv_feat = GCNConv(hidden_in, hidden, gfn=feat_gfn) 197 | if "gating" in global_pool: 198 | self.gating = torch.nn.Sequential( 199 | Linear(hidden, hidden), 200 | torch.nn.ReLU(), 201 | Linear(hidden, 1), 202 | torch.nn.Sigmoid()) 203 | else: 204 | self.gating = None 205 | self.bns_conv = torch.nn.ModuleList() 206 | self.convs = torch.nn.ModuleList() 207 | if self.res_branch == "resnet": 208 | for i in range(num_conv_layers): 209 | self.bns_conv.append(BatchNorm1d(hidden)) 210 | self.convs.append(GCNConv(hidden, hidden, gfn=feat_gfn)) 211 | self.bns_conv.append(BatchNorm1d(hidden)) 212 | self.convs.append(GConv(hidden, hidden)) 213 | self.bns_conv.append(BatchNorm1d(hidden)) 214 | self.convs.append(GCNConv(hidden, hidden, gfn=feat_gfn)) 215 | else: 216 | for i in range(num_conv_layers): 217 | self.bns_conv.append(BatchNorm1d(hidden)) 218 | self.convs.append(GConv(hidden, hidden)) 219 | self.bn_hidden = BatchNorm1d(hidden) 220 | self.bns_fc = torch.nn.ModuleList() 221 | self.lins = torch.nn.ModuleList() 222 | for i in range(num_fc_layers - 1): 223 | self.bns_fc.append(BatchNorm1d(hidden)) 224 | self.lins.append(Linear(hidden, hidden)) 225 | self.lin_class = Linear(hidden, self.num_classes) 226 | 227 | # BN initialization. 228 | for m in self.modules(): 229 | if isinstance(m, (torch.nn.BatchNorm1d)): 230 | torch.nn.init.constant_(m.weight, 1) 231 | torch.nn.init.constant_(m.bias, 0.0001) 232 | 233 | def reset_parameters(self): 234 | raise NotImplemented( 235 | "This is prune to bugs (e.g. lead to training on test set in " 236 | "cross validation setting). Create a new model instance instead.") 237 | 238 | def forward(self, data): 239 | x, edge_index, batch = data.x, data.edge_index, data.batch 240 | if self.use_xg: 241 | # xg is (batch_size x its feat dim) 242 | xg = self.bn1_xg(data.xg) 243 | xg = F.relu(self.lin1_xg(xg)) 244 | xg = self.bn2_xg(xg) 245 | xg = F.relu(self.lin2_xg(xg)) 246 | else: 247 | xg = None 248 | 249 | if self.collapse: 250 | return self.forward_collapse(x, edge_index, batch, xg) 251 | elif self.res_branch == "BNConvReLU": 252 | return self.forward_BNConvReLU(x, edge_index, batch, xg) 253 | elif self.res_branch == "BNReLUConv": 254 | return self.forward_BNReLUConv(x, edge_index, batch, xg) 255 | elif self.res_branch == "ConvReLUBN": 256 | return self.forward_ConvReLUBN(x, edge_index, batch, xg) 257 | elif self.res_branch == "resnet": 258 | return self.forward_resnet(x, edge_index, batch, xg) 259 | else: 260 | raise ValueError("Unknown res_branch %s" % self.res_branch) 261 | 262 | def forward_collapse(self, x, edge_index, batch, xg=None): 263 | x = self.bn_feat(x) 264 | gate = 1 if self.gating is None else self.gating(x) 265 | x = self.global_pool(x * gate, batch) 266 | x = x if xg is None else x + xg 267 | for i, lin in enumerate(self.lins): 268 | x_ = self.bns_fc[i](x) 269 | x_ = F.relu(lin(x_)) 270 | x = x + x_ if self.fc_residual else x_ 271 | x = self.lin_class(x) 272 | return F.log_softmax(x, dim=-1) 273 | 274 | def forward_BNConvReLU(self, x, edge_index, batch, xg=None): 275 | x = self.bn_feat(x) 276 | x = F.relu(self.conv_feat(x, edge_index)) 277 | for i, conv in enumerate(self.convs): 278 | x_ = self.bns_conv[i](x) 279 | x_ = F.relu(conv(x_, edge_index)) 280 | x = x + x_ if self.conv_residual else x_ 281 | gate = 1 if self.gating is None else self.gating(x) 282 | x = self.global_pool(x * gate, batch) 283 | x = x if xg is None else x + xg 284 | for i, lin in enumerate(self.lins): 285 | x_ = self.bns_fc[i](x) 286 | x_ = F.relu(lin(x_)) 287 | x = x + x_ if self.fc_residual else x_ 288 | x = self.bn_hidden(x) 289 | if self.dropout > 0: 290 | x = F.dropout(x, p=self.dropout, training=self.training) 291 | x = self.lin_class(x) 292 | return F.log_softmax(x, dim=-1) 293 | 294 | def forward_BNReLUConv(self, x, edge_index, batch, xg=None): 295 | x = self.bn_feat(x) 296 | x = self.conv_feat(x, edge_index) 297 | for i, conv in enumerate(self.convs): 298 | x_ = F.relu(self.bns_conv[i](x)) 299 | x_ = conv(x_, edge_index) 300 | x = x + x_ if self.conv_residual else x_ 301 | x = self.global_pool(x, batch) 302 | x = x if xg is None else x + xg 303 | for i, lin in enumerate(self.lins): 304 | x_ = F.relu(self.bns_fc[i](x)) 305 | x_ = lin(x_) 306 | x = x + x_ if self.fc_residual else x_ 307 | x = F.relu(self.bn_hidden(x)) 308 | if self.dropout > 0: 309 | x = F.dropout(x, p=self.dropout, training=self.training) 310 | x = self.lin_class(x) 311 | return F.log_softmax(x, dim=-1) 312 | 313 | def forward_ConvReLUBN(self, x, edge_index, batch, xg=None): 314 | x = self.bn_feat(x) 315 | x = F.relu(self.conv_feat(x, edge_index)) 316 | x = self.bn_hidden(x) 317 | for i, conv in enumerate(self.convs): 318 | x_ = F.relu(conv(x, edge_index)) 319 | x_ = self.bns_conv[i](x_) 320 | x = x + x_ if self.conv_residual else x_ 321 | x = self.global_pool(x, batch) 322 | x = x if xg is None else x + xg 323 | for i, lin in enumerate(self.lins): 324 | x_ = F.relu(lin(x)) 325 | x_ = self.bns_fc[i](x_) 326 | x = x + x_ if self.fc_residual else x_ 327 | if self.dropout > 0: 328 | x = F.dropout(x, p=self.dropout, training=self.training) 329 | x = self.lin_class(x) 330 | return F.log_softmax(x, dim=-1) 331 | 332 | def forward_resnet(self, x, edge_index, batch, xg=None): 333 | # this mimics resnet architecture in cv. 334 | x = self.bn_feat(x) 335 | x = self.conv_feat(x, edge_index) 336 | for i in range(len(self.convs) // 3): 337 | x_ = x 338 | x_ = F.relu(self.bns_conv[i * 3 + 0](x_)) 339 | x_ = self.convs[i * 3 + 0](x_, edge_index) 340 | x_ = F.relu(self.bns_conv[i * 3 + 1](x_)) 341 | x_ = self.convs[i * 3 + 1](x_, edge_index) 342 | x_ = F.relu(self.bns_conv[i * 3 + 2](x_)) 343 | x_ = self.convs[i * 3 + 2](x_, edge_index) 344 | x = x + x_ 345 | x = self.global_pool(x, batch) 346 | x = x if xg is None else x + xg 347 | for i, lin in enumerate(self.lins): 348 | x_ = F.relu(self.bns_fc[i](x)) 349 | x_ = lin(x_) 350 | x = x + x_ 351 | x = F.relu(self.bn_hidden(x)) 352 | if self.dropout > 0: 353 | x = F.dropout(x, p=self.dropout, training=self.training) 354 | x = self.lin_class(x) 355 | return F.log_softmax(x, dim=-1) 356 | 357 | def __repr__(self): 358 | return self.__class__.__name__ 359 | 360 | 361 | class ResGCN_graphcl(ResGCN): 362 | def __init__(self, **kargs): 363 | super(ResGCN_graphcl, self).__init__(**kargs) 364 | hidden = kargs['hidden'] 365 | self.proj_head = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden)) 366 | 367 | def forward_graphcl(self, data): 368 | x, edge_index, batch = data.x, data.edge_index, data.batch 369 | if self.use_xg: 370 | # xg is (batch_size x its feat dim) 371 | xg = self.bn1_xg(data.xg) 372 | xg = F.relu(self.lin1_xg(xg)) 373 | xg = self.bn2_xg(xg) 374 | xg = F.relu(self.lin2_xg(xg)) 375 | else: 376 | xg = None 377 | 378 | x = self.bn_feat(x) 379 | x = F.relu(self.conv_feat(x, edge_index)) 380 | for i, conv in enumerate(self.convs): 381 | x_ = self.bns_conv[i](x) 382 | x_ = F.relu(conv(x_, edge_index)) 383 | x = x + x_ if self.conv_residual else x_ 384 | gate = 1 if self.gating is None else self.gating(x) 385 | x = self.global_pool(x * gate, batch) 386 | x = x if xg is None else x + xg 387 | for i, lin in enumerate(self.lins): 388 | x_ = self.bns_fc[i](x) 389 | x_ = F.relu(lin(x_)) 390 | x = x + x_ if self.fc_residual else x_ 391 | x = self.bn_hidden(x) 392 | if self.dropout > 0: 393 | x = F.dropout(x, p=self.dropout, training=self.training) 394 | x = self.proj_head(x) 395 | return x 396 | 397 | def loss_graphcl(self, x1, x2, mean=True): 398 | T = 0.5 399 | batch_size, _ = x1.size() 400 | 401 | x1_abs = x1.norm(dim=1) 402 | x2_abs = x2.norm(dim=1) 403 | 404 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 405 | sim_matrix = torch.exp(sim_matrix / T) 406 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 407 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 408 | loss = - torch.log(loss) 409 | if mean: 410 | loss = loss.mean() 411 | return loss 412 | 413 | 414 | ###################################################### 415 | 416 | class vgae_encoder(ResGCN): 417 | def __init__(self, **kargs): 418 | super(vgae_encoder, self).__init__(**kargs) 419 | hidden = kargs['hidden'] 420 | self.encoder_mean = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden)) 421 | self.encoder_std = nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, hidden), 422 | nn.Softplus()) 423 | 424 | def forward(self, data): 425 | x, edge_index = data.x, data.edge_index 426 | if self.use_xg: 427 | # xg is (batch_size x its feat dim) 428 | xg = self.bn1_xg(data.xg) 429 | xg = F.relu(self.lin1_xg(xg)) 430 | xg = self.bn2_xg(xg) 431 | xg = F.relu(self.lin2_xg(xg)) 432 | else: 433 | xg = None 434 | 435 | x = self.bn_feat(x) 436 | x = F.relu(self.conv_feat(x, edge_index)) 437 | for i, conv in enumerate(self.convs): 438 | x_ = self.bns_conv[i](x) 439 | x_ = F.relu(conv(x_, edge_index)) 440 | x = x + x_ if self.conv_residual else x_ 441 | 442 | x_mean = self.encoder_mean(x) 443 | x_std = self.encoder_std(x) 444 | gaussian_noise = torch.randn(x_mean.shape).to(x.device) 445 | x = gaussian_noise * x_std + x_mean 446 | return x, x_mean, x_std 447 | 448 | 449 | class vgae_decoder(torch.nn.Module): 450 | def __init__(self, hidden=128): 451 | super(vgae_decoder, self).__init__() 452 | self.decoder = nn.Sequential(nn.ReLU(inplace=True), nn.Linear(hidden, hidden), nn.ReLU(inplace=True), 453 | nn.Linear(hidden, 1)) 454 | self.sigmoid = nn.Sigmoid() 455 | self.bceloss = nn.BCELoss(reduction='none') 456 | self.pool = global_mean_pool 457 | self.add_pool = global_add_pool 458 | 459 | def forward(self, x, x_mean, x_std, batch, edge_index, edge_index_batch, edge_index_neg, edge_index_neg_batch, 460 | reward): 461 | edge_pos_pred = self.sigmoid(self.decoder(x[edge_index[0]] * x[edge_index[1]])) 462 | edge_neg_pred = self.sigmoid(self.decoder(x[edge_index_neg[0]] * x[edge_index_neg[1]])) 463 | 464 | # for link prediction 465 | import numpy as np 466 | from sklearn.metrics import roc_auc_score, average_precision_score 467 | edge_pred = torch.cat((edge_pos_pred, edge_neg_pred)).detach().cpu().numpy() 468 | edge_auroc = roc_auc_score(np.concatenate((np.ones(edge_pos_pred.shape[0]), np.zeros(edge_neg_pred.shape[0]))), 469 | edge_pred) 470 | edge_auprc = average_precision_score( 471 | np.concatenate((np.ones(edge_pos_pred.shape[0]), np.zeros(edge_neg_pred.shape[0]))), edge_pred) 472 | if True: 473 | return edge_auroc, edge_auprc 474 | # end link prediction 475 | 476 | loss_edge_pos = self.bceloss(edge_pos_pred, torch.ones(edge_pos_pred.shape).to(edge_pos_pred.device)) 477 | loss_edge_neg = self.bceloss(edge_neg_pred, torch.zeros(edge_neg_pred.shape).to(edge_neg_pred.device)) 478 | loss_pos = self.pool(loss_edge_pos, edge_index_batch) 479 | loss_neg = self.pool(loss_edge_neg, edge_index_neg_batch) 480 | loss_rec = loss_pos + loss_neg 481 | if not reward is None: 482 | loss_rec = loss_rec * reward 483 | 484 | # reference: https://github.com/DaehanKim/vgae_pytorch 485 | kl_divergence = - 0.5 * (1 + 2 * torch.log(x_std) - x_mean ** 2 - x_std ** 2).sum(dim=1) 486 | kl_ones = torch.ones(kl_divergence.shape).to(kl_divergence.device) 487 | kl_divergence = self.pool(kl_divergence, batch) 488 | kl_double_norm = 1 / self.add_pool(kl_ones, batch) 489 | kl_divergence = kl_divergence * kl_double_norm 490 | 491 | loss = (loss_rec + kl_divergence).mean() 492 | return loss 493 | 494 | 495 | class vgae(torch.nn.Module): 496 | def __init__(self, encoder, decoder): 497 | super(vgae, self).__init__() 498 | self.encoder = encoder 499 | self.decoder = decoder 500 | 501 | def forward(self, data, reward=None): 502 | x, x_mean, x_std = self.encoder(data) 503 | loss = self.decoder(x, x_mean, x_std, data.batch, data.edge_index, data.edge_index_batch, data.edge_index_neg, 504 | data.edge_index_neg_batch, reward) 505 | return loss 506 | 507 | # for one graph 508 | def generate(self, data): 509 | x, _, _ = self.encoder(data) 510 | prob = torch.einsum('nd,md->nmd', x, x) 511 | prob = self.decoder.decoder(prob).squeeze() 512 | 513 | prob = torch.exp(prob) 514 | prob[torch.isinf(prob)] = 1e10 515 | prob[list(range(x.shape[0])), list(range(x.shape[0]))] = 0 516 | prob = torch.einsum('nm,n->nm', prob, 1 / prob.sum(dim=1)) 517 | 518 | # sparsify 519 | prob[prob < 1e-1] = 0 520 | prob[prob.sum(dim=1) == 0] = 1 521 | prob[list(range(x.shape[0])), list(range(x.shape[0]))] = 0 522 | prob = torch.einsum('nm,n->nm', prob, 1 / prob.sum(dim=1)) 523 | return prob 524 | 525 | 526 | ###################################################### 527 | 528 | 529 | class TDrumorGCN(torch.nn.Module): 530 | def __init__(self, in_feats, hid_feats, out_feats, tddroprate): 531 | super(TDrumorGCN, self).__init__() 532 | self.tddroprate = tddroprate 533 | self.conv1 = GCNConv(in_feats, hid_feats) 534 | self.conv2 = GCNConv(hid_feats + in_feats, out_feats) 535 | 536 | def forward(self, data): 537 | device = data.x.device 538 | x, edge_index = data.x, data.edge_index 539 | 540 | edge_index_list = edge_index.tolist() 541 | if self.tddroprate > 0: 542 | length = len(edge_index_list[0]) 543 | poslist = random.sample(range(length), int(length * (1 - self.tddroprate))) 544 | poslist = sorted(poslist) 545 | tdrow = list(np.array(edge_index_list[0])[poslist]) 546 | tdcol = list(np.array(edge_index_list[1])[poslist]) 547 | edge_index = torch.LongTensor([tdrow, tdcol]).to(device) 548 | 549 | x1 = copy.copy(x.float()) 550 | x = self.conv1(x, edge_index) 551 | x2 = copy.copy(x) 552 | root_extend = torch.zeros(len(data.batch), x1.size(1)).to(device) 553 | batch_size = max(data.batch) + 1 554 | for num_batch in range(batch_size): 555 | index = (torch.eq(data.batch, num_batch)) 556 | root_extend[index] = x1[index][0] 557 | x = torch.cat((x, root_extend), 1) 558 | x = F.relu(x) 559 | x = F.dropout(x, training=self.training) 560 | x = self.conv2(x, edge_index) 561 | x = F.relu(x) 562 | root_extend = torch.zeros(len(data.batch), x2.size(1)).to(device) 563 | for num_batch in range(batch_size): 564 | index = (torch.eq(data.batch, num_batch)) 565 | root_extend[index] = x2[index][0] 566 | x = torch.cat((x, root_extend), 1) 567 | x = scatter_mean(x, data.batch, dim=0) 568 | return x 569 | 570 | 571 | class BUrumorGCN(torch.nn.Module): 572 | def __init__(self, in_feats, hid_feats, out_feats, budroprate): 573 | super(BUrumorGCN, self).__init__() 574 | self.budroprate = budroprate 575 | self.conv1 = GCNConv(in_feats, hid_feats) 576 | self.conv2 = GCNConv(hid_feats + in_feats, out_feats) 577 | 578 | def forward(self, data): 579 | device = data.x.device 580 | x = data.x 581 | edge_index = data.edge_index.clone() 582 | edge_index[0], edge_index[1] = data.edge_index[1], data.edge_index[0] 583 | 584 | edge_index_list = edge_index.tolist() 585 | if self.budroprate > 0: 586 | length = len(edge_index_list[0]) 587 | poslist = random.sample(range(length), int(length * (1 - self.budroprate))) 588 | poslist = sorted(poslist) 589 | burow = list(np.array(edge_index_list[0])[poslist]) 590 | bucol = list(np.array(edge_index_list[1])[poslist]) 591 | edge_index = torch.LongTensor([burow, bucol]).to(device) 592 | 593 | x1 = copy.copy(x.float()) 594 | x = self.conv1(x, edge_index) 595 | x2 = copy.copy(x) 596 | root_extend = torch.zeros(len(data.batch), x1.size(1)).to(device) 597 | batch_size = max(data.batch) + 1 598 | for num_batch in range(batch_size): 599 | index = (torch.eq(data.batch, num_batch)) 600 | root_extend[index] = x1[index][0] 601 | x = torch.cat((x, root_extend), 1) 602 | x = F.relu(x) 603 | x = F.dropout(x, training=self.training) 604 | x = self.conv2(x, edge_index) 605 | x = F.relu(x) 606 | root_extend = torch.zeros(len(data.batch), x2.size(1)).to(device) 607 | for num_batch in range(batch_size): 608 | index = (torch.eq(data.batch, num_batch)) 609 | root_extend[index] = x2[index][0] 610 | x = torch.cat((x, root_extend), 1) 611 | x = scatter_mean(x, data.batch, dim=0) 612 | return x 613 | 614 | 615 | class BiGCN_graphcl(torch.nn.Module): 616 | def __init__(self, in_feats, hid_feats, out_feats, num_classes, tddroprate=0.0, budroprate=0.0): 617 | super(BiGCN_graphcl, self).__init__() 618 | self.TDrumorGCN = TDrumorGCN(in_feats, hid_feats, out_feats, tddroprate) 619 | self.BUrumorGCN = BUrumorGCN(in_feats, hid_feats, out_feats, budroprate) 620 | self.proj_head = torch.nn.Linear((out_feats + hid_feats) * 2, out_feats) 621 | self.fc = torch.nn.Linear((out_feats + hid_feats) * 2, num_classes) 622 | 623 | def forward(self, data): 624 | TD_x = self.TDrumorGCN(data) 625 | BU_x = self.BUrumorGCN(data) 626 | x = torch.cat((BU_x, TD_x), 1) 627 | x = self.fc(x) 628 | return F.log_softmax(x, dim=-1) 629 | 630 | def forward_graphcl(self, data): 631 | TD_x = self.TDrumorGCN(data) 632 | BU_x = self.BUrumorGCN(data) 633 | x = torch.cat((BU_x, TD_x), 1) 634 | x = self.proj_head(x) 635 | return x 636 | 637 | def loss_graphcl(self, x1, x2, mean=True): 638 | T = 0.5 639 | batch_size, _ = x1.size() 640 | 641 | x1_abs = x1.norm(dim=1) 642 | x2_abs = x2.norm(dim=1) 643 | 644 | sim_matrix = torch.einsum('ik,jk->ij', x1, x2) / torch.einsum('i,j->ij', x1_abs, x2_abs) 645 | sim_matrix = torch.exp(sim_matrix / T) 646 | pos_sim = sim_matrix[range(batch_size), range(batch_size)] 647 | loss = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) 648 | loss = - torch.log(loss) 649 | if mean: 650 | loss = loss.mean() 651 | return loss 652 | -------------------------------------------------------------------------------- /Main/pargs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def pargs(): 5 | str2bool = lambda x: x.lower() == "true" 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('--dataset', type=str, default='Weibo') 9 | parser.add_argument('--unsup_dataset', type=str, default='UWeiboV1') 10 | parser.add_argument('--tokenize_mode', type=str, default='naive') 11 | 12 | parser.add_argument('--vector_size', type=int, help='word embedding size', default=200) 13 | parser.add_argument('--unsup_train_size', type=int, help='word embedding unlabel data train size', default=20000) 14 | parser.add_argument('--runs', type=int, default=10) 15 | parser.add_argument('--ft_runs', type=int, default=10) 16 | 17 | parser.add_argument('--cuda', type=str2bool, default=True) 18 | parser.add_argument('--gpu', type=int, default=0) 19 | 20 | # 622 or 802 21 | parser.add_argument('--split', type=str, default='802') 22 | parser.add_argument('--batch_size', type=int, default=32) 23 | parser.add_argument('--unsup_bs_ratio', type=int, default=1) 24 | parser.add_argument('--undirected', type=str2bool, default=True) 25 | 26 | # ResGCN or BiGCN 27 | parser.add_argument('--model', type=str, default='ResGCN') 28 | parser.add_argument('--n_layers_feat', type=int, default=1) 29 | parser.add_argument('--n_layers_conv', type=int, default=3) 30 | parser.add_argument('--n_layers_fc', type=int, default=2) 31 | parser.add_argument('--hidden', type=int, default=128) 32 | parser.add_argument('--global_pool', type=str, default="sum") 33 | parser.add_argument('--skip_connection', type=str2bool, default=True) 34 | parser.add_argument('--res_branch', type=str, default="BNConvReLU") 35 | parser.add_argument('--dropout', type=float, default=0.3) 36 | parser.add_argument('--edge_norm', type=str2bool, default=True) 37 | 38 | parser.add_argument('--lr', type=float, default=0.001) 39 | parser.add_argument('--ft_lr', type=float, default=0.001) 40 | parser.add_argument('--epochs', type=int, default=100) 41 | parser.add_argument('--ft_epochs', type=int, default=100) 42 | parser.add_argument('--weight_decay', type=float, default=0) 43 | parser.add_argument('--lamda', dest='lamda', type=float, default=0.001) 44 | 45 | # Node centrality metric can be chosen from "Degree", "PageRank", "Eigenvector", "Betweenness". 46 | parser.add_argument('--centrality', type=str, default="PageRank") 47 | # Augmentation can be chosen from "DropEdge,mean,0.3,0.7", "NodeDrop,0.3,0.7", "AttrMask,0.3,0.7", 48 | # or augmentation combination "DropEdge,mean,0.3,0.7||NodeDrop,0.3,0.7", "DropEdge,mean,0.3,0.7||AttrMask,0.3,0.7". 49 | # Str like "DropEdge,mean,0.3,0.7" means "AugName,[aggr,]p,threshold". 50 | parser.add_argument('--aug1', type=str, default="DropEdge,mean,0.2,0.7") 51 | parser.add_argument('--aug2', type=str, default="NodeDrop,0.2,0.7") 52 | 53 | parser.add_argument('--use_unlabel', type=str2bool, default=False) 54 | parser.add_argument('--use_unsup_loss', type=str2bool, default=True) 55 | 56 | parser.add_argument('--k', type=int, default=10000) 57 | 58 | args = parser.parse_args() 59 | return args 60 | -------------------------------------------------------------------------------- /Main/sort.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : sort.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import json 10 | import random 11 | from Main.utils import write_post, dataset_makedirs 12 | 13 | 14 | def sort_dataset(label_source_path, label_dataset_path, k_shot=10000, split='622'): 15 | if split == '622': 16 | train_split = 0.6 17 | test_split = 0.8 18 | elif split == '802': 19 | train_split = 0.8 20 | test_split = 0.8 21 | 22 | train_path, val_path, test_path = dataset_makedirs(label_dataset_path) 23 | 24 | label_file_paths = [] 25 | for filename in os.listdir(label_source_path): 26 | label_file_paths.append(os.path.join(label_source_path, filename)) 27 | 28 | all_post = [] 29 | for filepath in label_file_paths: 30 | post = json.load(open(filepath, 'r', encoding='utf-8')) 31 | all_post.append((post['source']['tweet id'], post)) 32 | 33 | random.seed(1234) 34 | random.shuffle(all_post) 35 | train_post = [] 36 | 37 | multi_class = False 38 | for post in all_post: 39 | if post[1]['source']['label'] == 2 or post[1]['source']['label'] == 3: 40 | multi_class = True 41 | 42 | num0 = 0 43 | num1 = 0 44 | num2 = 0 45 | num3 = 0 46 | for post in all_post[:int(len(all_post) * train_split)]: 47 | if post[1]['source']['label'] == 0 and num0 != k_shot: 48 | train_post.append(post) 49 | num0 += 1 50 | if post[1]['source']['label'] == 1 and num1 != k_shot: 51 | train_post.append(post) 52 | num1 += 1 53 | if post[1]['source']['label'] == 2 and num2 != k_shot: 54 | train_post.append(post) 55 | num2 += 1 56 | if post[1]['source']['label'] == 3 and num3 != k_shot: 57 | train_post.append(post) 58 | num3 += 1 59 | if multi_class: 60 | if num0 == k_shot and num1 == k_shot and num2 == k_shot and num3 == k_shot: 61 | break 62 | else: 63 | if num0 == k_shot and num1 == k_shot: 64 | break 65 | if split == '622': 66 | val_post = all_post[int(len(all_post) * train_split):int(len(all_post) * test_split)] 67 | test_post = all_post[int(len(all_post) * test_split):] 68 | elif split == '802': 69 | val_post = all_post[-1:] 70 | test_post = all_post[int(len(all_post) * test_split):] 71 | write_post(train_post, train_path) 72 | write_post(val_post, val_path) 73 | write_post(test_post, test_path) 74 | -------------------------------------------------------------------------------- /Main/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : utils.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import json 9 | import os 10 | import shutil 11 | import jieba 12 | import nltk 13 | 14 | nltk.download('punkt') 15 | from nltk.tokenize import MWETokenizer 16 | 17 | mwe_tokenizer = MWETokenizer([('<', '@', 'user', '>'), ('<', 'url', '>')], separator='') 18 | 19 | 20 | def word_tokenizer(sentence, lang='en', mode='naive'): 21 | if lang == 'en': 22 | if mode == 'nltk': 23 | return mwe_tokenizer.tokenize(nltk.word_tokenize(sentence)) 24 | elif mode == 'naive': 25 | return sentence.split() 26 | if lang == 'ch': 27 | if mode == 'jieba': 28 | return jieba.lcut(sentence) 29 | elif mode == 'naive': 30 | return sentence 31 | 32 | 33 | def write_json(dict, path): 34 | with open(path, 'w', encoding='utf-8') as file_obj: 35 | json.dump(dict, file_obj, indent=4, ensure_ascii=False) 36 | 37 | 38 | def write_post(post_list, path): 39 | for post in post_list: 40 | write_json(post[1], os.path.join(path, f'{post[0]}.json')) 41 | 42 | 43 | def write_log(log, str): 44 | log.write(f'{str}\n') 45 | log.flush() 46 | 47 | 48 | def dataset_makedirs(dataset_path): 49 | train_path = os.path.join(dataset_path, 'train', 'raw') 50 | val_path = os.path.join(dataset_path, 'val', 'raw') 51 | test_path = os.path.join(dataset_path, 'test', 'raw') 52 | 53 | if os.path.exists(dataset_path): 54 | shutil.rmtree(dataset_path) 55 | os.makedirs(train_path) 56 | os.makedirs(val_path) 57 | os.makedirs(test_path) 58 | os.makedirs(os.path.join(dataset_path, 'train', 'processed')) 59 | os.makedirs(os.path.join(dataset_path, 'val', 'processed')) 60 | os.makedirs(os.path.join(dataset_path, 'test', 'processed')) 61 | 62 | return train_path, val_path, test_path 63 | 64 | 65 | def create_log_dict_pretrain(args): 66 | log_dict = {} 67 | log_dict['dataset'] = args.dataset 68 | log_dict['unsup dataset'] = args.unsup_dataset 69 | log_dict['tokenize mode'] = args.tokenize_mode 70 | 71 | log_dict['unsup train size'] = args.unsup_train_size 72 | log_dict['runs'] = args.runs 73 | log_dict['batch size'] = args.batch_size 74 | log_dict['undirected'] = args.undirected 75 | log_dict['model'] = args.model 76 | log_dict['n layers feat'] = args.n_layers_feat 77 | log_dict['n layers conv'] = args.n_layers_conv 78 | log_dict['n layers fc'] = args.n_layers_fc 79 | log_dict['vector size'] = args.vector_size 80 | log_dict['hidden'] = args.hidden 81 | log_dict['global pool'] = args.global_pool 82 | log_dict['skip connection'] = args.skip_connection 83 | log_dict['res branch'] = args.res_branch 84 | log_dict['dropout'] = args.dropout 85 | log_dict['edge norm'] = args.edge_norm 86 | 87 | log_dict['lr'] = args.lr 88 | log_dict['ft_lr'] = args.ft_lr 89 | log_dict['epochs'] = args.epochs 90 | log_dict['ft_epochs'] = args.ft_epochs 91 | log_dict['weight decay'] = args.weight_decay 92 | 93 | log_dict['centrality'] = args.centrality 94 | log_dict['aug1'] = args.aug1 95 | log_dict['aug2'] = args.aug2 96 | 97 | log_dict['record'] = [] 98 | return log_dict 99 | 100 | 101 | def create_log_dict_semisup(args): 102 | log_dict = {} 103 | log_dict['dataset'] = args.dataset 104 | log_dict['unsup dataset'] = args.unsup_dataset 105 | log_dict['tokenize mode'] = args.tokenize_mode 106 | 107 | log_dict['unsup train size'] = args.unsup_train_size 108 | log_dict['runs'] = args.runs 109 | log_dict['batch size'] = args.batch_size 110 | log_dict['unsup_bs_ratio'] = args.unsup_bs_ratio 111 | log_dict['undirected'] = args.undirected 112 | log_dict['model'] = args.model 113 | log_dict['n layers feat'] = args.n_layers_feat 114 | log_dict['n layers conv'] = args.n_layers_conv 115 | log_dict['n layers fc'] = args.n_layers_fc 116 | log_dict['vector size'] = args.vector_size 117 | log_dict['hidden'] = args.hidden 118 | log_dict['global pool'] = args.global_pool 119 | log_dict['skip connection'] = args.skip_connection 120 | log_dict['res branch'] = args.res_branch 121 | log_dict['dropout'] = args.dropout 122 | log_dict['edge norm'] = args.edge_norm 123 | 124 | log_dict['lr'] = args.lr 125 | log_dict['epochs'] = args.epochs 126 | log_dict['weight decay'] = args.weight_decay 127 | log_dict['lamda'] = args.lamda 128 | 129 | log_dict['centrality'] = args.centrality 130 | log_dict['aug1'] = args.aug1 131 | log_dict['aug2'] = args.aug2 132 | 133 | log_dict['k'] = args.k 134 | 135 | log_dict['record'] = [] 136 | return log_dict 137 | 138 | 139 | def create_log_dict_sup(args): 140 | log_dict = {} 141 | log_dict['dataset'] = args.dataset 142 | log_dict['unsup train size'] = args.unsup_train_size 143 | log_dict['tokenize mode'] = args.tokenize_mode 144 | 145 | log_dict['runs'] = args.runs 146 | log_dict['batch size'] = args.batch_size 147 | log_dict['undirected'] = args.undirected 148 | log_dict['model'] = args.model 149 | log_dict['n layers feat'] = args.n_layers_feat 150 | log_dict['n layers conv'] = args.n_layers_conv 151 | log_dict['n layers fc'] = args.n_layers_fc 152 | log_dict['vector size'] = args.vector_size 153 | log_dict['hidden'] = args.hidden 154 | log_dict['global pool'] = args.global_pool 155 | log_dict['skip connection'] = args.skip_connection 156 | log_dict['res branch'] = args.res_branch 157 | log_dict['dropout'] = args.dropout 158 | log_dict['edge norm'] = args.edge_norm 159 | 160 | log_dict['lr'] = args.lr 161 | log_dict['epochs'] = args.epochs 162 | log_dict['weight decay'] = args.weight_decay 163 | log_dict['lamda'] = args.lamda 164 | 165 | log_dict['centrality'] = args.centrality 166 | log_dict['aug1'] = args.aug1 167 | log_dict['aug2'] = args.aug2 168 | 169 | log_dict['use unlabel'] = args.use_unlabel 170 | log_dict['use unsup loss'] = args.use_unsup_loss 171 | 172 | log_dict['k'] = args.k 173 | 174 | log_dict['record'] = [] 175 | return log_dict 176 | -------------------------------------------------------------------------------- /Main/word2vec.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : word2vec.py 6 | # @Software: PyCharm 7 | # @Note : 8 | import os 9 | import os.path as osp 10 | import json 11 | import random 12 | import torch 13 | from gensim.models import Word2Vec 14 | from utils import word_tokenizer 15 | 16 | 17 | class Embedding(): 18 | def __init__(self, w2v_path, lang, tokenize_mode): 19 | self.w2v_path = w2v_path 20 | self.lang = lang 21 | self.tokenize_mode = tokenize_mode 22 | self.idx2word = [] 23 | self.word2idx = {} 24 | self.embedding_matrix = self.make_embedding() 25 | 26 | def add_embedding(self, word): 27 | vector = torch.empty(1, self.embedding_dim) 28 | torch.nn.init.uniform_(vector) 29 | self.word2idx[word] = len(self.word2idx) 30 | self.idx2word.append(word) 31 | self.embedding_matrix = torch.cat([self.embedding_matrix, vector], 0) 32 | 33 | def make_embedding(self): 34 | self.embedding_matrix = [] 35 | self.embedding = Word2Vec.load(self.w2v_path) 36 | self.embedding_dim = self.embedding.vector_size 37 | for i, word in enumerate(self.embedding.wv.key_to_index): 38 | # e.g. self.word2index['魯'] = 1 39 | # e.g. self.index2word[1] = '魯' 40 | self.word2idx[word] = len(self.word2idx) 41 | self.idx2word.append(word) 42 | self.embedding_matrix.append(self.embedding.wv.get_vector(word, norm=True)) 43 | self.embedding_matrix = torch.tensor(self.embedding_matrix) 44 | self.add_embedding("") 45 | print("total words: {}".format(len(self.embedding_matrix))) 46 | return self.embedding_matrix 47 | 48 | def sentence_word2idx(self, sen): 49 | sentence_idx = [] 50 | for word in word_tokenizer(sen, self.lang, self.tokenize_mode): 51 | if (word in self.word2idx.keys()): 52 | sentence_idx.append(self.word2idx[word]) 53 | else: 54 | sentence_idx.append(self.word2idx[""]) 55 | return sentence_idx 56 | 57 | def get_word_embedding(self, sen): 58 | sentence_idx = self.sentence_word2idx(sen) 59 | word_embedding = self.embedding_matrix[sentence_idx] 60 | return word_embedding 61 | 62 | def get_sentence_embedding(self, sen): 63 | word_embedding = self.get_word_embedding(sen) 64 | sen_embedding = torch.sum(word_embedding, dim=0) 65 | return sen_embedding 66 | 67 | def labels_to_tensor(self, y): 68 | y = [int(label) for label in y] 69 | return torch.LongTensor(y) 70 | 71 | 72 | def collect_sentences(label_source_path, unlabel_dataset_path, unsup_train_size, lang, tokenize_mode): 73 | unlabel_path = osp.join(unlabel_dataset_path, 'raw') 74 | sentences = collect_label_sentences(label_source_path) + collect_unlabel_sentences(unlabel_path, unsup_train_size) 75 | # sentences = collect_label_sentences(label_source_path) 76 | sentences = [word_tokenizer(sentence, lang=lang, mode=tokenize_mode) for sentence in sentences] 77 | return sentences 78 | 79 | 80 | def collect_label_sentences(path): 81 | sentences = [] 82 | for filename in os.listdir(path): 83 | filepath = osp.join(path, filename) 84 | post = json.load(open(filepath, 'r', encoding='utf-8')) 85 | sentences.append(post['source']['content']) 86 | for commnet in post['comment']: 87 | sentences.append(commnet['content']) 88 | return sentences 89 | 90 | 91 | def collect_unlabel_sentences(path, unsup_train_size): 92 | sentences = [] 93 | filenames = os.listdir(path) 94 | random.shuffle(filenames) 95 | for i, filename in enumerate(filenames): 96 | if i == unsup_train_size: 97 | break 98 | filepath = osp.join(path, filename) 99 | post = json.load(open(filepath, 'r', encoding='utf-8')) 100 | sentences.append(post['source']['content']) 101 | for commnet in post['comment']: 102 | sentences.append(commnet['content']) 103 | return sentences 104 | 105 | 106 | def train_word2vec(sentences, vector_size): 107 | model = Word2Vec(sentences, vector_size=vector_size, window=5, min_count=5, workers=12, epochs=30, sg=1) 108 | return model 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAGCL 2 | 3 | Source code for RAGCL in paper: 4 | 5 | **Propagation Tree is not Deep: Adaptive Graph Contrastive Learning Approach for Rumor Detection** 6 | 7 | ## Run 8 | 9 | You need to compute node centrality with ```centrality.py``` to process your own data before running the code. 10 | 11 | The code can be run in the following ways: 12 | 13 | ```shell script 14 | nohup python main\(sup\).py --gpu 0 & 15 | ``` 16 | 17 | Dataset is available at [https://pan.baidu.com/s/1Kl5IQjU3a_pdt90YmNNVsQ?pwd=qqul](https://pan.baidu.com/s/1Kl5IQjU3a_pdt90YmNNVsQ?pwd=qqul) 18 | 19 | ## Dependencies 20 | 21 | - [pytorch](https://pytorch.org/) == 1.12.0 22 | 23 | - [torch-geometric](https://github.com/pyg-team/pytorch_geometric) == 2.1.0 24 | 25 | - [gensim](https://radimrehurek.com/gensim/index.html) == 4.0.1 26 | 27 | ## Citation 28 | 29 | If this work is helpful, please kindly cite as: 30 | 31 | ``` 32 | @inproceedings{cui2024propagation, 33 | title={Propagation Tree Is Not Deep: Adaptive Graph Contrastive Learning Approach for Rumor Detection}, 34 | author={Cui, Chaoqun and Jia, Caiyan}, 35 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 36 | volume={38}, 37 | number={1}, 38 | pages={73--81}, 39 | year={2024} 40 | } 41 | ``` 42 | -------------------------------------------------------------------------------- /centrality.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 3 | # @Author : 4 | # @Email : 5 | # @File : centrality.py 6 | # @Software: PyCharm 7 | # @Note : 8 | from torch_scatter import scatter 9 | import networkx as nx 10 | from torch_geometric.utils import degree, to_undirected, to_networkx 11 | import os 12 | import json 13 | import torch 14 | from torch_geometric.data import Data, Batch 15 | 16 | 17 | def write_json(dict, path): 18 | with open(path, 'w', encoding='utf-8') as file_obj: 19 | json.dump(dict, file_obj, indent=4, ensure_ascii=False) 20 | 21 | 22 | def get_root_index(data): 23 | root_index = scatter(torch.ones((data.num_nodes,)).to(data.x.device).to(torch.long), data.batch, reduce='sum') 24 | for i in range(root_index.shape[0] - 1, -1, -1): 25 | root_index[i] = 0 26 | for j in range(i): 27 | root_index[i] += root_index[j] 28 | return root_index 29 | 30 | 31 | # need normalization 32 | # root centrality == no children 1 level reply centrality 33 | def degree_centrality(data): 34 | ud_edge_index = to_undirected(data.edge_index) 35 | # out degree 36 | centrality = degree(ud_edge_index[1]) 37 | centrality[0] = 1 38 | centrality = centrality - 1.0 + 1e-8 39 | return centrality 40 | 41 | 42 | # need normalization 43 | # root centrality = no children 1 level reply centrality 44 | def pagerank_centrality(data, damp=0.85, k=10): 45 | device = data.x.device 46 | bu_edge_index = data.edge_index.clone() 47 | bu_edge_index[0], bu_edge_index[1] = data.edge_index[1], data.edge_index[0] 48 | 49 | num_nodes = data.num_nodes 50 | deg_out = degree(bu_edge_index[0]) 51 | centrality = torch.ones((num_nodes,)).to(device).to(torch.float32) 52 | 53 | for i in range(k): 54 | edge_msg = centrality[bu_edge_index[0]] / deg_out[bu_edge_index[0]] 55 | agg_msg = scatter(edge_msg, bu_edge_index[1], reduce='sum') 56 | pad = torch.zeros((len(centrality) - len(agg_msg),)).to(device).to(torch.float32) 57 | agg_msg = torch.cat((agg_msg, pad), 0) 58 | 59 | centrality = (1 - damp) * centrality + damp * agg_msg 60 | 61 | centrality[0] = centrality.min().item() 62 | return centrality 63 | 64 | 65 | # need normalization 66 | # root centrality == no children 1 level reply centrality 67 | def eigenvector_centrality(data): 68 | bu_data = data.clone() 69 | bu_data.edge_index = bu_data.no_root_edge_index 70 | # bu_data.edge_index[0], bu_data.edge_index[1] = data.no_root_edge_index[1], data.no_root_edge_index[0] 71 | 72 | bu_data.edge_index = to_undirected(bu_data.edge_index) 73 | 74 | graph = to_networkx(bu_data) 75 | centrality = nx.eigenvector_centrality(graph, tol=1e-3) 76 | centrality = [centrality[i] for i in range(bu_data.num_nodes)] 77 | centrality = torch.tensor(centrality, dtype=torch.float32).to(bu_data.x.device) 78 | return centrality 79 | 80 | 81 | # need normalization 82 | # root centrality == no children 1 level reply centrality 83 | def betweenness_centrality(data): 84 | bu_data = data.clone() 85 | # bu_data.edge_index[0], bu_data.edge_index[1] = data.edge_index[1], data.edge_index[0] 86 | 87 | graph = to_networkx(bu_data) 88 | centrality = nx.betweenness_centrality(graph) 89 | centrality = [centrality[i] if centrality[i] != 0 else centrality[i] + 1e-16 for i in range(bu_data.num_nodes)] 90 | centrality = torch.tensor(centrality, dtype=torch.float32).to(bu_data.x.device) 91 | return centrality 92 | 93 | 94 | def calculate_centrality(source_path): 95 | raw_file_names = os.listdir(source_path) 96 | for filename in raw_file_names: 97 | filepath = os.path.join(source_path, filename) 98 | post = json.load(open(filepath, 'r', encoding='utf-8')) 99 | 100 | x = torch.ones(len(post['comment']) + 1, 20) 101 | row = [] 102 | col = [] 103 | no_root_row = [] 104 | no_root_col = [] 105 | filepath = os.path.join(source_path, filename) 106 | post = json.load(open(filepath, 'r', encoding='utf-8')) 107 | 108 | for i, comment in enumerate(post['comment']): 109 | if comment['parent'] != -1: 110 | no_root_row.append(comment['parent'] + 1) 111 | no_root_col.append(comment['comment id'] + 1) 112 | row.append(comment['parent'] + 1) 113 | col.append(comment['comment id'] + 1) 114 | edge_index = [row, col] 115 | no_root_edge_index = [no_root_row, no_root_col] 116 | edge_index = torch.LongTensor(edge_index) 117 | no_root_edge_index = torch.LongTensor(no_root_edge_index) 118 | one_data = Data(x=x, edge_index=edge_index, no_root_edge_index=no_root_edge_index) 119 | 120 | if one_data.num_nodes > 1: 121 | dc = degree_centrality(Batch.from_data_list([one_data])).tolist() 122 | pc = pagerank_centrality(Batch.from_data_list([one_data])).tolist() 123 | ec = eigenvector_centrality(Batch.from_data_list([one_data])).tolist() 124 | bc = betweenness_centrality(Batch.from_data_list([one_data])).tolist() 125 | else: 126 | dc = pc = ec = bc = [1] 127 | 128 | post['centrality'] = {} 129 | post['centrality']['Degree'] = dc 130 | post['centrality']['Pagerank'] = pc 131 | post['centrality']['Eigenvector'] = ec 132 | post['centrality']['Betweenness'] = bc 133 | 134 | write_json(post, filepath) 135 | --------------------------------------------------------------------------------