├── .gitattributes
├── data
├── gos
│ ├── test_idx.npy
│ ├── val_idx.npy
│ ├── train_idx.npy
│ ├── struct_temp.pkl
│ ├── news_mapping.pickle
│ ├── user_mapping.pickle
│ └── gos_news_list.txt
└── poli
│ ├── val_idx.npy
│ ├── test_idx.npy
│ ├── train_idx.npy
│ ├── struct_temp.pkl
│ ├── news_mapping.pickle
│ ├── user_mapping.pickle
│ ├── poli_news_list.txt
│ └── label.txt
├── __pycache__
├── HGAT.cpython-36.pyc
├── HGAT.cpython-37.pyc
├── HGSL.cpython-37.pyc
├── Optim.cpython-36.pyc
├── Optim.cpython-37.pyc
├── layer.cpython-36.pyc
├── layer.cpython-37.pyc
├── Constants.cpython-36.pyc
├── Constants.cpython-37.pyc
├── Metrics.cpython-36.pyc
├── Metrics.cpython-37.pyc
├── dataLoader.cpython-36.pyc
├── dataLoader.cpython-37.pyc
├── graphConstruct.cpython-36.pyc
├── graphConstruct.cpython-37.pyc
├── TransformerBlock.cpython-36.pyc
└── TransformerBlock.cpython-37.pyc
├── .idea
├── vcs.xml
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── encodings.xml
├── modules.xml
├── misc.xml
├── sshConfigs.xml
├── MS-HGAT.iml
├── webServers.xml
├── remote-mappings.xml
└── deployment.xml
├── Constants.py
├── Metrics.py
├── Optim.py
├── dataLoader.py
├── run.py
├── TransformerBlock.py
├── HGSL.py
└── Data_preprocessing.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/data/gos/test_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/test_idx.npy
--------------------------------------------------------------------------------
/data/gos/val_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/val_idx.npy
--------------------------------------------------------------------------------
/data/poli/val_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/val_idx.npy
--------------------------------------------------------------------------------
/data/gos/train_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/train_idx.npy
--------------------------------------------------------------------------------
/data/poli/test_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/test_idx.npy
--------------------------------------------------------------------------------
/data/poli/train_idx.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/train_idx.npy
--------------------------------------------------------------------------------
/data/gos/struct_temp.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/struct_temp.pkl
--------------------------------------------------------------------------------
/data/poli/struct_temp.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/struct_temp.pkl
--------------------------------------------------------------------------------
/data/gos/news_mapping.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/news_mapping.pickle
--------------------------------------------------------------------------------
/data/gos/user_mapping.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/gos/user_mapping.pickle
--------------------------------------------------------------------------------
/__pycache__/HGAT.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGAT.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/HGAT.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGAT.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/HGSL.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/HGSL.cpython-37.pyc
--------------------------------------------------------------------------------
/data/poli/news_mapping.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/news_mapping.pickle
--------------------------------------------------------------------------------
/data/poli/user_mapping.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/data/poli/user_mapping.pickle
--------------------------------------------------------------------------------
/__pycache__/Optim.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Optim.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/Optim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Optim.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/layer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/layer.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/layer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/layer.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/Constants.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Constants.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/Constants.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Constants.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/Metrics.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Metrics.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/Metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/Metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/dataLoader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/dataLoader.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/dataLoader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/dataLoader.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/graphConstruct.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/graphConstruct.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/graphConstruct.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/graphConstruct.cpython-37.pyc
--------------------------------------------------------------------------------
/__pycache__/TransformerBlock.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/TransformerBlock.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/TransformerBlock.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/slingling/HG-SL/HEAD/__pycache__/TransformerBlock.cpython-37.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/Constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 | PAD = 0
3 |
4 | step_split = 2
5 | n_heads = 14
6 |
7 | #cate = ['retweet', 'support', 'deny']
8 | cate = ['retweet']
9 | early_type = 'time' # 'engage' or 'time'
10 |
11 | GPU = torch.cuda.is_available()
12 | device = torch.device('cuda' if GPU else "cpu")
13 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/sshConfigs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/MS-HGAT.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/Metrics.py:
--------------------------------------------------------------------------------
1 |
2 | import numpy as np
3 |
4 | from sklearn.metrics import f1_score, accuracy_score, recall_score, precision_score, roc_auc_score, average_precision_score
5 |
6 |
7 | """
8 | Utility functions for evaluating the model performance
9 | """
10 |
11 | class Metrics(object):
12 |
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def compute_metric(self, y_prob, y_true):
17 | k_list = ['Acc', 'F1', 'Pre', 'Recall']
18 | y_pre = np.array(y_prob).argmax(axis=1)
19 | size = len(y_prob)
20 | assert len(y_prob) == len(y_true)
21 |
22 | scores = {str(k): 0.0 for k in k_list}
23 | scores['Acc'] += accuracy_score(y_true, y_pre) * size
24 | scores['F1'] += f1_score(y_true, y_pre, average='macro') * size
25 | scores['Pre'] += precision_score(y_true, y_pre, zero_division=0) * size
26 | scores['Recall'] += recall_score(y_true, y_pre, zero_division=0) * size
27 |
28 | # y_true = np.array(y_true)
29 | # prob_log = y_prob[:, 1].tolist()
30 | #scores['auc'] = roc_auc_score(y_true, prob_log)
31 |
32 | return scores
33 |
34 |
35 |
--------------------------------------------------------------------------------
/Optim.py:
--------------------------------------------------------------------------------
1 | '''A wrapper class for optimizer '''
2 | import numpy as np
3 |
4 | class ScheduledOptim(object):
5 | '''A simple wrapper class for learning rate scheduling'''
6 |
7 | def __init__(self, optimizer, d_model, n_warmup_steps):
8 | self.optimizer = optimizer
9 | self.d_model = d_model
10 | self.n_warmup_steps = n_warmup_steps
11 | self.n_current_steps = 0
12 |
13 | def step(self):
14 | "Step by the inner optimizer"
15 | self.optimizer.step()
16 |
17 | def zero_grad(self):
18 | "Zero out the gradients by the inner optimizer"
19 | self.optimizer.zero_grad()
20 |
21 | def update_learning_rate(self):
22 | ''' Learning rate scheduling per step '''
23 | self.n_current_steps += 1
24 | new_lr = np.power(self.d_model, -0.5) * np.min([
25 | np.power(self.n_current_steps, -0.5),
26 | np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])
27 |
28 | for param_group in self.optimizer.param_groups:
29 | param_group['lr'] = new_lr
30 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/dataLoader.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Created on Nov 1 22:28:02 2021
4 |
5 | @author: Ling Sun
6 | """
7 | import numpy as np
8 | import torch
9 | from torch.autograd import Variable
10 | import Constants
11 | import pickle
12 |
13 | class Options(object):
14 |
15 | def __init__(self, data_name='poli'):
16 | self.news_centered = 'data/' + data_name + '/Processed/news_centered.pickle'
17 | self.user_centered = 'data/' + data_name + '/Processed/user_centered.pickle'
18 |
19 | #self.user_features = 'data/' + data_name + '/user_features.pickle'
20 | self.test_data = 'data/' + data_name + '/Processed/test_processed.pickle'
21 | self.valid_data = 'data/' + data_name + '/Processed/valid_processed.pickle'
22 | self.train_data = 'data/' + data_name + '/Processed/train_processed.pickle'
23 | self.news_features = 'data/' + data_name + '/struct_temp.pkl'
24 | self.news_mapping = 'data/' + data_name + '/news_mapping.pickle'
25 |
26 | self.save_path = ''
27 |
28 | def DataReader(data_name):
29 | options = Options(data_name)
30 | with open(options.train_data, 'rb') as f:
31 | train_data = pickle.load(f)
32 | with open(options.valid_data, 'rb') as f:
33 | valid_data = pickle.load(f)
34 | with open(options.test_data, 'rb') as f:
35 | test_data = pickle.load(f)
36 |
37 | #print(train_data)
38 |
39 | total_size = len(train_data[0])+len(test_data[0])+len(valid_data[0])
40 |
41 | print("news cascades size:%d " % (total_size))
42 | print("train size:%d " % (len(train_data[0])))
43 | print("test and valid size:%d " % (len(test_data[0])+len(valid_data[0])))
44 |
45 | return train_data, valid_data, test_data, total_size
46 |
47 | def FeatureReader(data_name):
48 | options = Options(data_name)
49 | with open(options.news_mapping, 'rb') as handle:
50 | n2idx = pickle.load(handle)
51 | '''Spread status: S1, S2, T1, T2
52 | Structural:(S1)number of sub-cascades, (S2)proportion of non-isolated cascades;
53 | Temporal: (T1) duration of spread,(T2) the average response time from tweet to retweet'''
54 | with open(options.news_features, 'rb') as f:
55 | features = np.array(pickle.load(f))
56 | news_size = len(features)
57 | spread_status = np.zeros((news_size + 1, 4))
58 | for news in features:
59 | #print(news)
60 | spread_status[n2idx[news[0]]]=np.array(news[1:])
61 | #print(spread_status[n2idx[news[0]]])
62 | return spread_status
63 |
64 | def GraphReader(data_name):
65 | options = Options(data_name)
66 | with open(options.news_centered, 'rb') as f:
67 | news_centered_graph = pickle.load(f)
68 |
69 | with open(options.user_centered, 'rb') as f:
70 | user_centered_graph = pickle.load(f)
71 |
72 | useq, user_inf = (item for item in user_centered_graph)
73 | seq, timestamps, user_level, news_inf = (item for item in news_centered_graph)
74 | spread_status = FeatureReader(data_name)
75 |
76 | user_size = len(useq)
77 | user_inf[user_inf>0] = 1
78 | act_level = user_inf[1:].sum(1)
79 | avg_inf = np.append([0],act_level)
80 |
81 | news_centered_graph = [seq, timestamps, user_level]
82 | user_centered_graph = [useq, news_inf, avg_inf]
83 |
84 |
85 | return [[torch.LongTensor(i).to(Constants.device) for i in news_centered_graph], [torch.LongTensor(i).to(Constants.device) for i in user_centered_graph],
86 | torch.LongTensor(spread_status).to(Constants.device)], user_size
87 |
88 | class DataLoader(object):
89 | ''' For data iteration '''
90 |
91 | def __init__(
92 | self, data, batch_size=64, cuda=True, test=False):
93 | self._batch_size = batch_size
94 | self.idx = data[0]
95 | self.label = data[1]
96 | self.test = test
97 | self.cuda = cuda
98 |
99 |
100 | self._n_batch = int(np.ceil(len(self.idx) / self._batch_size))
101 | self._iter_count = 0
102 |
103 | def __iter__(self):
104 | return self
105 |
106 | def __next__(self):
107 | return self.next()
108 |
109 | def __len__(self):
110 | return self._n_batch
111 |
112 | def next(self):
113 | ''' Get the next batch '''
114 |
115 | def seq_to_tensor(insts):
116 |
117 | inst_data_tensor = Variable(
118 | torch.LongTensor(insts), volatile=self.test)
119 |
120 | return inst_data_tensor
121 |
122 | if self._iter_count < self._n_batch:
123 | batch_idx = self._iter_count
124 | self._iter_count += 1
125 |
126 | start_idx = batch_idx * self._batch_size
127 | end_idx = (batch_idx + 1) * self._batch_size
128 |
129 | idx = self.idx[start_idx:end_idx]
130 | labels = self.label[start_idx:end_idx]
131 | idx = seq_to_tensor(idx)
132 | labels = seq_to_tensor(labels)
133 |
134 | return idx, labels
135 | else:
136 |
137 | self._iter_count = 0
138 | raise StopIteration()
139 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Jan 18 22:42:32 2021
4 |
5 | @author: Ling Sun
6 | """
7 |
8 | import argparse
9 | import time
10 | import numpy as np
11 | import Constants
12 | import torch
13 | from dataLoader import DataReader, GraphReader, DataLoader
14 | from Metrics import Metrics
15 | from HGSL import HGSL
16 | from Optim import ScheduledOptim
17 | import torch.nn.functional as F
18 |
19 |
20 | torch.backends.cudnn.deterministic = True
21 | torch.manual_seed(0)
22 | torch.cuda.manual_seed_all(0)
23 | np.random.seed(0)
24 | torch.cuda.manual_seed(0)
25 |
26 | metric = Metrics()
27 |
28 |
29 | parser = argparse.ArgumentParser()
30 | parser.add_argument('-data_name', default='poli')
31 | parser.add_argument('-epoch', type=int, default=200)
32 | parser.add_argument('-batch_size', type=int, default=32)
33 | parser.add_argument('-d_model', type=int, default=64)
34 | parser.add_argument('-initialFeatureSize', type=int, default=64)
35 | parser.add_argument('-early_time', type=int, default=10)
36 | parser.add_argument('-n_warmup_steps', type=int, default=1000)
37 | parser.add_argument('-dropout', type=float, default=0.5)
38 | parser.add_argument('-save_path', default= "./checkpoint/fake_detection.pt")
39 | parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best')
40 | parser.add_argument('-no_cuda', action='store_true')
41 |
42 | opt = parser.parse_args()
43 |
44 |
45 | def train_epoch(model, training_data, hypergraph_list, optimizer):
46 | # train
47 | model.train()
48 | total_loss = 0.0
49 |
50 | for i, batch in enumerate(training_data):
51 | # data preparing
52 | tgt, labels = (item.to(Constants.device) for item in batch)
53 | # training
54 | optimizer.zero_grad()
55 | pred= model(tgt, hypergraph_list)
56 |
57 | # loss
58 | loss = F.nll_loss(pred, labels.squeeze())
59 | loss.backward()
60 |
61 | # parameter update
62 | optimizer.step()
63 | optimizer.update_learning_rate()
64 |
65 | total_loss += loss.item()
66 |
67 | return total_loss
68 |
69 | def train_model(HGSL, data_name):
70 | # ========= Preparing DataLoader =========#
71 | train, valid, test, news_size, = DataReader(data_name)
72 | hypergraph_list, user_size = GraphReader(data_name)
73 |
74 | train_data = DataLoader(train, batch_size=opt.batch_size, cuda=False)
75 | valid_data = DataLoader(valid, batch_size=opt.batch_size, cuda=False)
76 | test_data = DataLoader(test, batch_size=opt.batch_size, cuda=False)
77 |
78 |
79 | opt.user_size = user_size
80 | opt.edge_size = news_size+1
81 |
82 | # ========= Preparing Model =========#
83 | model = HGSL(opt)
84 | params = model.parameters()
85 | optimizerAdam = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-09)
86 | optimizer = ScheduledOptim(optimizerAdam, opt.d_model, opt.n_warmup_steps)
87 |
88 | if torch.cuda.is_available():
89 | model = model.to(Constants.device)
90 |
91 | validation_history = 0.0
92 | best_scores = {}
93 | for epoch_i in range(opt.epoch):
94 | print('\n[ Epoch', epoch_i, ']')
95 |
96 | start = time.time()
97 | train_loss = train_epoch(model, train_data, hypergraph_list, optimizer)
98 | print(' - (Training) loss: {loss: 8.5f} %, ' \
99 | 'elapse: {elapse:3.3f} min'.format(
100 | loss=train_loss,
101 | elapse=(time.time() - start) / 60))
102 |
103 | if epoch_i > 5:
104 | #start = time.time()
105 | scores = test_epoch(model, valid_data, hypergraph_list)
106 | print(' - (Validation) ')
107 | for metric in scores.keys():
108 | print(metric + ': ' + "%.5f"%(scores[metric]*100) +"%")
109 |
110 | print(' - (Test) ')
111 | scores = test_epoch(model, test_data, hypergraph_list)
112 | for metric in scores.keys():
113 | print(metric + ': ' + "%.5f"%(scores[metric]*100) +"%")
114 |
115 | if validation_history <= sum(scores.values()):
116 | print("Best Test Accuracy:{}% at Epoch:{}".format(round(scores["Acc"]*100,5), epoch_i))
117 | validation_history = sum(scores.values())
118 | best_scores = scores
119 | print("Save best model!!!")
120 | torch.save(model.state_dict(), opt.save_path)
121 |
122 | print(" - (Finished!!) \n Best scores: ")
123 | for metric in best_scores.keys():
124 | print(metric + ': ' + "%.5f"%(best_scores[metric]*100) +"%")
125 |
126 | def test_epoch(model, validation_data, hypergraph_list):
127 | ''' Epoch operation in evaluation phase '''
128 | model.eval()
129 |
130 | scores = {}
131 | k_list = ['Acc', 'F1', 'Pre', 'Recall']
132 | for k in k_list:
133 | scores[k] = 0
134 |
135 | n_total_words = 0
136 | with torch.no_grad():
137 | for i, batch in enumerate(validation_data):
138 | tgt, labels = (item.to(Constants.device) for item in batch)
139 | y_labels = labels.detach().cpu().numpy()
140 | # forward
141 | pred = model(tgt, hypergraph_list)
142 | y_pred = pred.detach().cpu().numpy()
143 | n_total_words += len(tgt)
144 |
145 | scores_batch= metric.compute_metric(y_pred, y_labels)
146 | for k in k_list:
147 | scores[k] += scores_batch[k]
148 |
149 | for k in k_list:
150 | scores[k] = scores[k] / n_total_words
151 | return scores
152 |
153 | if __name__ == "__main__":
154 | model = HGSL
155 | train_model(model, opt.data_name)
156 |
157 |
158 |
159 |
--------------------------------------------------------------------------------
/TransformerBlock.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn.init as init
4 | import torch.nn.functional as F
5 | import torch.nn as nn
6 | import math
7 | import Constants
8 |
9 |
10 | class PositionalEncoding(nn.Module):
11 | "Implement the PE function."
12 |
13 | def __init__(self, d_model, dropout, max_len=800):
14 | super(PositionalEncoding, self).__init__()
15 | self.dropout = nn.Dropout(p=dropout)
16 |
17 | # Compute the positional encodings once in log space.
18 | pe = torch.zeros(max_len, d_model)
19 | position = torch.arange(0, max_len).unsqueeze(1).float()
20 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
21 | pe[:, 0::2] = torch.sin(position * div_term)
22 | pe[:, 1::2] = torch.cos(position * div_term)
23 | pe = pe.unsqueeze(0)
24 | self.register_buffer('pe', pe)
25 |
26 | def forward(self, x):
27 | x = x + self.pe[:, :x.size(1)]
28 | return self.dropout(x)
29 |
30 | class TransformerBlock(nn.Module):
31 |
32 | def __init__(self, input_size, n_heads=2, is_layer_norm=True, attn_dropout=0.1):
33 | super(TransformerBlock, self).__init__()
34 | self.n_heads = n_heads
35 | self.d_k = input_size
36 | self.d_v = input_size
37 |
38 | self.is_layer_norm = is_layer_norm
39 | if is_layer_norm:
40 | self.layer_norm = nn.LayerNorm(normalized_shape=input_size)
41 |
42 | self.pos_encoding= PositionalEncoding(d_model=input_size, dropout=0.5)
43 |
44 | self.W_q = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_k))
45 | self.W_k = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_k))
46 | self.W_v = nn.Parameter(torch.Tensor(input_size, n_heads * self.d_v))
47 |
48 | self.W_o = nn.Parameter(torch.Tensor(self.d_v*n_heads, input_size))
49 | self.linear1 = nn.Linear(input_size, input_size)
50 | self.linear2 = nn.Linear(input_size, input_size)
51 |
52 | self.dropout = nn.Dropout(attn_dropout)
53 | self.__init_weights__()
54 |
55 | def __init_weights__(self):
56 | init.xavier_normal_(self.W_q)
57 | init.xavier_normal_(self.W_k)
58 | init.xavier_normal_(self.W_v)
59 | init.xavier_normal_(self.W_o)
60 |
61 | init.xavier_normal_(self.linear1.weight)
62 | init.xavier_normal_(self.linear2.weight)
63 |
64 | def FFN(self, X):
65 | output = self.linear2(F.relu(self.linear1(X)))
66 | output = self.dropout(output)
67 | return output
68 |
69 | def scaled_dot_product_attention(self, Q, K, V, mask, episilon=1e-6):
70 | '''
71 | :param Q: (*, max_q_words, n_heads, input_size)
72 | :param K: (*, max_k_words, n_heads, input_size)
73 | :param V: (*, max_v_words, n_heads, input_size)
74 | :param mask: (*, max_q_words)
75 | :param episilon:
76 | :return:
77 | '''
78 | temperature = self.d_k ** 0.5
79 |
80 | Q_K = (torch.einsum("bqd,bkd->bqk", Q, K)) / (temperature + episilon)
81 | if mask is not None:
82 | pad_mask = mask.unsqueeze(dim=-1).expand(-1, -1, K.size(1))
83 | mask = torch.triu(torch.ones(pad_mask.size()), diagonal=1).bool().to(Constants.device)
84 | mask_ = mask + pad_mask
85 | Q_K = Q_K.masked_fill(mask_, -2**32+1)
86 |
87 | Q_K_score = F.softmax(Q_K, dim=-1) # (batch_size, max_q_words, max_k_words)
88 | Q_K_score = self.dropout(Q_K_score)
89 | #维度为3的两个矩阵的乘法
90 | V_att = Q_K_score.bmm(V) # (*, max_q_words, input_size)
91 | return V_att
92 |
93 |
94 | def multi_head_attention(self, Q, K, V, mask):
95 | '''
96 | :param Q:
97 | :param K:
98 | :param V:
99 | :param mask: (bsz, max_q_words)
100 | :return:
101 | '''
102 | bsz, q_len, _ = Q.size()
103 | bsz, k_len, _ = K.size()
104 | bsz, v_len, _ = V.size()
105 | #print(self.W_q.size(), bsz, q_len, self.n_heads, self.d_k)
106 | Q_ = Q.matmul(self.W_q).view(bsz, q_len, self.n_heads, self.d_k)
107 | K_ = K.matmul(self.W_k).view(bsz, k_len, self.n_heads, self.d_k)
108 | V_ = V.matmul(self.W_v).view(bsz, v_len, self.n_heads, self.d_v)
109 | #print(Q_.size(), bsz, q_len, self.n_heads, self.d_k)
110 | Q_ = Q_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_k)
111 | K_ = K_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_k)
112 | V_ = V_.permute(0, 2, 1, 3).contiguous().view(bsz*self.n_heads, q_len, self.d_v)
113 |
114 | if mask is not None:
115 | mask = mask.unsqueeze(dim=1).expand(-1, self.n_heads, -1) # For head axis broadcasting.
116 | mask = mask.reshape(-1, mask.size(-1))
117 |
118 | V_att = self.scaled_dot_product_attention(Q_, K_, V_, mask)
119 | V_att = V_att.view(bsz, self.n_heads, q_len, self.d_v)
120 | V_att = V_att.permute(0, 2, 1, 3).contiguous().view(bsz, q_len, self.n_heads*self.d_v)
121 |
122 | output = self.dropout(V_att.matmul(self.W_o)) # (batch_size, max_q_words, input_size)
123 | return output
124 |
125 |
126 | def forward(self, Q, K, V, mask=None, pos = True):
127 | '''
128 | :param Q: (batch_size, max_q_words, input_size)
129 | :param K: (batch_size, max_k_words, input_size)
130 | :param V: (batch_size, max_v_words, input_size)
131 | :return: output: (batch_size, max_q_words, input_size) same size as Q
132 | '''
133 | if pos:
134 | Q = self.pos_encoding(Q)
135 | K = self.pos_encoding(K)
136 | V = self.pos_encoding(V)
137 |
138 | V_att = self.multi_head_attention(Q, K, V, mask)
139 |
140 | if self.is_layer_norm:
141 | X = self.layer_norm(Q + V_att) # (batch_size, max_r_words, embedding_dim)
142 | output = self.layer_norm(self.FFN(X) + X)
143 | else:
144 | X = Q + V_att
145 | output = self.FFN(X) + X
146 | return output
147 |
--------------------------------------------------------------------------------
/data/poli/poli_news_list.txt:
--------------------------------------------------------------------------------
1 | politifact4190
2 | politifact6657
3 | politifact582
4 | politifact6646
5 | politifact13138
6 | politifact13068
7 | politifact720
8 | politifact4181
9 | politifact7511
10 | politifact9802
11 | politifact8989
12 | politifact548
13 | politifact3228
14 | politifact7376
15 | politifact13682
16 | politifact7506
17 | politifact8069
18 | politifact695
19 | politifact6932
20 | politifact31
21 | politifact2393
22 | politifact979
23 | politifact6931
24 | politifact1052
25 | politifact13283
26 | politifact2048
27 | politifact379
28 | politifact14940
29 | politifact9033
30 | politifact6360
31 | politifact10332
32 | politifact12052
33 | politifact1446
34 | politifact12721
35 | politifact779
36 | politifact11777
37 | politifact6519
38 | politifact6641
39 | politifact8310
40 | politifact1500
41 | politifact12945
42 | politifact1575
43 | politifact1337
44 | politifact423
45 | politifact2128
46 | politifact13244
47 | politifact10787
48 | politifact11627
49 | politifact6556
50 | politifact1454
51 | politifact11899
52 | politifact12057
53 | politifact1690
54 | politifact1467
55 | politifact426
56 | politifact4275
57 | politifact2298
58 | politifact421
59 | politifact1519
60 | politifact809
61 | politifact4586
62 | politifact554
63 | politifact806
64 | politifact4588
65 | politifact5321
66 | politifact11552
67 | politifact7182
68 | politifact11960
69 | politifact186
70 | politifact8470
71 | politifact12627
72 | politifact73
73 | politifact3428
74 | politifact14174
75 | politifact10533
76 | politifact8737
77 | politifact8805
78 | politifact5608
79 | politifact160
80 | politifact462
81 | politifact7489
82 | politifact339
83 | politifact4433
84 | politifact8557
85 | politifact5237
86 | politifact8259
87 | politifact11066
88 | politifact780
89 | politifact5469
90 | politifact11314
91 | politifact52
92 | politifact13013
93 | politifact1714
94 | politifact245
95 | politifact12801
96 | politifact5659
97 | politifact8611
98 | politifact10185
99 | politifact7665
100 | politifact15453
101 | politifact7563
102 | politifact608
103 | politifact11577
104 | politifact9438
105 | politifact13305
106 | politifact10731
107 | politifact13477
108 | politifact581
109 | politifact6998
110 | politifact3527
111 | politifact1106
112 | politifact11761
113 | politifact6473
114 | politifact12755
115 | politifact724
116 | politifact1028
117 | politifact13548
118 | politifact2166
119 | politifact1424
120 | politifact13132
121 | politifact14064
122 | politifact14511
123 | politifact2139
124 | politifact230
125 | politifact8172
126 | politifact7259
127 | politifact134
128 | politifact10276
129 | politifact746
130 | politifact14474
131 | politifact8130
132 | politifact12418
133 | politifact1185
134 | politifact74
135 | politifact13303
136 | politifact681
137 | politifact13052
138 | politifact440
139 | politifact986
140 | politifact943
141 | politifact10903
142 | politifact150
143 | politifact8071
144 | politifact65
145 | politifact206
146 | politifact9622
147 | politifact228
148 | politifact1177
149 | politifact11189
150 | politifact6603
151 | politifact12944
152 | politifact128
153 | politifact9691
154 | politifact15645
155 | politifact208
156 | politifact13395
157 | politifact514
158 | politifact14960
159 | politifact13561
160 | politifact14587
161 | politifact13765
162 | politifact14205
163 | politifact13978
164 | politifact13565
165 | politifact14040
166 | politifact15232
167 | politifact14776
168 | politifact13887
169 | politifact14062
170 | politifact14516
171 | politifact14621
172 | politifact14119
173 | politifact14507
174 | politifact14755
175 | politifact14472
176 | politifact13784
177 | politifact14947
178 | politifact15267
179 | politifact13591
180 | politifact15156
181 | politifact14356
182 | politifact14927
183 | politifact14815
184 | politifact14795
185 | politifact15533
186 | politifact15545
187 | politifact14722
188 | politifact15352
189 | politifact15367
190 | politifact14135
191 | politifact14860
192 | politifact15204
193 | politifact14818
194 | politifact14500
195 | politifact14187
196 | politifact15514
197 | politifact15409
198 | politifact14207
199 | politifact14166
200 | politifact15327
201 | politifact13827
202 | politifact13731
203 | politifact14503
204 | politifact15210
205 | politifact13836
206 | politifact14005
207 | politifact14021
208 | politifact14169
209 | politifact14273
210 | politifact14063
211 | politifact14693
212 | politifact13934
213 | politifact13744
214 | politifact14993
215 | politifact15224
216 | politifact13698
217 | politifact14991
218 | politifact14893
219 | politifact15456
220 | politifact14794
221 | politifact15096
222 | politifact14051
223 | politifact14362
224 | politifact14448
225 | politifact13600
226 | politifact14788
227 | politifact15494
228 | politifact14879
229 | politifact14718
230 | politifact15049
231 | politifact14333
232 | politifact14426
233 | politifact15591
234 | politifact15349
235 | politifact15525
236 | politifact13584
237 | politifact14406
238 | politifact15341
239 | politifact15217
240 | politifact14876
241 | politifact14789
242 | politifact14905
243 | politifact14355
244 | politifact15161
245 | politifact15095
246 | politifact13912
247 | politifact13577
248 | politifact14733
249 | politifact13038
250 | politifact15423
251 | politifact15532
252 | politifact14544
253 | politifact14386
254 | politifact15266
255 | politifact15626
256 | politifact14785
257 | politifact14164
258 | politifact15534
259 | politifact14258
260 | politifact15146
261 | politifact15540
262 | politifact13823
263 | politifact14235
264 | politifact14264
265 | politifact13663
266 | politifact15486
267 | politifact15307
268 | politifact14694
269 | politifact15630
270 | politifact13982
271 | politifact15579
272 | politifact15097
273 | politifact14328
274 | politifact15130
275 | politifact14330
276 | politifact14699
277 | politifact13921
278 | politifact13468
279 | politifact15309
280 | politifact14003
281 | politifact13942
282 | politifact15262
283 | politifact15477
284 | politifact15539
285 | politifact15188
286 | politifact14222
287 | politifact15383
288 | politifact14517
289 | politifact15623
290 | politifact14128
291 | politifact13766
292 | politifact14556
293 | politifact14447
294 | politifact14402
295 | politifact14595
296 | politifact15159
297 | politifact15564
298 | politifact14890
299 | politifact14605
300 | politifact15301
301 | politifact14361
302 | politifact13999
303 | politifact15356
304 | politifact14395
305 | politifact15554
306 | politifact13931
307 | politifact14576
308 | politifact14270
309 | politifact13943
310 | politifact14085
311 | politifact15246
312 | politifact13794
313 | politifact14469
314 | politifact15178
315 |
--------------------------------------------------------------------------------
/data/poli/label.txt:
--------------------------------------------------------------------------------
1 | politifact13584 1
2 | politifact8172 0
3 | politifact15477 1
4 | politifact6657 0
5 | politifact14135 1
6 | politifact13912 1
7 | politifact1446 0
8 | politifact7489 0
9 | politifact13468 1
10 | politifact13052 0
11 | politifact14940 0
12 | politifact14402 1
13 | politifact13943 1
14 | politifact14927 1
15 | politifact6932 0
16 | politifact15534 1
17 | politifact14517 1
18 | politifact780 0
19 | politifact2166 0
20 | politifact10185 0
21 | politifact15367 1
22 | politifact608 0
23 | politifact15349 1
24 | politifact15352 1
25 | politifact14960 1
26 | politifact15341 1
27 | politifact2139 0
28 | politifact13982 1
29 | politifact14605 1
30 | politifact8259 0
31 | politifact15525 1
32 | politifact13934 1
33 | politifact5321 0
34 | politifact1052 0
35 | politifact73 0
36 | politifact13698 1
37 | politifact14063 1
38 | politifact1337 0
39 | politifact14258 1
40 | politifact14795 1
41 | politifact15453 0
42 | politifact462 0
43 | politifact8470 0
44 | politifact809 0
45 | politifact7665 0
46 | politifact14876 1
47 | politifact13548 0
48 | politifact14507 1
49 | politifact15630 1
50 | politifact14119 1
51 | politifact5608 0
52 | politifact14205 1
53 | politifact15626 1
54 | politifact14755 1
55 | politifact14788 1
56 | politifact12944 0
57 | politifact681 0
58 | politifact13477 0
59 | politifact12057 0
60 | politifact779 0
61 | politifact245 0
62 | politifact5469 0
63 | politifact150 0
64 | politifact15210 1
65 | politifact5659 0
66 | politifact4586 0
67 | politifact14503 1
68 | politifact8805 0
69 | politifact52 0
70 | politifact10533 0
71 | politifact13682 0
72 | politifact12755 0
73 | politifact514 0
74 | politifact14395 1
75 | politifact14595 1
76 | politifact14235 1
77 | politifact14386 1
78 | politifact13132 0
79 | politifact14500 1
80 | politifact379 0
81 | politifact14330 1
82 | politifact13663 1
83 | politifact13942 1
84 | politifact13038 1
85 | politifact14062 1
86 | politifact15307 1
87 | politifact12418 0
88 | politifact8611 0
89 | politifact14333 1
90 | politifact14040 1
91 | politifact14694 1
92 | politifact11189 0
93 | politifact339 0
94 | politifact15579 1
95 | politifact15591 1
96 | politifact13561 1
97 | politifact8310 0
98 | politifact9438 0
99 | politifact15188 1
100 | politifact14448 1
101 | politifact14273 1
102 | politifact6641 0
103 | politifact14776 1
104 | politifact14264 1
105 | politifact15049 1
106 | politifact943 0
107 | politifact12801 0
108 | politifact13013 0
109 | politifact979 0
110 | politifact13836 1
111 | politifact14556 1
112 | politifact14722 1
113 | politifact11960 0
114 | politifact2128 0
115 | politifact65 0
116 | politifact13591 1
117 | politifact1185 0
118 | politifact15532 1
119 | politifact806 0
120 | politifact6360 0
121 | politifact10903 0
122 | politifact14718 1
123 | politifact14893 1
124 | politifact14222 1
125 | politifact13766 1
126 | politifact14207 1
127 | politifact14576 1
128 | politifact548 0
129 | politifact1519 0
130 | politifact2393 0
131 | politifact421 0
132 | politifact8069 0
133 | politifact13999 1
134 | politifact13744 1
135 | politifact7182 0
136 | politifact720 0
137 | politifact8989 0
138 | politifact581 0
139 | politifact14587 1
140 | politifact14818 1
141 | politifact14362 1
142 | politifact15309 1
143 | politifact1714 0
144 | politifact10332 0
145 | politifact1106 0
146 | politifact14003 1
147 | politifact14187 1
148 | politifact7376 0
149 | politifact1467 0
150 | politifact14794 1
151 | politifact11552 0
152 | politifact13244 0
153 | politifact12052 0
154 | politifact13305 0
155 | politifact7563 0
156 | politifact8737 0
157 | politifact13565 1
158 | politifact4275 0
159 | politifact15266 1
160 | politifact14991 1
161 | politifact13303 0
162 | politifact15540 1
163 | politifact13784 1
164 | politifact6473 0
165 | politifact11761 0
166 | politifact14164 1
167 | politifact15224 1
168 | politifact15301 1
169 | politifact6931 0
170 | politifact14356 1
171 | politifact14785 1
172 | politifact15327 1
173 | politifact14544 1
174 | politifact3228 0
175 | politifact14085 1
176 | politifact3527 0
177 | politifact14815 1
178 | politifact13931 1
179 | politifact15204 1
180 | politifact11066 0
181 | politifact13138 0
182 | politifact13794 1
183 | politifact15645 0
184 | politifact74 0
185 | politifact15494 1
186 | politifact9622 0
187 | politifact13068 0
188 | politifact6603 0
189 | politifact14270 1
190 | politifact15156 1
191 | politifact13887 1
192 | politifact13283 0
193 | politifact160 0
194 | politifact13731 1
195 | politifact15486 1
196 | politifact1177 0
197 | politifact1424 0
198 | politifact14789 1
199 | politifact14947 1
200 | politifact13921 1
201 | politifact6519 0
202 | politifact208 0
203 | politifact14406 1
204 | politifact4181 0
205 | politifact986 0
206 | politifact15095 1
207 | politifact6646 0
208 | politifact15533 1
209 | politifact14361 1
210 | politifact9802 0
211 | politifact14693 1
212 | politifact186 0
213 | politifact15423 1
214 | politifact128 0
215 | politifact14905 1
216 | politifact15545 1
217 | politifact14469 1
218 | politifact13827 1
219 | politifact15356 1
220 | politifact14860 1
221 | politifact230 0
222 | politifact15130 1
223 | politifact15539 1
224 | politifact14621 1
225 | politifact15097 1
226 | politifact1575 0
227 | politifact9691 0
228 | politifact7259 0
229 | politifact13823 1
230 | politifact15623 1
231 | politifact15564 1
232 | politifact3428 0
233 | politifact14733 1
234 | politifact1454 0
235 | politifact14021 1
236 | politifact14064 0
237 | politifact15096 1
238 | politifact12721 0
239 | politifact14169 1
240 | politifact11314 0
241 | politifact12945 0
242 | politifact11777 0
243 | politifact15456 1
244 | politifact15246 1
245 | politifact13600 1
246 | politifact15383 1
247 | politifact11577 0
248 | politifact8071 0
249 | politifact14051 1
250 | politifact14426 1
251 | politifact8557 0
252 | politifact11627 0
253 | politifact14890 1
254 | politifact13978 1
255 | politifact426 0
256 | politifact14447 1
257 | politifact10787 0
258 | politifact14174 0
259 | politifact15262 1
260 | politifact15409 1
261 | politifact15554 1
262 | politifact2048 0
263 | politifact554 0
264 | politifact6998 0
265 | politifact4190 0
266 | politifact14166 1
267 | politifact13395 0
268 | politifact15267 1
269 | politifact1028 0
270 | politifact13577 1
271 | politifact14511 0
272 | politifact15217 1
273 | politifact14128 1
274 | politifact206 0
275 | politifact1500 0
276 | politifact13765 1
277 | politifact1690 0
278 | politifact4433 0
279 | politifact8130 0
280 | politifact12627 0
281 | politifact11899 0
282 | politifact15514 1
283 | politifact14516 1
284 | politifact14879 1
285 | politifact695 0
286 | politifact228 0
287 | politifact15161 1
288 | politifact15159 1
289 | politifact31 0
290 | politifact14472 1
291 | politifact5237 0
292 | politifact14699 1
293 | politifact14474 0
294 | politifact14328 1
295 | politifact15232 1
296 | politifact6556 0
297 | politifact10731 0
298 | politifact10276 0
299 | politifact7506 0
300 | politifact746 0
301 | politifact440 0
302 | politifact15178 1
303 | politifact14005 1
304 | politifact9033 0
305 | politifact134 0
306 | politifact14355 1
307 | politifact15146 1
308 | politifact423 0
309 | politifact2298 0
310 | politifact724 0
311 | politifact4588 0
312 | politifact7511 0
313 | politifact14993 1
314 | politifact582 0
315 |
--------------------------------------------------------------------------------
/HGSL.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Mon Jan 18 22:30:16 2021
4 |
5 | @author: Ling Sun
6 | """
7 |
8 | import math
9 | #import numpy as np
10 | import torch
11 | from torch import nn
12 | import torch.nn.functional as F
13 | #from layer import HGATLayer
14 | import torch.nn.init as init
15 | import Constants
16 | from torch.nn.parameter import Parameter
17 | from TransformerBlock import TransformerBlock
18 |
19 | class Gated_fusion(nn.Module):
20 | def __init__(self, input_size, out_size=1, dropout=0.2):
21 | super(Gated_fusion, self).__init__()
22 | self.linear1 = nn.Linear(input_size, input_size)
23 | self.linear2 = nn.Linear(input_size, out_size)
24 | self.dropout = nn.Dropout(dropout)
25 | self.init_weights()
26 |
27 | def init_weights(self):
28 | init.xavier_normal_(self.linear1.weight)
29 | init.xavier_normal_(self.linear2.weight)
30 |
31 | def forward(self, X1, X2):
32 | emb = torch.cat([X1.unsqueeze(dim=0), X2.unsqueeze(dim=0)], dim=0)
33 | emb_score = F.softmax(self.linear2(torch.tanh(self.linear1(emb))), dim=0)
34 | emb_score = self.dropout(emb_score)
35 | out = torch.sum(emb_score * emb, dim=0)
36 | return out
37 |
38 | class HGSL(nn.Module):
39 | def __init__(self, opt):
40 | super(HGSL, self).__init__()
41 |
42 | self.hidden_size = opt.d_model
43 | self.n_node = opt.user_size
44 | self.dropout = nn.Dropout(opt.dropout)
45 | self.initial_feature = opt.initialFeatureSize
46 | self.hgnn = HGNN(self.initial_feature, self.hidden_size, dropout = opt.dropout)
47 |
48 | self.user_embedding = nn.Embedding(self.n_node, self.initial_feature)
49 | self.stru_attention = TransformerBlock(self.hidden_size, n_heads=8)
50 | self.temp_attention = TransformerBlock(self.hidden_size, n_heads=8)
51 |
52 | self.global_cen_embedding = nn.Embedding(600, self.hidden_size)
53 | self.local_time_embedding = nn.Embedding(5000, self.hidden_size)
54 | self.cas_pos_embedding = nn.Embedding(50, self.hidden_size)
55 | self.local_inf_embedding = nn.Embedding(200, self.hidden_size)
56 |
57 | self.weight = Parameter(torch.Tensor(self.hidden_size+2, self.hidden_size+2))
58 | self.weight2 = Parameter(torch.Tensor(self.hidden_size+2, self.hidden_size+2))
59 | self.fus = Gated_fusion(self.hidden_size+2)
60 | self.linear = nn.Linear((self.hidden_size+2), 2)
61 | self.reset_parameters()
62 |
63 | def reset_parameters(self):
64 | stdv = 1.0 / math.sqrt(self.hidden_size)
65 | for weight in self.parameters():
66 | weight.data.uniform_(-stdv, stdv)
67 |
68 |
69 | def forward(self, data_idx, hypergraph_list):
70 |
71 | news_centered_graph, user_centered_graph, spread_status = (item for item in hypergraph_list)
72 | seq, timestamps, user_level = (item for item in news_centered_graph)
73 | useq, user_inf, user_cen = (item for item in user_centered_graph)
74 |
75 | #Global learning
76 | hidden = self.dropout(self.user_embedding.weight)
77 | user_cen = self.global_cen_embedding(user_cen)
78 | tweet_hidden = hidden + user_cen
79 | user_hgnn_out = self.hgnn(tweet_hidden, seq, useq)
80 | #print(user_hgnn_out.device)
81 |
82 | #Normalize
83 | zero_vec1 = -9e15 * torch.ones_like(seq[data_idx])
84 | one_vec = torch.ones_like(seq[data_idx], dtype=torch.float)
85 | nor_input = torch.where(seq[data_idx] > 0, one_vec, zero_vec1)
86 | nor_input = F.softmax(nor_input, 1)
87 | att_mask = (seq[data_idx] == Constants.PAD)
88 | adj_with_fea = F.embedding(seq[data_idx], user_hgnn_out)
89 | #print(seq[data_idx].size(), user_hgnn_out.size())
90 |
91 | #Local temporal learning
92 | global_time = self.local_time_embedding(timestamps[data_idx])
93 | att_hidden = adj_with_fea + global_time
94 |
95 | att_out = self.temp_attention(att_hidden, att_hidden, att_hidden, mask = att_mask )
96 | news_out = torch.einsum("abc,ab->ac", (att_out, nor_input))
97 |
98 | #Concatenate temporal propagation status
99 | news_out = torch.cat([news_out, spread_status[data_idx][:, 2:]/3600/24], dim=-1)
100 | news_out = news_out.matmul(self.weight)
101 |
102 | #Local structural learning
103 | local_inf = self.local_inf_embedding(user_inf[data_idx])
104 | cas_pos = self.cas_pos_embedding(user_level[data_idx])
105 | att_hidden_str = adj_with_fea + local_inf + cas_pos
106 |
107 | att_out_str = self.stru_attention(att_hidden_str, att_hidden_str, att_hidden_str, mask=att_mask, pos = False)
108 | news_out_str = torch.einsum("abc,ab->ac", (att_out_str, nor_input))
109 |
110 | # Concatenate structural propagation status
111 | news_out_str = torch.cat([news_out_str, spread_status[data_idx][:,:2]], dim=-1)
112 | news_out_str = news_out_str.matmul(self.weight2)
113 |
114 | #Gated fusion
115 | news_out = self.fus(news_out, news_out_str)
116 | output = self.linear(news_out)
117 | output = F.log_softmax(output, dim=1)
118 | #print(output)
119 |
120 | return output
121 |
122 | '''Learn hypergraphs'''
123 | class HGNN_layer(nn.Module):
124 | def __init__(self, input_size, output_size, dropout=0.5):
125 | super(HGNN_layer, self).__init__()
126 | self.dropout = dropout
127 | self.in_features = input_size
128 | self.out_features = output_size
129 | self.weight1 = Parameter(torch.Tensor(self.in_features, self.out_features))
130 | self.weight2 = Parameter(torch.Tensor(self.out_features, self.out_features))
131 | self.reset_parameters()
132 |
133 | def reset_parameters(self):
134 | stdv = 1.0 / math.sqrt(self.in_features)
135 | for weight in self.parameters():
136 | weight.data.uniform_(-stdv, stdv)
137 | self.weight1.data.uniform_(-stdv, stdv)
138 | self.weight2.data.uniform_(-stdv, stdv)
139 |
140 | def forward(self, x, seq, useq):
141 | x = x.matmul(self.weight1)
142 | adj_with_fea = F.embedding(seq, x)
143 | zero_vec1 = -9e15 * torch.ones_like(seq)
144 | one_vec = torch.ones_like(seq, dtype=torch.float)
145 | nor_input = torch.where(seq > 0, one_vec, zero_vec1)
146 | nor_input = F.softmax(nor_input, 1)
147 |
148 | edge = torch.einsum("abc,ab->ac", (adj_with_fea, nor_input))
149 | edge = F.dropout(edge, self.dropout, training=self.training)
150 | edge = F.relu(edge, inplace=False)
151 | e1 = edge.matmul(self.weight2)
152 | edge_adj_with_fea = F.embedding(useq, e1)
153 |
154 | zero_vec1 = -9e15 * torch.ones_like(useq)
155 | one_vec = torch.ones_like(useq, dtype=torch.float)
156 | u_nor_input = torch.where(useq > 0, one_vec, zero_vec1)
157 | u_nor_input = F.softmax(u_nor_input, 1)
158 | node = torch.einsum("abc,ab->ac", (edge_adj_with_fea, u_nor_input))
159 |
160 | node = F.dropout(node, self.dropout, training=self.training)
161 |
162 | return node
163 |
164 | class HGNN(nn.Module):
165 | def __init__(self, input_size, output_size, dropout=0.5):
166 | super(HGNN, self).__init__()
167 | self.dropout = dropout
168 | self.gnn1 = HGNN_layer(input_size, output_size, dropout=self.dropout)
169 |
170 |
171 | def forward(self, x, seq, useq):
172 | node = self.gnn1(x, seq, useq)
173 | return node
174 |
--------------------------------------------------------------------------------
/Data_preprocessing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import Constants
4 | import pickle
5 | import os
6 |
7 |
8 | class Options(object):
9 |
10 | def __init__(self, data_name='poli'):
11 | self.nretweet = 'data/' + data_name + '/news_centered_data.txt'
12 | self.uretweet = 'data/' + data_name + '/user_centered_data.txt'
13 | self.label = 'data/' + data_name + '/label.txt'
14 | self.news_list = 'data/' + data_name + '/' + data_name + '_news_list.txt'
15 |
16 | self.news_centered = 'data/' + data_name + '/Processed/news_centered.pickle'
17 | self.user_centered = 'data/' + data_name + '/Processed/user_centered.pickle'
18 |
19 | self.train_idx = torch.from_numpy(np.load('data/' + data_name +'/train_idx.npy'))
20 | self.valid_idx = torch.from_numpy(np.load('data/' + data_name +'/val_idx.npy'))
21 | self.test_idx = torch.from_numpy(np.load('data/' + data_name +'/test_idx.npy'))
22 |
23 | self.train = 'data/' + data_name + '/Processed/train_processed.pickle'
24 | self.valid = 'data/' + data_name + '/Processed/valid_processed.pickle'
25 | self.test = 'data/' + data_name + '/Processed/test_processed.pickle'
26 |
27 | self.user_mapping = 'data/' + data_name + '/user_mapping.pickle'
28 | self.news_mapping = 'data/' + data_name + '/news_mapping.pickle'
29 | self.save_path = ''
30 | self.embed_dim = 64
31 |
32 |
33 | def buildIndex(user_set, news_set):
34 | n2idx = {}
35 | u2idx = {}
36 |
37 | pos = 0
38 | u2idx[''] = pos
39 | pos += 1
40 | for user in user_set:
41 | u2idx[user] = pos
42 | pos += 1
43 |
44 | pos = 0
45 | n2idx[''] = pos
46 | pos += 1
47 | for news in news_set:
48 | n2idx[news] = pos
49 | pos += 1
50 |
51 | user_size = len(user_set)
52 | news_size = len(news_set)
53 | return user_size, news_size, u2idx, n2idx
54 |
55 | def Pre_data(data_name, early_type, early, max_len=200):
56 | options = Options(data_name)
57 | cascades = {}
58 |
59 | '''load news-centered retweet data'''
60 | for line in open(options.nretweet):
61 | userlist = []
62 | timestamps = []
63 | levels = []
64 | infs = []
65 |
66 | chunks = line.strip().split(',')
67 | cascades[chunks[0]] = []
68 |
69 | for chunk in chunks[1:]:
70 | try:
71 | user, timestamp, level, inf = chunk.split()
72 | userlist.append(user)
73 | timestamps.append(float(timestamp)/3600/24)
74 | levels.append(int(level)+1)
75 | infs.append(inf)
76 | except:
77 | user = chunk
78 | userlist.append(user)
79 | timestamps.append(float(0.0))
80 | infs.append(1)
81 | levels.append(1)
82 | print('tweet root', chunk)
83 | cascades[chunks[0]] = [userlist, timestamps, levels, infs]
84 |
85 | news_list = []
86 | for line in open(options.news_list):
87 | news_list.append(line.strip())
88 | cascades = {key: value for key, value in cascades.items() if key in news_list}
89 |
90 | if early:
91 | if early_type == 'engage':
92 | max_len = early
93 | elif early_type == 'time':
94 | mint = []
95 | for times in np.array(list(cascades.values()))[:,1]:
96 | if max(times)-min(times) < early:
97 | mint.append(len(times))
98 | else:
99 | for t in times:
100 | if t - min(times) >= early:
101 | mint.append(times.index(t))
102 | break
103 |
104 |
105 | '''ordered by timestamps'''
106 | for idx, cas in enumerate(cascades.keys()):
107 | max_ = mint[idx] if early and early_type == 'time' and mint[idx] < max_len else max_len
108 | cascades[cas] = [i[:max_] for i in cascades[cas]]
109 |
110 | order = [i[0] for i in sorted(enumerate(cascades[cas][1]), key=lambda x: float(x[1]))]
111 | #print(cascades[cas].shape)
112 | cascades[cas] = [[x[i] for i in order] for x in cascades[cas]]
113 | #cascades[cas] = cascades[cas][:,order]
114 | #cascades[cas][1][:] = [cascades[cas][1][i] for i in order]
115 | #cascades[cas][0][:] = [cascades[cas][0][i] for i in order]
116 | #cascades[cas][2][:] = [cascades[cas][2][i] for i in order]
117 | #cascades[cas][3][:] = [cascades[cas][3][i] for i in order]
118 |
119 |
120 |
121 | ucascades = {}
122 | '''load user-centered retweet data'''
123 | for line in open(options.uretweet):
124 | newslist = []
125 | userinf = []
126 |
127 | chunks = line.strip().split(',')
128 |
129 | ucascades[chunks[0]] = []
130 |
131 | for chunk in chunks[1:]:
132 | news, timestamp, inf= chunk.split()
133 | newslist.append(news)
134 | userinf.append(inf)
135 |
136 | ucascades[chunks[0]] = np.array([newslist, userinf])
137 |
138 | '''ordered by timestamps'''
139 | for cas in list(ucascades.keys()):
140 | order = [i[0] for i in sorted(enumerate(ucascades[cas][1]), key=lambda x: float(x[1]))]
141 | #ucascades[cas] = cascades[cas][:, order]
142 | ucascades[cas] = [[x[i] for i in order] for x in ucascades[cas]]
143 | #ucascades[cas][1][:] = [ucascades[cas][1][i] for i in order]
144 | #ucascades[cas][0][:] = [ucascades[cas][0][i] for i in order]
145 | user_set = ucascades.keys()
146 |
147 |
148 | if os.path.exists(options.user_mapping):
149 | with open(options.user_mapping, 'rb') as handle:
150 | u2idx = pickle.load(handle)
151 | user_size = len(list(user_set))
152 | with open(options.news_mapping, 'rb') as handle:
153 | n2idx = pickle.load(handle)
154 | news_size = len(news_list)
155 | else:
156 | user_size, news_size, u2idx, n2idx = buildIndex(user_set, news_list)
157 | with open(options.user_mapping, 'wb') as handle:
158 | pickle.dump(u2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
159 | with open(options.news_mapping, 'wb') as handle:
160 | pickle.dump(n2idx, handle, protocol=pickle.HIGHEST_PROTOCOL)
161 |
162 | for cas in cascades:
163 | cascades[cas][0] = [u2idx[u] for u in cascades[cas][0]]
164 | t_cascades = dict([(n2idx[key], cascades[key]) for key in cascades])
165 |
166 | for cas in ucascades:
167 | ucascades[cas][0] = [n2idx[n] for n in ucascades[cas][0]]
168 | u_cascades = dict([(u2idx[key], ucascades[key]) for key in ucascades])
169 |
170 | '''load labels'''
171 | labels = np.zeros((news_size + 1, 1))
172 | for line in open(options.label):
173 | news, label = line.strip().split(' ')
174 | if news in n2idx:
175 | labels[n2idx[news]] = label
176 |
177 | seq = np.zeros((news_size + 1, max_len))
178 | timestamps = np.zeros((news_size + 1, max_len))
179 | user_level = np.zeros((news_size + 1, max_len))
180 | user_inf = np.zeros((news_size + 1, max_len))
181 | news_list = [0] + news_list
182 | for n, s in cascades.items():
183 | news_list[n2idx[n]] = n
184 | se_data = np.hstack((s[0], np.array([Constants.PAD] * (max_len - len(s[0])))))
185 | seq[n2idx[n]] = se_data
186 |
187 | t_data = np.hstack((s[1], np.array([Constants.PAD] * (max_len - len(s[1])))))
188 | timestamps[n2idx[n]] = t_data
189 |
190 | lv_data = np.hstack((s[2], np.array([Constants.PAD] * (max_len - len(s[2])))))
191 | user_level[n2idx[n]] = lv_data
192 |
193 | inf_data = np.hstack((s[3], np.array([Constants.PAD] * (max_len - len(s[3])))))
194 | user_inf[n2idx[n]] = inf_data
195 |
196 | useq = np.zeros((user_size + 1, max_len))
197 | uinfs = np.zeros((user_size + 1, max_len))
198 |
199 | for n, s in ucascades.items():
200 | if len(s[0])