├── .gitattributes ├── data ├── gos │ ├── test_idx.npy │ ├── val_idx.npy │ ├── train_idx.npy │ ├── struct_temp.pkl │ ├── news_mapping.pickle │ ├── user_mapping.pickle │ └── gos_news_list.txt └── poli │ ├── val_idx.npy │ ├── test_idx.npy │ ├── train_idx.npy │ ├── struct_temp.pkl │ ├── news_mapping.pickle │ ├── user_mapping.pickle │ ├── poli_news_list.txt │ └── label.txt ├── __pycache__ ├── HGAT.cpython-36.pyc ├── HGAT.cpython-37.pyc ├── HGSL.cpython-37.pyc ├── Optim.cpython-36.pyc ├── Optim.cpython-37.pyc ├── layer.cpython-36.pyc ├── layer.cpython-37.pyc ├── Constants.cpython-36.pyc ├── Constants.cpython-37.pyc ├── Metrics.cpython-36.pyc ├── Metrics.cpython-37.pyc ├── dataLoader.cpython-36.pyc ├── dataLoader.cpython-37.pyc ├── graphConstruct.cpython-36.pyc ├── graphConstruct.cpython-37.pyc ├── TransformerBlock.cpython-36.pyc └── TransformerBlock.cpython-37.pyc ├── .idea ├── vcs.xml ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── encodings.xml ├── modules.xml ├── misc.xml ├── sshConfigs.xml ├── MS-HGAT.iml ├── webServers.xml ├── remote-mappings.xml └── deployment.xml ├── Constants.py ├── Metrics.py ├── Optim.py ├── dataLoader.py ├── run.py ├── TransformerBlock.py ├── HGSL.py └── Data_preprocessing.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /data/gos/test_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/test_idx.npy -------------------------------------------------------------------------------- /data/gos/val_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/val_idx.npy -------------------------------------------------------------------------------- /data/poli/val_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/val_idx.npy -------------------------------------------------------------------------------- /data/gos/train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/train_idx.npy -------------------------------------------------------------------------------- /data/poli/test_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/test_idx.npy -------------------------------------------------------------------------------- /data/poli/train_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/train_idx.npy -------------------------------------------------------------------------------- /data/gos/struct_temp.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/struct_temp.pkl -------------------------------------------------------------------------------- /data/poli/struct_temp.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/struct_temp.pkl -------------------------------------------------------------------------------- /data/gos/news_mapping.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/news_mapping.pickle -------------------------------------------------------------------------------- /data/gos/user_mapping.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/user_mapping.pickle -------------------------------------------------------------------------------- /__pycache__/HGAT.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGAT.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/HGAT.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGAT.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/HGSL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGSL.cpython-37.pyc -------------------------------------------------------------------------------- /data/poli/news_mapping.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/news_mapping.pickle -------------------------------------------------------------------------------- /data/poli/user_mapping.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/user_mapping.pickle -------------------------------------------------------------------------------- /__pycache__/Optim.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Optim.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Optim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Optim.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/layer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/layer.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/layer.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/Constants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Constants.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Constants.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Constants.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/Metrics.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Metrics.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/Metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Metrics.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/dataLoader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/dataLoader.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/dataLoader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/dataLoader.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/graphConstruct.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/graphConstruct.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/graphConstruct.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/graphConstruct.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/TransformerBlock.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/TransformerBlock.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/TransformerBlock.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/TransformerBlock.cpython-37.pyc -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Constants.py: -------------------------------------------------------------------------------- 1 | import torch 2 | PAD = 0 3 | 4 | step_split = 2 5 | n_heads = 14 6 | 7 | #cate = ['retweet', 'support', 'deny'] 8 | cate = ['retweet'] 9 | early_type = 'time' # 'engage' or 'time' 10 | 11 | GPU = torch.cuda.is_available() 12 | device = torch.device('cuda' if GPU else "cpu") 13 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 7 | -------------------------------------------------------------------------------- /.idea/sshConfigs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/MS-HGAT.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/webServers.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 13 | 14 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /Metrics.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, roc_auc_score, average_precision_score 5 | 6 | 7 | """ 8 | Utility functions for evaluating the model performance 9 | """ 10 | 11 | class Metrics(object): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def compute_metric(self, y_prob, y_true): 17 | k_list = ['Acc', 'F1', 'Pre', 'Recall'] 18 | y_pre = np.array(y_prob).argmax(axis=1) 19 | size = len(y_prob) 20 | assert len(y_prob) == len(y_true) 21 | 22 | scores = {str(k): 0.0 for k in k_list} 23 | scores['Acc'] += accuracy_score(y_true, y_pre) * size 24 | scores['F1'] += f1_score(y_true, y_pre, average='macro') * size 25 | scores['Pre'] += precision_score(y_true, y_pre, zero_division=0) * size 26 | scores['Recall'] += recall_score(y_true, y_pre, zero_division=0) * size 27 | 28 | # y_true = np.array(y_true) 29 | # prob_log = y_prob[:, 1].tolist() 30 | #scores['auc'] = roc_auc_score(y_true, prob_log) 31 | 32 | return scores 33 | 34 | 35 | -------------------------------------------------------------------------------- /Optim.py: -------------------------------------------------------------------------------- 1 | '''A wrapper class for optimizer ''' 2 | import numpy as np 3 | 4 | class ScheduledOptim(object): 5 | '''A simple wrapper class for learning rate scheduling''' 6 | 7 | def __init__(self, optimizer, d_model, n_warmup_steps): 8 | self.optimizer = optimizer 9 | self.d_model = d_model 10 | self.n_warmup_steps = n_warmup_steps 11 | self.n_current_steps = 0 12 | 13 | def step(self): 14 | "Step by the inner optimizer" 15 | self.optimizer.step() 16 | 17 | def zero_grad(self): 18 | "Zero out the gradients by the inner optimizer" 19 | self.optimizer.zero_grad() 20 | 21 | def update_learning_rate(self): 22 | ''' Learning rate scheduling per step ''' 23 | self.n_current_steps += 1 24 | new_lr = np.power(self.d_model, -0.5) * np.min([ 25 | np.power(self.n_current_steps, -0.5), 26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps]) 27 | 28 | for param_group in self.optimizer.param_groups: 29 | param_group['lr'] = new_lr 30 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 43 | -------------------------------------------------------------------------------- /dataLoader.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Created on Nov 1 22:28:02 2021 4 | 5 | @author: Ling Sun 6 | """ 7 | import numpy as np 8 | import torch 9 | from torch.autograd import Variable 10 | import Constants 11 | import pickle 12 | 13 | class Options(object): 14 | 15 | def __init__(self, data_name='poli'): 16 | self.news_centered = 'data/' + data_name + '/Processed/news_centered.pickle' 17 | self.user_centered = 'data/' + data_name + '/Processed/user_centered.pickle' 18 | 19 | #self.user_features = 'data/' + data_name + '/user_features.pickle' 20 | self.test_data = 'data/' + data_name + '/Processed/test_processed.pickle' 21 | self.valid_data = 'data/' + data_name + '/Processed/valid_processed.pickle' 22 | self.train_data = 'data/' + data_name + '/Processed/train_processed.pickle' 23 | self.news_features = 'data/' + data_name + '/struct_temp.pkl' 24 | self.news_mapping = 'data/' + data_name + '/news_mapping.pickle' 25 | 26 | self.save_path = '' 27 | 28 | def DataReader(data_name): 29 | options = Options(data_name) 30 | with open(options.train_data, 'rb') as f: 31 | train_data = pickle.load(f) 32 | with open(options.valid_data, 'rb') as f: 33 | valid_data = pickle.load(f) 34 | with open(options.test_data, 'rb') as f: 35 | test_data = pickle.load(f) 36 | 37 | #print(train_data) 38 | 39 | total_size = len(train_data[0])+len(test_data[0])+len(valid_data[0]) 40 | 41 | print("news cascades size:%d " % (total_size)) 42 | print("train size:%d " % (len(train_data[0]))) 43 | print("test and valid size:%d " % (len(test_data[0])+len(valid_data[0]))) 44 | 45 | return train_data, valid_data, test_data, total_size 46 | 47 | def FeatureReader(data_name): 48 | options = Options(data_name) 49 | with open(options.news_mapping, 'rb') as handle: 50 | n2idx = pickle.load(handle) 51 | '''Spread status: S1, S2, T1, T2 52 | Structural:(S1)number of sub-cascades, (S2)proportion of non-isolated cascades; 53 | Temporal: (T1) duration of spread,(T2) the average response time from tweet to retweet''' 54 | with open(options.news_features, 'rb') as f: 55 | features = np.array(pickle.load(f)) 56 | news_size = len(features) 57 | spread_status = np.zeros((news_size + 1, 4)) 58 | for news in features: 59 | #print(news) 60 | spread_status[n2idx[news[0]]]=np.array(news[1:]) 61 | #print(spread_status[n2idx[news[0]]]) 62 | return spread_status 63 | 64 | def GraphReader(data_name): 65 | options = Options(data_name) 66 | with open(options.news_centered, 'rb') as f: 67 | news_centered_graph = pickle.load(f) 68 | 69 | with open(options.user_centered, 'rb') as f: 70 | user_centered_graph = pickle.load(f) 71 | 72 | useq, user_inf = (item for item in user_centered_graph) 73 | seq, timestamps, user_level, news_inf = (item for item in news_centered_graph) 74 | spread_status = FeatureReader(data_name) 75 | 76 | user_size = len(useq) 77 | user_inf[user_inf>0] = 1 78 | act_level = user_inf[1:].sum(1) 79 | avg_inf = np.append([0],act_level) 80 | 81 | news_centered_graph = [seq, timestamps, user_level] 82 | user_centered_graph = [useq, news_inf, avg_inf] 83 | 84 | 85 | return [[torch.LongTensor(i).to(Constants.device) for i in news_centered_graph], [torch.LongTensor(i).to(Constants.device) for i in user_centered_graph], 86 | torch.LongTensor(spread_status).to(Constants.device)], user_size 87 | 88 | class DataLoader(object): 89 | ''' For data iteration ''' 90 | 91 | def __init__( 92 | self, data, batch_size=64, cuda=True, test=False): 93 | self._batch_size = batch_size 94 | self.idx = data[0] 95 | self.label = data[1] 96 | self.test = test 97 | self.cuda = cuda 98 | 99 | 100 | self._n_batch = int(np.ceil(len(self.idx) / self._batch_size)) 101 | self._iter_count = 0 102 | 103 | def __iter__(self): 104 | return self 105 | 106 | def __next__(self): 107 | return self.next() 108 | 109 | def __len__(self): 110 | return self._n_batch 111 | 112 | def next(self): 113 | ''' Get the next batch ''' 114 | 115 | def seq_to_tensor(insts): 116 | 117 | inst_data_tensor = Variable( 118 | torch.LongTensor(insts), volatile=self.test) 119 | 120 | return inst_data_tensor 121 | 122 | if self._iter_count < self._n_batch: 123 | batch_idx = self._iter_count 124 | self._iter_count += 1 125 | 126 | start_idx = batch_idx * self._batch_size 127 | end_idx = (batch_idx + 1) * self._batch_size 128 | 129 | idx = self.idx[start_idx:end_idx] 130 | labels = self.label[start_idx:end_idx] 131 | idx = seq_to_tensor(idx) 132 | labels = seq_to_tensor(labels) 133 | 134 | return idx, labels 135 | else: 136 | 137 | self._iter_count = 0 138 | raise StopIteration() 139 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 18 22:42:32 2021 4 | 5 | @author: Ling Sun 6 | """ 7 | 8 | import argparse 9 | import time 10 | import numpy as np 11 | import Constants 12 | import torch 13 | from dataLoader import DataReader, GraphReader, DataLoader 14 | from Metrics import Metrics 15 | from HGSL import HGSL 16 | from Optim import ScheduledOptim 17 | import torch.nn.functional as F 18 | 19 | 20 | torch.backends.cudnn.deterministic = True 21 | torch.manual_seed(0) 22 | torch.cuda.manual_seed_all(0) 23 | np.random.seed(0) 24 | torch.cuda.manual_seed(0) 25 | 26 | metric = Metrics() 27 | 28 | 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('-data_name', default='poli') 31 | parser.add_argument('-epoch', type=int, default=200) 32 | parser.add_argument('-batch_size', type=int, default=32) 33 | parser.add_argument('-d_model', type=int, default=64) 34 | parser.add_argument('-initialFeatureSize', type=int, default=64) 35 | parser.add_argument('-early_time', type=int, default=10) 36 | parser.add_argument('-n_warmup_steps', type=int, default=1000) 37 | parser.add_argument('-dropout', type=float, default=0.5) 38 | parser.add_argument('-save_path', default= "./checkpoint/fake_detection.pt") 39 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') 40 | parser.add_argument('-no_cuda', action='store_true') 41 | 42 | opt = parser.parse_args() 43 | 44 | 45 | def train_epoch(model, training_data, hypergraph_list, optimizer): 46 | # train 47 | model.train() 48 | total_loss = 0.0 49 | 50 | for i, batch in enumerate(training_data): 51 | # data preparing 52 | tgt, labels = (item.to(Constants.device) for item in batch) 53 | # training 54 | optimizer.zero_grad() 55 | pred= model(tgt, hypergraph_list) 56 | 57 | # loss 58 | loss = F.nll_loss(pred, labels.squeeze()) 59 | loss.backward() 60 | 61 | # parameter update 62 | optimizer.step() 63 | optimizer.update_learning_rate() 64 | 65 | total_loss += loss.item() 66 | 67 | return total_loss 68 | 69 | def train_model(HGSL, data_name): 70 | # ========= Preparing DataLoader =========# 71 | train, valid, test, news_size, = DataReader(data_name) 72 | hypergraph_list, user_size = GraphReader(data_name) 73 | 74 | train_data = DataLoader(train, batch_size=opt.batch_size, cuda=False) 75 | valid_data = DataLoader(valid, batch_size=opt.batch_size, cuda=False) 76 | test_data = DataLoader(test, batch_size=opt.batch_size, cuda=False) 77 | 78 | 79 | opt.user_size = user_size 80 | opt.edge_size = news_size+1 81 | 82 | # ========= Preparing Model =========# 83 | model = HGSL(opt) 84 | params = model.parameters() 85 | optimizerAdam = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-09) 86 | optimizer = ScheduledOptim(optimizerAdam, opt.d_model, opt.n_warmup_steps) 87 | 88 | if torch.cuda.is_available(): 89 | model = model.to(Constants.device) 90 | 91 | validation_history = 0.0 92 | best_scores = {} 93 | for epoch_i in range(opt.epoch): 94 | print('\n[ Epoch', epoch_i, ']') 95 | 96 | start = time.time() 97 | train_loss = train_epoch(model, train_data, hypergraph_list, optimizer) 98 | print(' - (Training) loss: {loss: 8.5f} %, ' \ 99 | 'elapse: {elapse:3.3f} min'.format( 100 | loss=train_loss, 101 | elapse=(time.time() - start) / 60)) 102 | 103 | if epoch_i > 5: 104 | #start = time.time() 105 | scores = test_epoch(model, valid_data, hypergraph_list) 106 | print(' - (Validation) ') 107 | for metric in scores.keys(): 108 | print(metric + ': ' + "%.5f"%(scores[metric]*100) +"%") 109 | 110 | print(' - (Test) ') 111 | scores = test_epoch(model, test_data, hypergraph_list) 112 | for metric in scores.keys(): 113 | print(metric + ': ' + "%.5f"%(scores[metric]*100) +"%") 114 | 115 | if validation_history <= sum(scores.values()): 116 | print("Best Test Accuracy:{}% at Epoch:{}".format(round(scores["Acc"]*100,5), epoch_i)) 117 | validation_history = sum(scores.values()) 118 | best_scores = scores 119 | print("Save best model!!!") 120 | torch.save(model.state_dict(), opt.save_path) 121 | 122 | print(" - (Finished!!) \n Best scores: ") 123 | for metric in best_scores.keys(): 124 | print(metric + ': ' + "%.5f"%(best_scores[metric]*100) +"%") 125 | 126 | def test_epoch(model, validation_data, hypergraph_list): 127 | ''' Epoch operation in evaluation phase ''' 128 | model.eval() 129 | 130 | scores = {} 131 | k_list = ['Acc', 'F1', 'Pre', 'Recall'] 132 | for k in k_list: 133 | scores[k] = 0 134 | 135 | n_total_words = 0 136 | with torch.no_grad(): 137 | for i, batch in enumerate(validation_data): 138 | tgt, labels = (item.to(Constants.device) for item in batch) 139 | y_labels = labels.detach().cpu().numpy() 140 | # forward 141 | pred = model(tgt, hypergraph_list) 142 | y_pred = pred.detach().cpu().numpy() 143 | n_total_words += len(tgt) 144 | 145 | scores_batch= metric.compute_metric(y_pred, y_labels) 146 | for k in k_list: 147 | scores[k] += scores_batch[k] 148 | 149 | for k in k_list: 150 | scores[k] = scores[k] / n_total_words 151 | return scores 152 | 153 | if __name__ == "__main__": 154 | model = HGSL 155 | train_model(model, opt.data_name) 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /TransformerBlock.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | import math 7 | import Constants 8 | 9 | 10 | class PositionalEncoding(nn.Module): 11 | "Implement the PE function." 12 | 13 | def __init__(self, d_model, dropout, max_len=800): 14 | super(PositionalEncoding, self).__init__() 15 | self.dropout = nn.Dropout(p=dropout) 16 | 17 | # Compute the positional encodings once in log space. 18 | pe = torch.zeros(max_len, d_model) 19 | position = torch.arange(0, max_len).unsqueeze(1).float() 20 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) 21 | pe[:, 0::2] = torch.sin(position * div_term) 22 | pe[:, 1::2] = torch.cos(position * div_term) 23 | pe = pe.unsqueeze(0) 24 | self.register_buffer('pe', pe) 25 | 26 | def forward(self, x): 27 | x = x + self.pe[:, :x.size(1)] 28 | return self.dropout(x) 29 | 30 | class TransformerBlock(nn.Module): 31 | 32 | def __init__(self, input_size, n_heads=2, is_layer_norm=True, attn_dropout=0.1): 33 | super(TransformerBlock, self).__init__() 34 | self.n_heads = n_heads 35 | self.d_k = input_size 36 | self.d_v = input_size 37 | 38 | self.is_layer_norm = is_layer_norm 39 | if is_layer_norm: 40 | self.layer_norm = nn.LayerNorm(normalized_shape=input_size) 41 | 42 | self.pos_encoding= PositionalEncoding(d_model=input_size, dropout=0.5) 43 | 44 | self.W_q = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_k)) 45 | self.W_k = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_k)) 46 | self.W_v = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_v)) 47 | 48 | self.W_o = nn.Parameter(torch.Tensor(self.d_v*n_heads, input_size)) 49 | self.linear1 = nn.Linear(input_size, input_size) 50 | self.linear2 = nn.Linear(input_size, input_size) 51 | 52 | self.dropout = nn.Dropout(attn_dropout) 53 | self.__init_weights__() 54 | 55 | def __init_weights__(self): 56 | init.xavier_normal_(self.W_q) 57 | init.xavier_normal_(self.W_k) 58 | init.xavier_normal_(self.W_v) 59 | init.xavier_normal_(self.W_o) 60 | 61 | init.xavier_normal_(self.linear1.weight) 62 | init.xavier_normal_(self.linear2.weight) 63 | 64 | def FFN(self, X): 65 | output = self.linear2(F.relu(self.linear1(X))) 66 | output = self.dropout(output) 67 | return output 68 | 69 | def scaled_dot_product_attention(self, Q, K, V, mask, episilon=1e-6): 70 | ''' 71 | :param Q: (*, max_q_words, n_heads, input_size) 72 | :param K: (*, max_k_words, n_heads, input_size) 73 | :param V: (*, max_v_words, n_heads, input_size) 74 | :param mask: (*, max_q_words) 75 | :param episilon: 76 | :return: 77 | ''' 78 | temperature = self.d_k ** 0.5 79 | 80 | Q_K = (torch.einsum("bqd,bkd->bqk", Q, K)) / (temperature + episilon) 81 | if mask is not None: 82 | pad_mask = mask.unsqueeze(dim=-1).expand(-1, -1, K.size(1)) 83 | mask = torch.triu(torch.ones(pad_mask.size()), diagonal=1).bool().to(Constants.device) 84 | mask_ = mask + pad_mask 85 | Q_K = Q_K.masked_fill(mask_, -2**32+1) 86 | 87 | Q_K_score = F.softmax(Q_K, dim=-1) # (batch_size, max_q_words, max_k_words) 88 | Q_K_score = self.dropout(Q_K_score) 89 | #维度为3的两个矩阵的乘法 90 | V_att = Q_K_score.bmm(V) # (*, max_q_words, input_size) 91 | return V_att 92 | 93 | 94 | def multi_head_attention(self, Q, K, V, mask): 95 | ''' 96 | :param Q: 97 | :param K: 98 | :param V: 99 | :param mask: (bsz, max_q_words) 100 | :return: 101 | ''' 102 | bsz, q_len, _ = Q.size() 103 | bsz, k_len, _ = K.size() 104 | bsz, v_len, _ = V.size() 105 | #print(self.W_q.size(), bsz, q_len, self.n_heads, self.d_k) 106 | Q_ = Q.matmul(self.W_q).view(bsz, q_len, self.n_heads, self.d_k) 107 | K_ = K.matmul(self.W_k).view(bsz, k_len, self.n_heads, self.d_k) 108 | V_ = V.matmul(self.W_v).view(bsz, v_len, self.n_heads, self.d_v) 109 | #print(Q_.size(), bsz, q_len, self.n_heads, self.d_k) 110 | Q_ = Q_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_k) 111 | K_ = K_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_k) 112 | V_ = V_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_v) 113 | 114 | if mask is not None: 115 | mask = mask.unsqueeze(dim=1).expand(-1, self.n_heads, -1) # For head axis broadcasting. 116 | mask = mask.reshape(-1, mask.size(-1)) 117 | 118 | V_att = self.scaled_dot_product_attention(Q_, K_, V_, mask) 119 | V_att = V_att.view(bsz, self.n_heads, q_len, self.d_v) 120 | V_att = V_att.permute(0, 2, 1, 3).contiguous().view(bsz, q_len, self.n_heads*self.d_v) 121 | 122 | output = self.dropout(V_att.matmul(self.W_o)) # (batch_size, max_q_words, input_size) 123 | return output 124 | 125 | 126 | def forward(self, Q, K, V, mask=None, pos = True): 127 | ''' 128 | :param Q: (batch_size, max_q_words, input_size) 129 | :param K: (batch_size, max_k_words, input_size) 130 | :param V: (batch_size, max_v_words, input_size) 131 | :return: output: (batch_size, max_q_words, input_size) same size as Q 132 | ''' 133 | if pos: 134 | Q = self.pos_encoding(Q) 135 | K = self.pos_encoding(K) 136 | V = self.pos_encoding(V) 137 | 138 | V_att = self.multi_head_attention(Q, K, V, mask) 139 | 140 | if self.is_layer_norm: 141 | X = self.layer_norm(Q + V_att) # (batch_size, max_r_words, embedding_dim) 142 | output = self.layer_norm(self.FFN(X) + X) 143 | else: 144 | X = Q + V_att 145 | output = self.FFN(X) + X 146 | return output 147 | -------------------------------------------------------------------------------- /data/poli/poli_news_list.txt: -------------------------------------------------------------------------------- 1 | politifact4190 2 | politifact6657 3 | politifact582 4 | politifact6646 5 | politifact13138 6 | politifact13068 7 | politifact720 8 | politifact4181 9 | politifact7511 10 | politifact9802 11 | politifact8989 12 | politifact548 13 | politifact3228 14 | politifact7376 15 | politifact13682 16 | politifact7506 17 | politifact8069 18 | politifact695 19 | politifact6932 20 | politifact31 21 | politifact2393 22 | politifact979 23 | politifact6931 24 | politifact1052 25 | politifact13283 26 | politifact2048 27 | politifact379 28 | politifact14940 29 | politifact9033 30 | politifact6360 31 | politifact10332 32 | politifact12052 33 | politifact1446 34 | politifact12721 35 | politifact779 36 | politifact11777 37 | politifact6519 38 | politifact6641 39 | politifact8310 40 | politifact1500 41 | politifact12945 42 | politifact1575 43 | politifact1337 44 | politifact423 45 | politifact2128 46 | politifact13244 47 | politifact10787 48 | politifact11627 49 | politifact6556 50 | politifact1454 51 | politifact11899 52 | politifact12057 53 | politifact1690 54 | politifact1467 55 | politifact426 56 | politifact4275 57 | politifact2298 58 | politifact421 59 | politifact1519 60 | politifact809 61 | politifact4586 62 | politifact554 63 | politifact806 64 | politifact4588 65 | politifact5321 66 | politifact11552 67 | politifact7182 68 | politifact11960 69 | politifact186 70 | politifact8470 71 | politifact12627 72 | politifact73 73 | politifact3428 74 | politifact14174 75 | politifact10533 76 | politifact8737 77 | politifact8805 78 | politifact5608 79 | politifact160 80 | politifact462 81 | politifact7489 82 | politifact339 83 | politifact4433 84 | politifact8557 85 | politifact5237 86 | politifact8259 87 | politifact11066 88 | politifact780 89 | politifact5469 90 | politifact11314 91 | politifact52 92 | politifact13013 93 | politifact1714 94 | politifact245 95 | politifact12801 96 | politifact5659 97 | politifact8611 98 | politifact10185 99 | politifact7665 100 | politifact15453 101 | politifact7563 102 | politifact608 103 | politifact11577 104 | politifact9438 105 | politifact13305 106 | politifact10731 107 | politifact13477 108 | politifact581 109 | politifact6998 110 | politifact3527 111 | politifact1106 112 | politifact11761 113 | politifact6473 114 | politifact12755 115 | politifact724 116 | politifact1028 117 | politifact13548 118 | politifact2166 119 | politifact1424 120 | politifact13132 121 | politifact14064 122 | politifact14511 123 | politifact2139 124 | politifact230 125 | politifact8172 126 | politifact7259 127 | politifact134 128 | politifact10276 129 | politifact746 130 | politifact14474 131 | politifact8130 132 | politifact12418 133 | politifact1185 134 | politifact74 135 | politifact13303 136 | politifact681 137 | politifact13052 138 | politifact440 139 | politifact986 140 | politifact943 141 | politifact10903 142 | politifact150 143 | politifact8071 144 | politifact65 145 | politifact206 146 | politifact9622 147 | politifact228 148 | politifact1177 149 | politifact11189 150 | politifact6603 151 | politifact12944 152 | politifact128 153 | politifact9691 154 | politifact15645 155 | politifact208 156 | politifact13395 157 | politifact514 158 | politifact14960 159 | politifact13561 160 | politifact14587 161 | politifact13765 162 | politifact14205 163 | politifact13978 164 | politifact13565 165 | politifact14040 166 | politifact15232 167 | politifact14776 168 | politifact13887 169 | politifact14062 170 | politifact14516 171 | politifact14621 172 | politifact14119 173 | politifact14507 174 | politifact14755 175 | politifact14472 176 | politifact13784 177 | politifact14947 178 | politifact15267 179 | politifact13591 180 | politifact15156 181 | politifact14356 182 | politifact14927 183 | politifact14815 184 | politifact14795 185 | politifact15533 186 | politifact15545 187 | politifact14722 188 | politifact15352 189 | politifact15367 190 | politifact14135 191 | politifact14860 192 | politifact15204 193 | politifact14818 194 | politifact14500 195 | politifact14187 196 | politifact15514 197 | politifact15409 198 | politifact14207 199 | politifact14166 200 | politifact15327 201 | politifact13827 202 | politifact13731 203 | politifact14503 204 | politifact15210 205 | politifact13836 206 | politifact14005 207 | politifact14021 208 | politifact14169 209 | politifact14273 210 | politifact14063 211 | politifact14693 212 | politifact13934 213 | politifact13744 214 | politifact14993 215 | politifact15224 216 | politifact13698 217 | politifact14991 218 | politifact14893 219 | politifact15456 220 | politifact14794 221 | politifact15096 222 | politifact14051 223 | politifact14362 224 | politifact14448 225 | politifact13600 226 | politifact14788 227 | politifact15494 228 | politifact14879 229 | politifact14718 230 | politifact15049 231 | politifact14333 232 | politifact14426 233 | politifact15591 234 | politifact15349 235 | politifact15525 236 | politifact13584 237 | politifact14406 238 | politifact15341 239 | politifact15217 240 | politifact14876 241 | politifact14789 242 | politifact14905 243 | politifact14355 244 | politifact15161 245 | politifact15095 246 | politifact13912 247 | politifact13577 248 | politifact14733 249 | politifact13038 250 | politifact15423 251 | politifact15532 252 | politifact14544 253 | politifact14386 254 | politifact15266 255 | politifact15626 256 | politifact14785 257 | politifact14164 258 | politifact15534 259 | politifact14258 260 | politifact15146 261 | politifact15540 262 | politifact13823 263 | politifact14235 264 | politifact14264 265 | politifact13663 266 | politifact15486 267 | politifact15307 268 | politifact14694 269 | politifact15630 270 | politifact13982 271 | politifact15579 272 | politifact15097 273 | politifact14328 274 | politifact15130 275 | politifact14330 276 | politifact14699 277 | politifact13921 278 | politifact13468 279 | politifact15309 280 | politifact14003 281 | politifact13942 282 | politifact15262 283 | politifact15477 284 | politifact15539 285 | politifact15188 286 | politifact14222 287 | politifact15383 288 | politifact14517 289 | politifact15623 290 | politifact14128 291 | politifact13766 292 | politifact14556 293 | politifact14447 294 | politifact14402 295 | politifact14595 296 | politifact15159 297 | politifact15564 298 | politifact14890 299 | politifact14605 300 | politifact15301 301 | politifact14361 302 | politifact13999 303 | politifact15356 304 | politifact14395 305 | politifact15554 306 | politifact13931 307 | politifact14576 308 | politifact14270 309 | politifact13943 310 | politifact14085 311 | politifact15246 312 | politifact13794 313 | politifact14469 314 | politifact15178 315 | -------------------------------------------------------------------------------- /data/poli/label.txt: -------------------------------------------------------------------------------- 1 | politifact13584 1 2 | politifact8172 0 3 | politifact15477 1 4 | politifact6657 0 5 | politifact14135 1 6 | politifact13912 1 7 | politifact1446 0 8 | politifact7489 0 9 | politifact13468 1 10 | politifact13052 0 11 | politifact14940 0 12 | politifact14402 1 13 | politifact13943 1 14 | politifact14927 1 15 | politifact6932 0 16 | politifact15534 1 17 | politifact14517 1 18 | politifact780 0 19 | politifact2166 0 20 | politifact10185 0 21 | politifact15367 1 22 | politifact608 0 23 | politifact15349 1 24 | politifact15352 1 25 | politifact14960 1 26 | politifact15341 1 27 | politifact2139 0 28 | politifact13982 1 29 | politifact14605 1 30 | politifact8259 0 31 | politifact15525 1 32 | politifact13934 1 33 | politifact5321 0 34 | politifact1052 0 35 | politifact73 0 36 | politifact13698 1 37 | politifact14063 1 38 | politifact1337 0 39 | politifact14258 1 40 | politifact14795 1 41 | politifact15453 0 42 | politifact462 0 43 | politifact8470 0 44 | politifact809 0 45 | politifact7665 0 46 | politifact14876 1 47 | politifact13548 0 48 | politifact14507 1 49 | politifact15630 1 50 | politifact14119 1 51 | politifact5608 0 52 | politifact14205 1 53 | politifact15626 1 54 | politifact14755 1 55 | politifact14788 1 56 | politifact12944 0 57 | politifact681 0 58 | politifact13477 0 59 | politifact12057 0 60 | politifact779 0 61 | politifact245 0 62 | politifact5469 0 63 | politifact150 0 64 | politifact15210 1 65 | politifact5659 0 66 | politifact4586 0 67 | politifact14503 1 68 | politifact8805 0 69 | politifact52 0 70 | politifact10533 0 71 | politifact13682 0 72 | politifact12755 0 73 | politifact514 0 74 | politifact14395 1 75 | politifact14595 1 76 | politifact14235 1 77 | politifact14386 1 78 | politifact13132 0 79 | politifact14500 1 80 | politifact379 0 81 | politifact14330 1 82 | politifact13663 1 83 | politifact13942 1 84 | politifact13038 1 85 | politifact14062 1 86 | politifact15307 1 87 | politifact12418 0 88 | politifact8611 0 89 | politifact14333 1 90 | politifact14040 1 91 | politifact14694 1 92 | politifact11189 0 93 | politifact339 0 94 | politifact15579 1 95 | politifact15591 1 96 | politifact13561 1 97 | politifact8310 0 98 | politifact9438 0 99 | politifact15188 1 100 | politifact14448 1 101 | politifact14273 1 102 | politifact6641 0 103 | politifact14776 1 104 | politifact14264 1 105 | politifact15049 1 106 | politifact943 0 107 | politifact12801 0 108 | politifact13013 0 109 | politifact979 0 110 | politifact13836 1 111 | politifact14556 1 112 | politifact14722 1 113 | politifact11960 0 114 | politifact2128 0 115 | politifact65 0 116 | politifact13591 1 117 | politifact1185 0 118 | politifact15532 1 119 | politifact806 0 120 | politifact6360 0 121 | politifact10903 0 122 | politifact14718 1 123 | politifact14893 1 124 | politifact14222 1 125 | politifact13766 1 126 | politifact14207 1 127 | politifact14576 1 128 | politifact548 0 129 | politifact1519 0 130 | politifact2393 0 131 | politifact421 0 132 | politifact8069 0 133 | politifact13999 1 134 | politifact13744 1 135 | politifact7182 0 136 | politifact720 0 137 | politifact8989 0 138 | politifact581 0 139 | politifact14587 1 140 | politifact14818 1 141 | politifact14362 1 142 | politifact15309 1 143 | politifact1714 0 144 | politifact10332 0 145 | politifact1106 0 146 | politifact14003 1 147 | politifact14187 1 148 | politifact7376 0 149 | politifact1467 0 150 | politifact14794 1 151 | politifact11552 0 152 | politifact13244 0 153 | politifact12052 0 154 | politifact13305 0 155 | politifact7563 0 156 | politifact8737 0 157 | politifact13565 1 158 | politifact4275 0 159 | politifact15266 1 160 | politifact14991 1 161 | politifact13303 0 162 | politifact15540 1 163 | politifact13784 1 164 | politifact6473 0 165 | politifact11761 0 166 | politifact14164 1 167 | politifact15224 1 168 | politifact15301 1 169 | politifact6931 0 170 | politifact14356 1 171 | politifact14785 1 172 | politifact15327 1 173 | politifact14544 1 174 | politifact3228 0 175 | politifact14085 1 176 | politifact3527 0 177 | politifact14815 1 178 | politifact13931 1 179 | politifact15204 1 180 | politifact11066 0 181 | politifact13138 0 182 | politifact13794 1 183 | politifact15645 0 184 | politifact74 0 185 | politifact15494 1 186 | politifact9622 0 187 | politifact13068 0 188 | politifact6603 0 189 | politifact14270 1 190 | politifact15156 1 191 | politifact13887 1 192 | politifact13283 0 193 | politifact160 0 194 | politifact13731 1 195 | politifact15486 1 196 | politifact1177 0 197 | politifact1424 0 198 | politifact14789 1 199 | politifact14947 1 200 | politifact13921 1 201 | politifact6519 0 202 | politifact208 0 203 | politifact14406 1 204 | politifact4181 0 205 | politifact986 0 206 | politifact15095 1 207 | politifact6646 0 208 | politifact15533 1 209 | politifact14361 1 210 | politifact9802 0 211 | politifact14693 1 212 | politifact186 0 213 | politifact15423 1 214 | politifact128 0 215 | politifact14905 1 216 | politifact15545 1 217 | politifact14469 1 218 | politifact13827 1 219 | politifact15356 1 220 | politifact14860 1 221 | politifact230 0 222 | politifact15130 1 223 | politifact15539 1 224 | politifact14621 1 225 | politifact15097 1 226 | politifact1575 0 227 | politifact9691 0 228 | politifact7259 0 229 | politifact13823 1 230 | politifact15623 1 231 | politifact15564 1 232 | politifact3428 0 233 | politifact14733 1 234 | politifact1454 0 235 | politifact14021 1 236 | politifact14064 0 237 | politifact15096 1 238 | politifact12721 0 239 | politifact14169 1 240 | politifact11314 0 241 | politifact12945 0 242 | politifact11777 0 243 | politifact15456 1 244 | politifact15246 1 245 | politifact13600 1 246 | politifact15383 1 247 | politifact11577 0 248 | politifact8071 0 249 | politifact14051 1 250 | politifact14426 1 251 | politifact8557 0 252 | politifact11627 0 253 | politifact14890 1 254 | politifact13978 1 255 | politifact426 0 256 | politifact14447 1 257 | politifact10787 0 258 | politifact14174 0 259 | politifact15262 1 260 | politifact15409 1 261 | politifact15554 1 262 | politifact2048 0 263 | politifact554 0 264 | politifact6998 0 265 | politifact4190 0 266 | politifact14166 1 267 | politifact13395 0 268 | politifact15267 1 269 | politifact1028 0 270 | politifact13577 1 271 | politifact14511 0 272 | politifact15217 1 273 | politifact14128 1 274 | politifact206 0 275 | politifact1500 0 276 | politifact13765 1 277 | politifact1690 0 278 | politifact4433 0 279 | politifact8130 0 280 | politifact12627 0 281 | politifact11899 0 282 | politifact15514 1 283 | politifact14516 1 284 | politifact14879 1 285 | politifact695 0 286 | politifact228 0 287 | politifact15161 1 288 | politifact15159 1 289 | politifact31 0 290 | politifact14472 1 291 | politifact5237 0 292 | politifact14699 1 293 | politifact14474 0 294 | politifact14328 1 295 | politifact15232 1 296 | politifact6556 0 297 | politifact10731 0 298 | politifact10276 0 299 | politifact7506 0 300 | politifact746 0 301 | politifact440 0 302 | politifact15178 1 303 | politifact14005 1 304 | politifact9033 0 305 | politifact134 0 306 | politifact14355 1 307 | politifact15146 1 308 | politifact423 0 309 | politifact2298 0 310 | politifact724 0 311 | politifact4588 0 312 | politifact7511 0 313 | politifact14993 1 314 | politifact582 0 315 | -------------------------------------------------------------------------------- /HGSL.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Jan 18 22:30:16 2021 4 | 5 | @author: Ling Sun 6 | """ 7 | 8 | import math 9 | #import numpy as np 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | #from layer import HGATLayer 14 | import torch.nn.init as init 15 | import Constants 16 | from torch.nn.parameter import Parameter 17 | from TransformerBlock import TransformerBlock 18 | 19 | class Gated_fusion(nn.Module): 20 | def __init__(self, input_size, out_size=1, dropout=0.2): 21 | super(Gated_fusion, self).__init__() 22 | self.linear1 = nn.Linear(input_size, input_size) 23 | self.linear2 = nn.Linear(input_size, out_size) 24 | self.dropout = nn.Dropout(dropout) 25 | self.init_weights() 26 | 27 | def init_weights(self): 28 | init.xavier_normal_(self.linear1.weight) 29 | init.xavier_normal_(self.linear2.weight) 30 | 31 | def forward(self, X1, X2): 32 | emb = torch.cat([X1.unsqueeze(dim=0), X2.unsqueeze(dim=0)], dim=0) 33 | emb_score = F.softmax(self.linear2(torch.tanh(self.linear1(emb))), dim=0) 34 | emb_score = self.dropout(emb_score) 35 | out = torch.sum(emb_score * emb, dim=0) 36 | return out 37 | 38 | class HGSL(nn.Module): 39 | def __init__(self, opt): 40 | super(HGSL, self).__init__() 41 | 42 | self.hidden_size = opt.d_model 43 | self.n_node = opt.user_size 44 | self.dropout = nn.Dropout(opt.dropout) 45 | self.initial_feature = opt.initialFeatureSize 46 | self.hgnn = HGNN(self.initial_feature, self.hidden_size, dropout = opt.dropout) 47 | 48 | self.user_embedding = nn.Embedding(self.n_node, self.initial_feature) 49 | self.stru_attention = TransformerBlock(self.hidden_size, n_heads=8) 50 | self.temp_attention = TransformerBlock(self.hidden_size, n_heads=8) 51 | 52 | self.global_cen_embedding = nn.Embedding(600, self.hidden_size) 53 | self.local_time_embedding = nn.Embedding(5000, self.hidden_size) 54 | self.cas_pos_embedding = nn.Embedding(50, self.hidden_size) 55 | self.local_inf_embedding = nn.Embedding(200, self.hidden_size) 56 | 57 | self.weight = Parameter(torch.Tensor(self.hidden_size+2, self.hidden_size+2)) 58 | self.weight2 = Parameter(torch.Tensor(self.hidden_size+2, self.hidden_size+2)) 59 | self.fus = Gated_fusion(self.hidden_size+2) 60 | self.linear = nn.Linear((self.hidden_size+2), 2) 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | stdv = 1.0 / math.sqrt(self.hidden_size) 65 | for weight in self.parameters(): 66 | weight.data.uniform_(-stdv, stdv) 67 | 68 | 69 | def forward(self, data_idx, hypergraph_list): 70 | 71 | news_centered_graph, user_centered_graph, spread_status = (item for item in hypergraph_list) 72 | seq, timestamps, user_level = (item for item in news_centered_graph) 73 | useq, user_inf, user_cen = (item for item in user_centered_graph) 74 | 75 | #Global learning 76 | hidden = self.dropout(self.user_embedding.weight) 77 | user_cen = self.global_cen_embedding(user_cen) 78 | tweet_hidden = hidden + user_cen 79 | user_hgnn_out = self.hgnn(tweet_hidden, seq, useq) 80 | #print(user_hgnn_out.device) 81 | 82 | #Normalize 83 | zero_vec1 = -9e15 * torch.ones_like(seq[data_idx]) 84 | one_vec = torch.ones_like(seq[data_idx], dtype=torch.float) 85 | nor_input = torch.where(seq[data_idx] > 0, one_vec, zero_vec1) 86 | nor_input = F.softmax(nor_input, 1) 87 | att_mask = (seq[data_idx] == Constants.PAD) 88 | adj_with_fea = F.embedding(seq[data_idx], user_hgnn_out) 89 | #print(seq[data_idx].size(), user_hgnn_out.size()) 90 | 91 | #Local temporal learning 92 | global_time = self.local_time_embedding(timestamps[data_idx]) 93 | att_hidden = adj_with_fea + global_time 94 | 95 | att_out = self.temp_attention(att_hidden, att_hidden, att_hidden, mask = att_mask ) 96 | news_out = torch.einsum("abc,ab->ac", (att_out, nor_input)) 97 | 98 | #Concatenate temporal propagation status 99 | news_out = torch.cat([news_out, spread_status[data_idx][:, 2:]/3600/24], dim=-1) 100 | news_out = news_out.matmul(self.weight) 101 | 102 | #Local structural learning 103 | local_inf = self.local_inf_embedding(user_inf[data_idx]) 104 | cas_pos = self.cas_pos_embedding(user_level[data_idx]) 105 | att_hidden_str = adj_with_fea + local_inf + cas_pos 106 | 107 | att_out_str = self.stru_attention(att_hidden_str, att_hidden_str, att_hidden_str, mask=att_mask, pos = False) 108 | news_out_str = torch.einsum("abc,ab->ac", (att_out_str, nor_input)) 109 | 110 | # Concatenate structural propagation status 111 | news_out_str = torch.cat([news_out_str, spread_status[data_idx][:,:2]], dim=-1) 112 | news_out_str = news_out_str.matmul(self.weight2) 113 | 114 | #Gated fusion 115 | news_out = self.fus(news_out, news_out_str) 116 | output = self.linear(news_out) 117 | output = F.log_softmax(output, dim=1) 118 | #print(output) 119 | 120 | return output 121 | 122 | '''Learn hypergraphs''' 123 | class HGNN_layer(nn.Module): 124 | def __init__(self, input_size, output_size, dropout=0.5): 125 | super(HGNN_layer, self).__init__() 126 | self.dropout = dropout 127 | self.in_features = input_size 128 | self.out_features = output_size 129 | self.weight1 = Parameter(torch.Tensor(self.in_features, self.out_features)) 130 | self.weight2 = Parameter(torch.Tensor(self.out_features, self.out_features)) 131 | self.reset_parameters() 132 | 133 | def reset_parameters(self): 134 | stdv = 1.0 / math.sqrt(self.in_features) 135 | for weight in self.parameters(): 136 | weight.data.uniform_(-stdv, stdv) 137 | self.weight1.data.uniform_(-stdv, stdv) 138 | self.weight2.data.uniform_(-stdv, stdv) 139 | 140 | def forward(self, x, seq, useq): 141 | x = x.matmul(self.weight1) 142 | adj_with_fea = F.embedding(seq, x) 143 | zero_vec1 = -9e15 * torch.ones_like(seq) 144 | one_vec = torch.ones_like(seq, dtype=torch.float) 145 | nor_input = torch.where(seq > 0, one_vec, zero_vec1) 146 | nor_input = F.softmax(nor_input, 1) 147 | 148 | edge = torch.einsum("abc,ab->ac", (adj_with_fea, nor_input)) 149 | edge = F.dropout(edge, self.dropout, training=self.training) 150 | edge = F.relu(edge, inplace=False) 151 | e1 = edge.matmul(self.weight2) 152 | edge_adj_with_fea = F.embedding(useq, e1) 153 | 154 | zero_vec1 = -9e15 * torch.ones_like(useq) 155 | one_vec = torch.ones_like(useq, dtype=torch.float) 156 | u_nor_input = torch.where(useq > 0, one_vec, zero_vec1) 157 | u_nor_input = F.softmax(u_nor_input, 1) 158 | node = torch.einsum("abc,ab->ac", (edge_adj_with_fea, u_nor_input)) 159 | 160 | node = F.dropout(node, self.dropout, training=self.training) 161 | 162 | return node 163 | 164 | class HGNN(nn.Module): 165 | def __init__(self, input_size, output_size, dropout=0.5): 166 | super(HGNN, self).__init__() 167 | self.dropout = dropout 168 | self.gnn1 = HGNN_layer(input_size, output_size, dropout=self.dropout) 169 | 170 | 171 | def forward(self, x, seq, useq): 172 | node = self.gnn1(x, seq, useq) 173 | return node 174 | -------------------------------------------------------------------------------- /Data_preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import Constants 4 | import pickle 5 | import os 6 | 7 | 8 | class Options(object): 9 | 10 | def __init__(self, data_name='poli'): 11 | self.nretweet = 'data/' + data_name + '/news_centered_data.txt' 12 | self.uretweet = 'data/' + data_name + '/user_centered_data.txt' 13 | self.label = 'data/' + data_name + '/label.txt' 14 | self.news_list = 'data/' + data_name + '/' + data_name + '_news_list.txt' 15 | 16 | self.news_centered = 'data/' + data_name + '/Processed/news_centered.pickle' 17 | self.user_centered = 'data/' + data_name + '/Processed/user_centered.pickle' 18 | 19 | self.train_idx = torch.from_numpy(np.load('data/' + data_name +'/train_idx.npy')) 20 | self.valid_idx = torch.from_numpy(np.load('data/' + data_name +'/val_idx.npy')) 21 | self.test_idx = torch.from_numpy(np.load('data/' + data_name +'/test_idx.npy')) 22 | 23 | self.train = 'data/' + data_name + '/Processed/train_processed.pickle' 24 | self.valid = 'data/' + data_name + '/Processed/valid_processed.pickle' 25 | self.test = 'data/' + data_name + '/Processed/test_processed.pickle' 26 | 27 | self.user_mapping = 'data/' + data_name + '/user_mapping.pickle' 28 | self.news_mapping = 'data/' + data_name + '/news_mapping.pickle' 29 | self.save_path = '' 30 | self.embed_dim = 64 31 | 32 | 33 | def buildIndex(user_set, news_set): 34 | n2idx = {} 35 | u2idx = {} 36 | 37 | pos = 0 38 | u2idx[''] = pos 39 | pos += 1 40 | for user in user_set: 41 | u2idx[user] = pos 42 | pos += 1 43 | 44 | pos = 0 45 | n2idx[''] = pos 46 | pos += 1 47 | for news in news_set: 48 | n2idx[news] = pos 49 | pos += 1 50 | 51 | user_size = len(user_set) 52 | news_size = len(news_set) 53 | return user_size, news_size, u2idx, n2idx 54 | 55 | def Pre_data(data_name, early_type, early, max_len=200): 56 | options = Options(data_name) 57 | cascades = {} 58 | 59 | '''load news-centered retweet data''' 60 | for line in open(options.nretweet): 61 | userlist = [] 62 | timestamps = [] 63 | levels = [] 64 | infs = [] 65 | 66 | chunks = line.strip().split(',') 67 | cascades[chunks[0]] = [] 68 | 69 | for chunk in chunks[1:]: 70 | try: 71 | user, timestamp, level, inf = chunk.split() 72 | userlist.append(user) 73 | timestamps.append(float(timestamp)/3600/24) 74 | levels.append(int(level)+1) 75 | infs.append(inf) 76 | except: 77 | user = chunk 78 | userlist.append(user) 79 | timestamps.append(float(0.0)) 80 | infs.append(1) 81 | levels.append(1) 82 | print('tweet root', chunk) 83 | cascades[chunks[0]] = [userlist, timestamps, levels, infs] 84 | 85 | news_list = [] 86 | for line in open(options.news_list): 87 | news_list.append(line.strip()) 88 | cascades = {key: value for key, value in cascades.items() if key in news_list} 89 | 90 | if early: 91 | if early_type == 'engage': 92 | max_len = early 93 | elif early_type == 'time': 94 | mint = [] 95 | for times in np.array(list(cascades.values()))[:,1]: 96 | if max(times)-min(times) < early: 97 | mint.append(len(times)) 98 | else: 99 | for t in times: 100 | if t - min(times) >= early: 101 | mint.append(times.index(t)) 102 | break 103 | 104 | 105 | '''ordered by timestamps''' 106 | for idx, cas in enumerate(cascades.keys()): 107 | max_ = mint[idx] if early and early_type == 'time' and mint[idx] < max_len else max_len 108 | cascades[cas] = [i[:max_] for i in cascades[cas]] 109 | 110 | order = [i[0] for i in sorted(enumerate(cascades[cas][1]), key=lambda x: float(x[1]))] 111 | #print(cascades[cas].shape) 112 | cascades[cas] = [[x[i] for i in order] for x in cascades[cas]] 113 | #cascades[cas] = cascades[cas][:,order] 114 | #cascades[cas][1][:] = [cascades[cas][1][i] for i in order] 115 | #cascades[cas][0][:] = [cascades[cas][0][i] for i in order] 116 | #cascades[cas][2][:] = [cascades[cas][2][i] for i in order] 117 | #cascades[cas][3][:] = [cascades[cas][3][i] for i in order] 118 | 119 | 120 | 121 | ucascades = {} 122 | '''load user-centered retweet data''' 123 | for line in open(options.uretweet): 124 | newslist = [] 125 | userinf = [] 126 | 127 | chunks = line.strip().split(',') 128 | 129 | ucascades[chunks[0]] = [] 130 | 131 | for chunk in chunks[1:]: 132 | news, timestamp, inf= chunk.split() 133 | newslist.append(news) 134 | userinf.append(inf) 135 | 136 | ucascades[chunks[0]] = np.array([newslist, userinf]) 137 | 138 | '''ordered by timestamps''' 139 | for cas in list(ucascades.keys()): 140 | order = [i[0] for i in sorted(enumerate(ucascades[cas][1]), key=lambda x: float(x[1]))] 141 | #ucascades[cas] = cascades[cas][:, order] 142 | ucascades[cas] = [[x[i] for i in order] for x in ucascades[cas]] 143 | #ucascades[cas][1][:] = [ucascades[cas][1][i] for i in order] 144 | #ucascades[cas][0][:] = [ucascades[cas][0][i] for i in order] 145 | user_set = ucascades.keys() 146 | 147 | 148 | if os.path.exists(options.user_mapping): 149 | with open(options.user_mapping, 'rb') as handle: 150 | u2idx = pickle.load(handle) 151 | user_size = len(list(user_set)) 152 | with open(options.news_mapping, 'rb') as handle: 153 | n2idx = pickle.load(handle) 154 | news_size = len(news_list) 155 | else: 156 | user_size, news_size, u2idx, n2idx = buildIndex(user_set, news_list) 157 | with open(options.user_mapping, 'wb') as handle: 158 | pickle.dump(u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL) 159 | with open(options.news_mapping, 'wb') as handle: 160 | pickle.dump(n2idx, handle, protocol=pickle.HIGHEST_PROTOCOL) 161 | 162 | for cas in cascades: 163 | cascades[cas][0] = [u2idx[u] for u in cascades[cas][0]] 164 | t_cascades = dict([(n2idx[key], cascades[key]) for key in cascades]) 165 | 166 | for cas in ucascades: 167 | ucascades[cas][0] = [n2idx[n] for n in ucascades[cas][0]] 168 | u_cascades = dict([(u2idx[key], ucascades[key]) for key in ucascades]) 169 | 170 | '''load labels''' 171 | labels = np.zeros((news_size + 1, 1)) 172 | for line in open(options.label): 173 | news, label = line.strip().split(' ') 174 | if news in n2idx: 175 | labels[n2idx[news]] = label 176 | 177 | seq = np.zeros((news_size + 1, max_len)) 178 | timestamps = np.zeros((news_size + 1, max_len)) 179 | user_level = np.zeros((news_size + 1, max_len)) 180 | user_inf = np.zeros((news_size + 1, max_len)) 181 | news_list = [0] + news_list 182 | for n, s in cascades.items(): 183 | news_list[n2idx[n]] = n 184 | se_data = np.hstack((s[0], np.array([Constants.PAD] * (max_len - len(s[0]))))) 185 | seq[n2idx[n]] = se_data 186 | 187 | t_data = np.hstack((s[1], np.array([Constants.PAD] * (max_len - len(s[1]))))) 188 | timestamps[n2idx[n]] = t_data 189 | 190 | lv_data = np.hstack((s[2], np.array([Constants.PAD] * (max_len - len(s[2]))))) 191 | user_level[n2idx[n]] = lv_data 192 | 193 | inf_data = np.hstack((s[3], np.array([Constants.PAD] * (max_len - len(s[3]))))) 194 | user_inf[n2idx[n]] = inf_data 195 | 196 | useq = np.zeros((user_size + 1, max_len)) 197 | uinfs = np.zeros((user_size + 1, max_len)) 198 | 199 | for n, s in ucascades.items(): 200 | if len(s[0])