├── .gitattributes ├── .gitignore ├── README.md ├── data_preprocessing ├── __init__.py ├── feature_extraction.py ├── graph_structure.py ├── load_data.py ├── text_summarization.py └── visualization.py ├── machine_learning ├── __init__.py ├── gnn_models.py └── gnn_training.py ├── requirements.txt ├── results ├── README.md └── results_static_graphs.pdf ├── scripts ├── generate_graphs.py └── run_experiment.py └── temporal ├── __init__.py ├── temporal_gnn_models.py ├── temporal_gnn_training.py ├── temporal_graph_structure.py └── temporal_layers.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .nox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # IPython 77 | profile_default/ 78 | ipython_config.py 79 | 80 | # pyenv 81 | .python-version 82 | 83 | # celery beat schedule file 84 | celerybeat-schedule 85 | 86 | # SageMath parsed files 87 | *.sage.py 88 | 89 | # Environments 90 | .env 91 | .venv 92 | env/ 93 | venv/ 94 | ENV/ 95 | env.bak/ 96 | venv.bak/ 97 | 98 | # Spyder project settings 99 | .spyderproject 100 | .spyproject 101 | 102 | # Rope project settings 103 | .ropeproject 104 | 105 | # mkdocs documentation 106 | /site 107 | 108 | # mypy 109 | .mypy_cache/ 110 | .dmypy.json 111 | dmypy.json 112 | 113 | # Pyre type checker 114 | .pyre/ 115 | 116 | .idea 117 | .DS_Store 118 | /data 119 | /scripts/test.ipynb 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Heterogeneous Graphs for Fake News Detection 2 | 3 | We evaluate how heterogeneous graphs constructed around news articles can 4 | be used to detect fake stories. The contextual information describes 5 | social context and is modelled in network structure. In detail we use 6 | - news articles 7 | - user postings (tweets) 8 | - user repostings (retweets) 9 | - user accounts 10 | - user timeline-posts 11 | 12 | as node types in our graphs and reformulate the problem as a 13 | graph classification task. We use the Politifact and Gossipcop 14 | datasets from FakeNewsNet (https://github.com/KaiDMML/FakeNewsNet). 15 | 16 | ## Project Structure: 17 | 18 | ### `data_preprocessing` 19 | 20 | Python files to load and preprocess data (place a folder named `data` in the project's 21 | root directory that has two subfolders with the same structure as FakeNewsNet's 22 | `dataset` and `fakenewsnet_dataset` folders) 23 | 24 | - `feature_extraction.py`: getting node related features like retweet count and generating transformer-based text embeddings 25 | - `graph_structure.py`: functions to generate graphs from data. For an example see `scripts/generate_graphs.py` 26 | - `load_data.py`: helper functions to load data from `data` folder during graph construction 27 | - `text_summarization.py`: generating extractive and abstractive summaries from text (not used yet) 28 | - `visualization.py`: function to visualize homogeneous graphs 29 | 30 | ### `machine_learning` 31 | 32 | Python files that are related to graph machine learning 33 | 34 | - `gnn_models.py`: GNNs used for experiments: SAGE, GAT, HGT. Architecture is currently adapted to graphs that feature all types of information (important for mean pooling node types) 35 | - `gnn_training.py`: training and evaluation of models 36 | 37 | ### `scripts` 38 | 39 | - `generate_graphs.py`: example script how to generate graphs. Parameters can be set to specify which node types should be considered 40 | - `run_experiment.py`: example script that shows how the generated graphs can be used to run graph classification experiments 41 | 42 | 43 | # Citation 44 | 45 | The paper based on this idea was accepted at ECIR 2023. If you use parts of our code or adopt our approach we kindly ask you to cite our work as follows: 46 | ``` 47 | @inproceedings{10.1007/978-3-031-28238-6_29, 48 | author = {Donabauer, Gregor and Kruschwitz, Udo}, 49 | title = {Exploring Fake News Detection with Heterogeneous Social Media Context Graphs}, 50 | year = {2023}, 51 | isbn = {978-3-031-28237-9}, 52 | publisher = {Springer-Verlag}, 53 | address = {Berlin, Heidelberg}, 54 | url = {https://doi.org/10.1007/978-3-031-28238-6_29}, 55 | doi = {10.1007/978-3-031-28238-6_29}, 56 | booktitle = {Advances in Information Retrieval: 45th European Conference on Information Retrieval, ECIR 2023, Dublin, Ireland, April 2–6, 2023, Proceedings, Part II}, 57 | pages = {396–405}, 58 | numpages = {10}, 59 | location = {Dublin, Ireland} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /data_preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doGregor/Graph-FakeNewsNet/c5d86b3244499261898b2ff261a847e3dda8d983/data_preprocessing/__init__.py -------------------------------------------------------------------------------- /data_preprocessing/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import json 3 | from flair.embeddings import TransformerDocumentEmbeddings, WordEmbeddings, DocumentPoolEmbeddings 4 | from flair.data import Sentence 5 | import numpy as np 6 | 7 | 8 | EMBEDDING = TransformerDocumentEmbeddings('bert-base-cased') 9 | 10 | 11 | def text_embeddings(text_array): 12 | embedding_array = [] 13 | for sentence in text_array: 14 | if sentence.strip() == '': 15 | embedding_array.append(np.zeros(768)) 16 | else: 17 | sent = Sentence(sentence) 18 | EMBEDDING.embed(sent) 19 | embedding_array.append(sent.embedding.cpu().detach().numpy()) 20 | return np.asarray(embedding_array) 21 | 22 | 23 | def get_news_features(news_data): 24 | return news_data['title'] + '. ' + news_data['text'] 25 | 26 | 27 | def get_summaries(news_id, dataset='politifact', subset='fake', *args): 28 | file_path = '../data/fakenewsnet_dataset/' + dataset + '/' + subset + '/' + news_id + '/summary.json' 29 | output = '' 30 | if os.path.isfile(file_path): 31 | with open(file_path) as json_file: 32 | data = json.load(json_file) 33 | for summary_type in args: 34 | output += data[summary_type] + ' ' 35 | return output.strip() 36 | else: 37 | return output 38 | 39 | 40 | def get_tweet_features(tweet_data): 41 | return [tweet_data['text'], tweet_data['retweet_count'], tweet_data['favorite_count']] 42 | 43 | 44 | def get_user_features(user_data): 45 | return [user_data['description'], user_data['followers_count'], user_data['friends_count'], user_data['favourites_count'], 46 | user_data['statuses_count']] 47 | 48 | 49 | 50 | if __name__ == '__main__': 51 | 52 | test_array = ['hello', 'test', 'a string', ''] 53 | embedded_array = text_embeddings(test_array) 54 | print(embedded_array.shape) 55 | print(embedded_array[-1].shape) 56 | -------------------------------------------------------------------------------- /data_preprocessing/graph_structure.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import json 4 | import pickle 5 | import torch 6 | from data_preprocessing.load_data import * 7 | from torch_geometric.data import Data, HeteroData 8 | import torch_geometric.transforms as T 9 | from data_preprocessing.feature_extraction import * 10 | 11 | 12 | def create_homogeneous_graph(news_id_dict, dataset='politifact', include_tweets=True, include_users=True, 13 | include_user_timeline_tweets=True, include_retweets=True, include_user_followers=True, 14 | include_user_following=True, add_new_users=False, to_undirected=True): 15 | node_ids = {'article': [], 16 | 'tweet': [], 17 | 'user': []} 18 | node_ids_all = [] 19 | graph = Data(x=[], edge_index=[[], []], y=[]) 20 | for subset, news_ids in news_id_dict.items(): 21 | for news_id in news_ids: 22 | if content_available(news_id=news_id, dataset=dataset, subset=subset): 23 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 24 | # news node feature 25 | graph.x.append(0) 26 | graph.y.append(0) 27 | node_ids['article'].append(news_id) 28 | node_ids_all.append(news_id) 29 | if include_tweets: 30 | tweets_path, tweet_ids = get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset) 31 | for tweet in tweet_ids: 32 | tweet_data = open_tweet_json(tweets_path, tweet) 33 | if tweet_data['id'] not in node_ids['tweet']: 34 | # tweets node features 35 | graph.x.append(1) 36 | graph.y.append(1) 37 | node_ids['tweet'].append(tweet_data['id']) 38 | node_ids_all.append(tweet_data['id']) 39 | node_id_news = node_ids_all.index(news_id) 40 | node_id_tweet = node_ids_all.index(tweet_data['id']) 41 | graph.edge_index[0] += [node_id_news, node_id_tweet] 42 | graph.edge_index[1] += [node_id_tweet, node_id_news] 43 | if include_users: 44 | user_information = get_user_information(tweet_data['user']['id']) 45 | if user_information: 46 | if user_information['id'] not in node_ids['user']: 47 | # user node features 48 | graph.x.append(2) 49 | graph.y.append(2) 50 | node_ids['user'].append(user_information['id']) 51 | node_ids_all.append(user_information['id']) 52 | node_id_user = node_ids_all.index(user_information['id']) 53 | graph.edge_index[0] += [node_id_tweet, node_id_user] 54 | graph.edge_index[1] += [node_id_user, node_id_tweet] 55 | #else: 56 | # print(f"[WARNING] excluding sample with id {news_id} no user available") 57 | if include_retweets: 58 | retweets_path, retweet_ids = get_retweet_ids(news_id=news_id, dataset=dataset, subset=subset) 59 | if tweet in retweet_ids: 60 | retweets_data = open_retweet_json(retweets_path, tweet) 61 | for retweet in retweets_data: 62 | if retweet['id'] not in node_ids['tweet']: 63 | # retweets node features 64 | graph.x.append(1) 65 | graph.y.append(1) 66 | node_ids['tweet'].append(retweet['id']) 67 | node_ids_all.append(retweet['id']) 68 | node_id_retweet = node_ids_all.index(retweet['id']) 69 | graph.edge_index[0] += [node_id_retweet] 70 | graph.edge_index[1] += [node_id_tweet] 71 | if include_users: 72 | user_information = get_user_information(retweet['user']['id']) 73 | if user_information: 74 | if user_information['id'] not in node_ids['user']: 75 | # user node features 76 | graph.x.append(2) 77 | graph.y.append(2) 78 | node_ids['user'].append(user_information['id']) 79 | node_ids_all.append(user_information['id']) 80 | node_id_user = node_ids_all.index(user_information['id']) 81 | graph.edge_index[0] += [node_id_user] 82 | graph.edge_index[1] += [node_id_retweet] 83 | else: 84 | print(f"[WARNING] excluding sample with id {news_id} no news or tweets available") 85 | if include_users and include_user_timeline_tweets and len(node_ids['user']) > 0: 86 | for user_id in node_ids['user']: 87 | user_timeline_tweets = get_user_timeline_tweets(user_id) 88 | node_id_user = node_ids_all.index(user_id) 89 | if len(user_timeline_tweets) > 0: 90 | for user_timeline_tweet_data in user_timeline_tweets: 91 | if user_timeline_tweet_data['id'] not in node_ids['tweet']: 92 | # timeline tweets node features 93 | graph.x.append(1) 94 | graph.y.append(1) 95 | node_ids['tweet'].append(user_timeline_tweet_data['id']) 96 | node_ids_all.append(user_timeline_tweet_data['id']) 97 | node_id_tweet = node_ids_all.index(user_timeline_tweet_data['id']) 98 | graph.edge_index[0] += [node_id_user] 99 | graph.edge_index[1] += [node_id_tweet] 100 | if include_user_followers: 101 | user_followers = get_user_followers(user_id) 102 | if len(user_followers) > 0: 103 | for follower_id in user_followers: 104 | if follower_id in node_ids['user']: 105 | node_id_follower = node_ids_all.index(follower_id) 106 | graph.edge_index[0] += [node_id_follower] 107 | graph.edge_index[1] += [node_id_user] 108 | elif add_new_users and get_user_information(follower_id): 109 | user_information = get_user_information(follower_id) 110 | # followers user features 111 | graph.x.append(2) 112 | graph.y.append(2) 113 | node_ids['user'].append(user_information['id']) 114 | node_ids_all.append(user_information['id']) 115 | graph.edge_index[0] += [len(node_ids_all)-1] 116 | graph.edge_index[1] += [node_id_user] 117 | if include_user_following: 118 | user_following = get_user_following(user_id) 119 | if len(user_following) > 0: 120 | for following_id in user_following: 121 | if following_id in node_ids['user']: 122 | node_id_following = node_ids_all.index(following_id) 123 | graph.edge_index[0] += [node_id_user] 124 | graph.edge_index[1] += [node_id_following] 125 | elif add_new_users and get_user_information(following_id): 126 | user_information = get_user_information(following_id) 127 | # following user features 128 | graph.x.append(2) 129 | graph.y.append(2) 130 | node_ids['user'].append(user_information['id']) 131 | node_ids_all.append(user_information['id']) 132 | graph.edge_index[0] += [node_id_user] 133 | graph.edge_index[1] += [len(node_ids_all)-1] 134 | graph.x = torch.tensor(graph.x, dtype=torch.float32) 135 | graph.y = torch.tensor(graph.y, dtype=torch.long) 136 | graph.edge_index = torch.tensor(graph.edge_index, dtype=torch.long) 137 | graph.num_classes = torch.unique(graph.y).size()[0] 138 | graph = graph.coalesce() 139 | if to_undirected: 140 | graph = T.ToUndirected()(graph) 141 | return graph 142 | 143 | 144 | def create_heterogeneous_graph(news_id_dict, dataset='politifact', include_tweets=True, include_users=True, 145 | include_user_timeline_tweets=True, include_retweets=True, include_user_followers=True, 146 | include_user_following=True, add_new_users=False, to_undirected=True, include_text=False): 147 | node_ids = {'article': [], 148 | 'tweet': [], 149 | 'user': []} 150 | graph = HeteroData() 151 | graph['article'].x = [] 152 | graph['article'].y = [] 153 | if include_tweets: 154 | graph['tweet'].x = [[], []] 155 | graph['tweet', 'cites', 'article'].edge_index = [[], []] 156 | if include_users: 157 | graph['user'].x = [[], []] 158 | graph['user', 'posts', 'tweet'].edge_index = [[], []] 159 | if include_user_followers or include_user_following: 160 | graph['user', 'follows', 'user'].edge_index = [[], []] 161 | if include_retweets: 162 | graph['tweet', 'retweets', 'tweet'].edge_index = [[], []] 163 | 164 | for subset, news_ids in news_id_dict.items(): 165 | for news_id in news_ids: 166 | if content_available(news_id=news_id, dataset=dataset, subset=subset): 167 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 168 | # news node feature 169 | graph['article'].x.append(get_news_features(news_content)) 170 | if subset == 'fake': 171 | graph['article'].y.append(1) 172 | elif subset == 'real': 173 | graph['article'].y.append(0) 174 | node_ids['article'].append(news_id) 175 | if include_tweets: 176 | tweets_path, tweet_ids = get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset) 177 | for tweet in tweet_ids: 178 | tweet_data = open_tweet_json(tweets_path, tweet) 179 | if tweet_data['id'] not in node_ids['tweet']: 180 | # tweets node features 181 | graph['tweet'].x[0].append(get_tweet_features(tweet_data)[0]) 182 | graph['tweet'].x[1].append(get_tweet_features(tweet_data)[1:]) 183 | node_ids['tweet'].append(tweet_data['id']) 184 | node_id_news = node_ids['article'].index(news_id) 185 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 186 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 187 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 188 | if include_users: 189 | user_information = get_user_information(tweet_data['user']['id']) 190 | if user_information: 191 | if user_information['id'] not in node_ids['user']: 192 | # user node features 193 | graph['user'].x[0].append(get_user_features(user_information)[0]) 194 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 195 | node_ids['user'].append(user_information['id']) 196 | node_id_user = node_ids['user'].index(user_information['id']) 197 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 198 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 199 | #else: 200 | # print(f"[WARNING] excluding sample with id {news_id} no user available") 201 | if include_retweets: 202 | retweets_path, retweet_ids = get_retweet_ids(news_id=news_id, dataset=dataset, subset=subset) 203 | if tweet in retweet_ids: 204 | retweets_data = open_retweet_json(retweets_path, tweet) 205 | for retweet in retweets_data: 206 | if retweet['id'] not in node_ids['tweet']: 207 | # retweets node features 208 | graph['tweet'].x[0].append(get_tweet_features(retweet)[0]) 209 | graph['tweet'].x[1].append(get_tweet_features(retweet)[1:]) 210 | node_ids['tweet'].append(retweet['id']) 211 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 212 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 213 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 214 | if include_users: 215 | user_information = get_user_information(retweet['user']['id']) 216 | if user_information: 217 | if user_information['id'] not in node_ids['user']: 218 | # user node features 219 | graph['user'].x[0].append(get_user_features(user_information)[0]) 220 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 221 | node_ids['user'].append(user_information['id']) 222 | node_id_user = node_ids['user'].index(user_information['id']) 223 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 224 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 225 | else: 226 | print(f"[WARNING] excluding sample with id {news_id} no news or tweets available") 227 | if len(news_ids) == 1: 228 | graph['article'].x = torch.tensor(graph['article'].x, dtype=torch.float32) 229 | return graph 230 | if include_users and include_user_timeline_tweets and len(node_ids['user']) > 0: 231 | for user_id in node_ids['user']: 232 | user_timeline_tweets = get_user_timeline_tweets(user_id) 233 | node_id_user = node_ids['user'].index(user_id) 234 | if len(user_timeline_tweets) > 0: 235 | for user_timeline_tweet_data in user_timeline_tweets: 236 | if user_timeline_tweet_data['id'] not in node_ids['tweet']: 237 | graph['tweet'].x[0].append(get_tweet_features(user_timeline_tweet_data)[0]) 238 | graph['tweet'].x[1].append(get_tweet_features(user_timeline_tweet_data)[1:]) 239 | node_ids['tweet'].append(user_timeline_tweet_data['id']) 240 | node_id_tweet = node_ids['tweet'].index(user_timeline_tweet_data['id']) 241 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 242 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 243 | if include_user_followers: 244 | user_followers = get_user_followers(user_id) 245 | if len(user_followers) > 0: 246 | for follower_id in user_followers: 247 | if follower_id in node_ids['user']: 248 | node_id_follower = node_ids['user'].index(follower_id) 249 | graph['user', 'follows', 'user'].edge_index[0] += [node_id_follower] 250 | graph['user', 'follows', 'user'].edge_index[1] += [node_id_user] 251 | elif add_new_users and get_user_information(follower_id): 252 | user_information = get_user_information(follower_id) 253 | graph['user'].x[0].append(get_user_features(user_information)[0]) 254 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 255 | node_ids['user'].append(user_information['id']) 256 | graph['user', 'follows', 'user'].edge_index[0] += [len(node_ids['user'])-1] 257 | graph['user', 'follows', 'user'].edge_index[1] += [node_id_user] 258 | if include_user_following: 259 | user_following = get_user_following(user_id) 260 | if len(user_following) > 0: 261 | for following_id in user_following: 262 | if following_id in node_ids['user']: 263 | node_id_following = node_ids['user'].index(following_id) 264 | graph['user', 'follows', 'user'].edge_index[0] += [node_id_user] 265 | graph['user', 'follows', 'user'].edge_index[1] += [node_id_following] 266 | elif add_new_users and get_user_information(following_id): 267 | user_information = get_user_information(following_id) 268 | graph['user'].x[0].append(get_user_features(user_information)[0]) 269 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 270 | node_ids['user'].append(user_information['id']) 271 | graph['user', 'follows', 'user'].edge_index[0] += [node_id_user] 272 | graph['user', 'follows', 'user'].edge_index[1] += [len(node_ids['user'])-1] 273 | 274 | graph['article'].x = torch.tensor(text_embeddings(graph['article'].x), dtype=torch.float32) 275 | graph['article'].y = torch.tensor(graph['article'].y, dtype=torch.long) 276 | if include_tweets: 277 | if include_text: 278 | graph['tweet'].x = torch.tensor(np.concatenate((text_embeddings(graph['tweet'].x[0]), np.asarray(graph['tweet'].x[1])), axis=1), dtype=torch.float32) 279 | else: 280 | graph['tweet'].x = torch.tensor(graph['tweet'].x[1], dtype=torch.float32) 281 | graph['tweet', 'cites', 'article'].edge_index = torch.tensor(graph['tweet', 'cites', 'article'].edge_index, dtype=torch.long) 282 | if include_users: 283 | if include_text and np.asarray(graph['user'].x[1]).shape[0] > 0: 284 | graph['user'].x = torch.tensor(np.concatenate((text_embeddings(graph['user'].x[0]), np.asarray(graph['user'].x[1])), axis=1), dtype=torch.float32) 285 | else: 286 | graph['user'].x = torch.tensor(graph['user'].x[1], dtype=torch.float32) 287 | graph['user', 'posts', 'tweet'].edge_index = torch.tensor(graph['user', 'posts', 'tweet'].edge_index, dtype=torch.long) 288 | if include_user_followers or include_user_following: 289 | graph['user', 'follows', 'user'].edge_index = torch.tensor(graph['user', 'follows', 'user'].edge_index, dtype=torch.long) 290 | if include_retweets: 291 | graph['tweet', 'retweets', 'tweet'].edge_index = torch.tensor(graph['tweet', 'retweets', 'tweet'].edge_index, dtype=torch.long) 292 | graph = graph.coalesce() 293 | if to_undirected: 294 | graph = T.ToUndirected(merge=False)(graph) 295 | return graph 296 | 297 | 298 | def make_undirected(graph): 299 | for edge_relation in graph.metadata()[1]: 300 | print([graph[edge_relation]['edge_index'][1], graph[edge_relation]['edge_index'][0]]) 301 | graph[edge_relation[2], 'rev_'+edge_relation[1], edge_relation[0]].edge_index = [graph[edge_relation]['edge_index'][1], graph[edge_relation]['edge_index'][0]] 302 | return graph 303 | 304 | 305 | def graph_to_pickle(graph, file_name): 306 | path = "../data/graphs/" + file_name + ".pickle" 307 | with open(path, 'wb') as handle: 308 | pickle.dump({'graph': graph}, handle, protocol=pickle.HIGHEST_PROTOCOL) 309 | 310 | 311 | def graph_from_pickle(file_name): 312 | path = "../data/graphs/" + file_name + ".pickle" 313 | with open(path, 'rb') as handle: 314 | return pickle.load(handle) 315 | 316 | 317 | if __name__ == '__main__': 318 | from data_preprocessing.visualization import * 319 | 320 | ids_true, ids_fake = get_news_ids() 321 | 322 | #for id in ids_fake[0:5]: 323 | 324 | graph = create_homogeneous_graph({'fake': list(ids_fake[:10])}, include_user_followers=False, include_user_following=False, 325 | to_undirected=True) 326 | print(graph) 327 | visualize_graph(graph, labels=True) 328 | -------------------------------------------------------------------------------- /data_preprocessing/load_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import sys 4 | import itertools 5 | import os 6 | import json 7 | 8 | 9 | DATA_PATH = { 10 | 'gossipcop_fake_dataset': '../data/dataset/gossipcop_fake.csv', 11 | 'gossipcop_real_dataset': '../data/dataset/gossipcop_real.csv', 12 | 'politifact_fake_dataset': '../data/dataset/politifact_fake.csv', 13 | 'politifact_real_dataset': '../data/dataset/politifact_real.csv', 14 | 'gossipcop_dir': '../data/fakenewsnet_dataset/gossipcop/', 15 | 'politifact_dir': '../data/fakenewsnet_dataset/politifact/', 16 | 'user_profiles_dir': '../data/fakenewsnet_dataset/user_profiles/', 17 | 'user_timeline_tweets_dir': '../data/fakenewsnet_dataset/user_timeline_tweets/', 18 | 'user_followers_dir': '../data/fakenewsnet_dataset/user_followers/', 19 | 'user_following_dir': '../data/fakenewsnet_dataset/user_following/' 20 | } 21 | 22 | 23 | def __get_directories(dataset_name): 24 | if dataset_name == 'politifact': 25 | true = DATA_PATH['politifact_real_dataset'] 26 | fake = DATA_PATH['politifact_fake_dataset'] 27 | directory = DATA_PATH['politifact_dir'] 28 | elif dataset_name == 'gossipcop': 29 | true = DATA_PATH['gossipcop_real_dataset'] 30 | fake = DATA_PATH['gossipcop_fake_dataset'] 31 | directory = DATA_PATH['gossipcop_dir'] 32 | else: 33 | print("[ERROR] Wrong dataset parameter specified.") 34 | sys.exit(0) 35 | return true, fake, directory 36 | 37 | 38 | def get_dataset_info(dataset='politifact'): 39 | true, fake, directory = __get_directories(dataset_name=dataset) 40 | true_data = pd.read_csv(true, sep=',') 41 | true_tweet_information = true_data['tweet_ids'].to_list() 42 | true_tweet_information = [str(x).split('\t') for x in true_tweet_information] 43 | print("True news samples:", len(true_tweet_information), "\t Number of related tweets:", 44 | len(list(itertools.chain.from_iterable(true_tweet_information)))) 45 | fake_data = pd.read_csv(fake, sep=',') 46 | fake_tweet_information = fake_data['tweet_ids'].to_list() 47 | fake_tweet_information = [str(x).split('\t') for x in fake_tweet_information] 48 | print("Fake news samples:", len(fake_tweet_information), "\t Number of related tweets:", 49 | len(list(itertools.chain.from_iterable(fake_tweet_information)))) 50 | 51 | 52 | def get_news_ids(dataset='politifact'): 53 | true, fake, directory = __get_directories(dataset_name=dataset) 54 | return pd.read_csv(true, sep=',')['id'].to_numpy(), pd.read_csv(fake, sep=',')['id'].to_numpy() 55 | 56 | 57 | def get_news_tweet_ids(news_id, dataset='politifact', subset='fake'): 58 | data_path = '../data/fakenewsnet_dataset/' + dataset + '/' + subset + '/' + news_id + '/tweets/' 59 | if os.path.exists(data_path): 60 | news_tweet_files = os.listdir(data_path) 61 | return (data_path, news_tweet_files) 62 | else: 63 | return ("", []) 64 | 65 | 66 | def get_retweet_ids(news_id, dataset='politifact', subset='fake'): 67 | data_path = '../data/fakenewsnet_dataset/' + dataset + '/' + subset + '/' + news_id + '/retweets/' 68 | if os.path.exists(data_path): 69 | retweet_files = os.listdir(data_path) 70 | return (data_path, retweet_files) 71 | else: 72 | return ("", []) 73 | 74 | 75 | def open_tweet_json(data_path, file_name): 76 | file_path = data_path + file_name 77 | if os.path.exists(file_path): 78 | with open(file_path, 'r') as tweet_json: 79 | tweet_data = json.load(tweet_json) 80 | return tweet_data 81 | else: 82 | return {} 83 | 84 | 85 | def open_retweet_json(data_path, file_name): 86 | file_path = data_path + file_name 87 | if os.path.exists(file_path): 88 | with open(file_path, 'r') as retweet_json: 89 | retweet_data = json.load(retweet_json) 90 | return retweet_data['retweets'] 91 | else: 92 | return [] 93 | 94 | 95 | def get_news_content(news_id, dataset='politifact', subset='fake'): 96 | file_path = '../data/fakenewsnet_dataset/' + dataset + '/' + subset + '/' + news_id + '/news content.json' 97 | if os.path.exists(file_path): 98 | with open(file_path, 'r') as news_json: 99 | news_data = json.load(news_json) 100 | return news_data 101 | else: 102 | return {} 103 | 104 | 105 | def get_user_information(user_id): 106 | file_path = DATA_PATH['user_profiles_dir'] + str(user_id) + '.json' 107 | if os.path.exists(file_path): 108 | with open(file_path, 'r') as user_information_json: 109 | user_information = json.load(user_information_json) 110 | return user_information 111 | else: 112 | return {} 113 | 114 | 115 | def get_user_timeline_tweets(user_id, n=5): 116 | file_path = DATA_PATH['user_timeline_tweets_dir'] + str(user_id) + '.json' 117 | if os.path.exists(file_path): 118 | with open(file_path, 'r') as user_timeline_tweets_json: 119 | user_timeline_tweets = json.load(user_timeline_tweets_json) 120 | if n == 0: 121 | return user_timeline_tweets 122 | else: 123 | return user_timeline_tweets[:n] 124 | else: 125 | return [] 126 | 127 | 128 | def get_user_followers(user_id): 129 | file_path = DATA_PATH['user_followers_dir'] + str(user_id) + '.json' 130 | if os.path.exists(file_path): 131 | with open(file_path, 'r') as user_followers_json: 132 | user_followers_info = json.load(user_followers_json) 133 | return user_followers_info['followers'] 134 | else: 135 | return [] 136 | 137 | 138 | def get_user_following(user_id): 139 | file_path = DATA_PATH['user_following_dir'] + str(user_id) + '.json' 140 | if os.path.exists(file_path): 141 | with open(file_path, 'r') as user_following_json: 142 | user_following_info = json.load(user_following_json) 143 | return user_following_info['following'] 144 | else: 145 | return [] 146 | 147 | 148 | def content_available(news_id, dataset='politifact', subset='fake'): 149 | if get_news_content(news_id=news_id, dataset=dataset, subset=subset) and (len(get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset)[1]) > 0): 150 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 151 | if news_content['text'] != '': 152 | return True 153 | else: 154 | return False 155 | else: 156 | return False 157 | 158 | 159 | if __name__ == '__main__': 160 | true, fake = get_news_ids() 161 | print(fake[0]) 162 | 163 | print("NEWS CONTENT", get_news_content(fake[1])) 164 | data_path, tweet_ids = get_news_tweet_ids(fake[0]) 165 | print("DATA PATH", data_path) 166 | print("TWEET IDS", tweet_ids) 167 | print("TWEET DATA", open_tweet_json(data_path, tweet_ids[0])) 168 | example_user = open_tweet_json(data_path, tweet_ids[0])['user']['id'] 169 | print("USER DATA", get_user_information(example_user)) 170 | print("USER TIMELINE TWEETS", get_user_timeline_tweets(41)[0]) 171 | -------------------------------------------------------------------------------- /data_preprocessing/text_summarization.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from summarizer import Summarizer 3 | from transformers import BartTokenizerFast, BartForConditionalGeneration, BartConfig 4 | from transformers import BertModel, BertTokenizerFast, BertConfig 5 | import json 6 | import os 7 | 8 | 9 | def generate_extractive_summary(text, model_name='bert-base-uncased', random_state=43, ratio=0.4): 10 | config = BertConfig.from_pretrained(model_name) 11 | config.output_hidden_states = True 12 | tokenizer = BertTokenizerFast.from_pretrained(model_name) 13 | model = BertModel.from_pretrained(model_name, config=config) 14 | extractive_summary_model = Summarizer( 15 | custom_model=model, 16 | custom_tokenizer=tokenizer, 17 | random_state=random_state 18 | ) 19 | return [extractive_summary_model(t, ratio=ratio) for t in tqdm(text)] 20 | 21 | 22 | def generate_abstractive_summary(text, model_name='sshleifer/distilbart-cnn-12-6'): 23 | sum_tokenizer = BartTokenizerFast.from_pretrained(model_name) 24 | sum_model = BartForConditionalGeneration.from_pretrained(model_name) 25 | inputs = [sum_tokenizer(t, return_tensors="pt", truncation=True, padding=True, max_length=1024).input_ids for t in text] 26 | outputs = [sum_model.generate(i, min_length=int(len(text[idx].split()) * 0.4), max_length=512, top_k=100, top_p=0.95, do_sample=True) for idx, i in enumerate(tqdm(inputs))] 27 | summarized_texts = [sum_tokenizer.batch_decode(o, skip_special_tokens=True) for o in outputs] 28 | tmps = [s[0].strip() for s in summarized_texts] 29 | return tmps 30 | 31 | 32 | def save_summary(news_id, summary_text, summary_type, dataset='politifact', subset='fake'): 33 | folder_path = '../data/fakenewsnet_dataset/' + dataset + '/' + subset + '/' + news_id 34 | if os.path.isdir(folder_path): 35 | file_path = folder_path + '/summary.json' 36 | if os.path.isfile(file_path): 37 | with open(file_path) as json_file: 38 | data = json.load(json_file) 39 | data[summary_type] = summary_text 40 | else: 41 | data = {summary_type: summary_text} 42 | with open(file_path, 'w', encoding='utf-8') as f: 43 | json.dump(data, f, ensure_ascii=False, indent=4) 44 | else: 45 | print(f"[WARNING] folder '{folder_path}' doesn't exist") 46 | -------------------------------------------------------------------------------- /data_preprocessing/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from torch_geometric.utils.convert import to_networkx 3 | import networkx as nx 4 | import matplotlib.colors as colors 5 | import matplotlib.cm as cmx 6 | 7 | 8 | def visualize_graph(graph_data, labels=False, node_size=6, line_width=5, save=False): 9 | """ 10 | Plot pytorch geometric graph object. 11 | :param graph_data: pytorch geometric graph object 12 | :param labels: Whether to plot with label information (colored nodes) 13 | :param node_size: size of nodes in graph 14 | :param line_width: strength of edges between nodes in graph 15 | :return: nothing (plots graph) 16 | """ 17 | graph_viz = to_networkx(graph_data) 18 | if labels: 19 | node_labels = graph_data.y[list(graph_viz.nodes)].numpy() 20 | ColorLegend = {'News Text': 0, 'Tweets/Retweets': 1, 'Users': 2} 21 | cNorm = colors.Normalize(vmin=0, vmax=2) 22 | jet = cm = plt.get_cmap('brg') 23 | scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=jet) 24 | f = plt.figure(1) 25 | ax = f.add_subplot(1, 1, 1) 26 | for label in ColorLegend: 27 | ax.plot([0], [0], color=scalarMap.to_rgba(ColorLegend[label]), label=label, marker='.', linestyle='None') 28 | nx.draw(graph_viz, cmap=jet, vmin=0, vmax=2, arrowstyle='-', node_color=node_labels, width=0.3, node_size=node_size, 29 | linewidths=line_width, ax=ax) 30 | plt.axis('off') 31 | f.set_facecolor('w') 32 | plt.legend() 33 | f.tight_layout() 34 | if save: 35 | f.savefig('graph.eps', format='eps') 36 | else: 37 | plt.figure(1, figsize=(7, 7)) 38 | nx.draw(graph_viz, cmap=plt.get_cmap('Set1'), arrowstyle='-', node_size=node_size, linewidths=line_width) 39 | plt.show() 40 | -------------------------------------------------------------------------------- /machine_learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doGregor/Graph-FakeNewsNet/c5d86b3244499261898b2ff261a847e3dda8d983/machine_learning/__init__.py -------------------------------------------------------------------------------- /machine_learning/gnn_models.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.loader import DataLoader 2 | import torch 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, HeteroConv, GATConv, SAGEConv, RGCNConv, HGTConv, Linear 5 | from torch_geometric.nn import global_mean_pool 6 | 7 | 8 | class GraphSAGE(torch.nn.Module): 9 | def __init__(self, hidden_channels, out_channels, metadata, num_layers=2): 10 | super(GraphSAGE, self).__init__() 11 | torch.manual_seed(12345) 12 | 13 | self.convs = torch.nn.ModuleList() 14 | for _ in range(num_layers): 15 | conv = HeteroConv({edge_type: SAGEConv((-1, -1), hidden_channels) for edge_type in metadata[1]}) 16 | self.convs.append(conv) 17 | 18 | self.lin = torch.nn.Linear(hidden_channels*3, out_channels) 19 | 20 | def forward(self, x_dict, edge_index_dict, batch_dict): 21 | 22 | for conv in self.convs: 23 | x_dict = conv(x_dict, edge_index_dict) 24 | x_dict = {key: x.relu() for key, x in x_dict.items()} 25 | 26 | x_dict = {key: global_mean_pool(x, batch_dict[key]) for key, x in x_dict.items()} 27 | x = torch.cat([x_dict['article'], x_dict['tweet'], x_dict['user']], dim=1) 28 | x = F.dropout(x, p=0.5, training=self.training) 29 | x = self.lin(x) 30 | 31 | return x 32 | 33 | 34 | class GAT(torch.nn.Module): 35 | def __init__(self, hidden_channels, out_channels, metadata, num_layers=2, num_attention_heads=3): 36 | super(GAT, self).__init__() 37 | torch.manual_seed(12345) 38 | 39 | self.convs = torch.nn.ModuleList() 40 | for _ in range(num_layers): 41 | conv = HeteroConv({edge_type: GATConv((-1, -1), hidden_channels, heads=num_attention_heads, 42 | add_self_loops=False) for edge_type in metadata[1]}) 43 | self.convs.append(conv) 44 | 45 | self.lin = torch.nn.Linear(hidden_channels*3*num_attention_heads, out_channels) 46 | 47 | def forward(self, x_dict, edge_index_dict, batch_dict): 48 | 49 | for conv in self.convs: 50 | x_dict = conv(x_dict, edge_index_dict) 51 | x_dict = {key: x.relu() for key, x in x_dict.items()} 52 | 53 | x_dict = {key: global_mean_pool(x, batch_dict[key]) for key, x in x_dict.items()} 54 | x = torch.cat([x_dict['article'], x_dict['tweet'], x_dict['user']], dim=1) 55 | x = F.dropout(x, p=0.5, training=self.training) 56 | 57 | x = self.lin(x) 58 | 59 | return x 60 | 61 | 62 | class HGT(torch.nn.Module): 63 | def __init__(self, hidden_channels, out_channels, metadata, num_layers=2, num_attention_heads=1): 64 | super(HGT, self).__init__() 65 | torch.manual_seed(12345) 66 | 67 | self.lin_dict = torch.nn.ModuleDict() 68 | for node_type in metadata[0]: 69 | self.lin_dict[node_type] = Linear(-1, hidden_channels) 70 | 71 | self.convs = torch.nn.ModuleList() 72 | for _ in range(num_layers): 73 | conv = HGTConv(hidden_channels, hidden_channels, metadata, num_attention_heads, group='sum') 74 | self.convs.append(conv) 75 | 76 | self.lin = Linear(hidden_channels*3, out_channels) 77 | 78 | def forward(self, x_dict, edge_index_dict, batch_dict): 79 | for node_type, x in x_dict.items(): 80 | x_dict[node_type] = self.lin_dict[node_type](x).relu_() 81 | 82 | for conv in self.convs: 83 | x_dict = conv(x_dict, edge_index_dict) 84 | 85 | x_dict = {key: global_mean_pool(x, batch_dict[key]) for key, x in x_dict.items()} 86 | x = torch.cat([x_dict['article'], x_dict['tweet'], x_dict['user']], dim=1) 87 | x = F.dropout(x, p=0.5, training=self.training) 88 | 89 | x = self.lin(x) 90 | 91 | return x 92 | -------------------------------------------------------------------------------- /machine_learning/gnn_training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score 4 | 5 | 6 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 7 | 8 | 9 | def train_model(model, train_loader, loss_fct, optimizer): 10 | model.train() 11 | for batch_idx, data in enumerate(train_loader): # Iterate in batches over the training dataset. 12 | data.to(DEVICE) 13 | out = model(data.x_dict, data.edge_index_dict, data.batch_dict) # Perform a single forward pass. 14 | loss = loss_fct(out, data['article'].y) # Compute the loss. 15 | loss.backward() # Derive gradients. 16 | optimizer.step() # Update parameters based on gradients. 17 | optimizer.zero_grad() # Clear gradients. 18 | 19 | 20 | def eval_model(model, test_loader, print_classification_report=False): 21 | model.eval() 22 | correct = 0 23 | true_y = [] 24 | pred_y = [] 25 | for data in test_loader: # Iterate in batches over the training/test dataset. 26 | data.to(DEVICE) 27 | out = model(data.x_dict, data.edge_index_dict, data.batch_dict) 28 | pred = out.argmax(dim=1) # Use the class with highest probability. 29 | pred_y.append(pred.cpu().detach().numpy()) 30 | correct += int((pred == data['article'].y).sum()) # Check against ground-truth labels. 31 | true_y.append(data['article'].y.cpu().detach().numpy()) 32 | if print_classification_report: 33 | print(classification_report(np.concatenate(true_y), np.concatenate(pred_y), digits=5)) 34 | return accuracy_score(np.concatenate(true_y), np.concatenate(pred_y)), precision_score(np.concatenate(true_y), np.concatenate(pred_y), average='macro'), recall_score(np.concatenate(true_y), np.concatenate(pred_y), average='macro'), f1_score(np.concatenate(true_y), np.concatenate(pred_y), average='macro') 35 | 36 | 37 | def train_eval_model(model, train_loader, test_loader, loss_fct, optimizer, num_epochs=1, verbose=1): 38 | model.to(DEVICE) 39 | for epoch in range(1, num_epochs+1): 40 | train_model(model=model, train_loader=train_loader, loss_fct=loss_fct, optimizer=optimizer) 41 | train_acc, train_p, train_r, train_f1 = eval_model(model, train_loader) 42 | if epoch == num_epochs: 43 | test_acc, test_p, test_r, test_f1 = eval_model(model, test_loader, print_classification_report=True) 44 | return test_acc, test_p, test_r, test_f1 45 | else: 46 | test_acc, test_p, test_r, test_f1 = eval_model(model, test_loader) 47 | if verbose == 1: 48 | print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}') 49 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torch-geometric==2.0.2 3 | flair==0.10 4 | transformers==4.11.3 5 | tqdm 6 | scikit-learn 7 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | ## Results Overview 2 | 3 | The file `results_static_graphs.pdf` gives an overview about all experiments that we were running with static 4 | heterogeneous/homogeneous graphs with/without social media context features. The tables are sorted as follows: 5 | 6 | **Page 1:** 7 | * **Table 5**: Results on the Politifact subset using heterogeneous graphs with social context features 8 | * **Table 6**: Results on the GossipCop subset using heterogeneous graphs with social context features 9 | * **Table 7**: Results on the full FakeNewsNet dataset using heterogeneous graphs with social context features 10 | 11 | **Page 2:** 12 | * **Table 8**: Results on the Politifact subset using homogeneous graphs with/without social context features in the best 13 | performing setup of the experiments with heterogeneous graphs 14 | * **Table 9**: Results on the GossipCop subset using homogeneous graphs with/without social context features in the best 15 | performing setup of the experiments with heterogeneous graphs 16 | * **Table 10**: Results on the full FakeNewsNet dataset using homogeneous graphs with/without social context features in the best 17 | performing setup of the experiments with heterogeneous graphs 18 | 19 | **Page 3:** 20 | * **Table 11**: Results on the Politifact subset using heterogeneous graphs without social context features 21 | * **Table 12**: Results on the GossipCop subset using heterogeneous graphs without social context features 22 | * **Table 13**: Results on the full FakeNewsNet dataset using heterogeneous graphs without social context features 23 | -------------------------------------------------------------------------------- /results/results_static_graphs.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doGregor/Graph-FakeNewsNet/c5d86b3244499261898b2ff261a847e3dda8d983/results/results_static_graphs.pdf -------------------------------------------------------------------------------- /scripts/generate_graphs.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import os 3 | 4 | os.environ["CUDA_VISIBLE_DEVICES"]="0" 5 | os.environ["TF_ENABLE_ONEDNN_OPTS"]="0" 6 | 7 | import sys 8 | module_path = os.path.abspath(os.path.join('..')) 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | 12 | from data_preprocessing.graph_structure import * 13 | 14 | 15 | # generates graph for GossipCop dataset 16 | 17 | DATASET = 'gossipcop' 18 | 19 | ids_true, ids_fake = get_news_ids(dataset=DATASET) 20 | 21 | setups = ['tweets_only'] # 'all_data' 'no_retweets' 'no_timeline' 'tweets_users' 'tweets_only' 22 | 23 | for s in setups: 24 | print(s) 25 | 26 | if s == 'all_data': 27 | include_retweets = True 28 | include_user_timeline_tweets = True 29 | include_users = True 30 | include_tweets = True 31 | if s == 'no_retweets': 32 | include_retweets = False 33 | include_user_timeline_tweets = True 34 | include_users = True 35 | include_tweets = True 36 | if s == 'no_timeline': 37 | include_retweets = True 38 | include_user_timeline_tweets = False 39 | include_users = True 40 | include_tweets = True 41 | if s == 'tweets_only': 42 | include_retweets = False 43 | include_user_timeline_tweets = False 44 | include_users = False 45 | include_tweets = True 46 | if s == 'tweets_users': 47 | include_retweets = False 48 | include_user_timeline_tweets = False 49 | include_users = True 50 | include_tweets = True 51 | 52 | def save_graph_to_pickle(graph, file_name): 53 | path = "static_graphs/" + DATASET + '/' + s + '/' + file_name + ".pickle" 54 | with open(path, 'wb') as handle: 55 | pickle.dump({'graph': graph}, handle, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | print("Starting with real news...") 58 | for id in tqdm(list(ids_true)): 59 | try: 60 | graph = create_heterogeneous_graph({'real': [id]}, 61 | dataset=DATASET, 62 | include_tweets=include_tweets, 63 | include_user_followers=False, 64 | include_user_following=False, 65 | include_retweets=include_retweets, 66 | include_user_timeline_tweets=include_user_timeline_tweets, 67 | to_undirected=True, 68 | include_text=True, 69 | include_users=include_users) 70 | if graph['article'].x.size()[0] > 0: 71 | save_graph_to_pickle(graph, id) 72 | except: 73 | print('ERROR', id) 74 | 75 | 76 | print("Starting with fake news...") 77 | for id in tqdm(list(ids_fake)): 78 | try: 79 | graph = create_heterogeneous_graph({'fake': [id]}, 80 | dataset=DATASET, 81 | include_tweets=include_tweets, 82 | include_user_followers=False, 83 | include_user_following=False, 84 | include_retweets=include_retweets, 85 | include_user_timeline_tweets=include_user_timeline_tweets, 86 | to_undirected=True, 87 | include_text=True, 88 | include_users=include_users) 89 | if graph['article'].x.size()[0] > 0: 90 | save_graph_to_pickle(graph, id) 91 | except: 92 | print('ERROR', id) 93 | -------------------------------------------------------------------------------- /scripts/run_experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from tqdm import tqdm 4 | from random import shuffle 5 | from sklearn.model_selection import KFold 6 | from operator import itemgetter 7 | from sklearn.utils import compute_class_weight 8 | module_path = os.path.abspath(os.path.join('..')) 9 | if module_path not in sys.path: 10 | sys.path.append(module_path) 11 | from data_preprocessing.graph_structure import * 12 | from machine_learning.gnn_models import * 13 | from machine_learning.gnn_training import * 14 | from data_preprocessing.load_data import * 15 | 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | 20 | CONFIG = { 21 | 'dataset': 'gossipcop', 22 | 'setting': 'all-data', 23 | 'batch_size': 16, 24 | 'hidden_dim': 64, 25 | 'learning_rate': 0.00008, 26 | 'weight_decay': 0.00005, 27 | 'epochs': 20 28 | } 29 | 30 | 31 | def load_graph_by_path(file_name, dataset='gossipcop', setting='all_data'): 32 | if setting in ['tweets', 'tweets_users', 'tweets_users_retweets', 'tweets_users_timeline', 'all_data'] and \ 33 | dataset in ['gossipcop', 'politifact']: 34 | path = "../data/graphs_" + dataset + "/" + setting + "/" + file_name + ".pickle" 35 | with open(path, 'rb') as handle: 36 | return pickle.load(handle) 37 | 38 | 39 | ids_true, ids_fake = get_news_ids('gossipcop') 40 | 41 | relevant_real = [] 42 | relevant_fake = [] 43 | 44 | path_to_graphs = '../data/graphs_' + CONFIG['dataset'] + '/' + CONFIG['setting'] + '/' 45 | for graph_id in os.listdir(path_to_graphs): 46 | if graph_id.endswith('pickle'): 47 | graph_id = graph_id.split('.')[0] 48 | graph = load_graph_by_path(graph_id, CONFIG['dataset'], CONFIG['setting'])['graph'] 49 | if graph['user'].x.size()[0] >= 5 and graph['tweet'].x.size()[0] >= 5 and \ 50 | (0 == graph['tweet', 'cites', 'article'].edge_index.size()[1] or 51 | graph['tweet', 'cites', 'article'].edge_index.size()[1] >= 5) and \ 52 | (0 == graph['user', 'posts', 'tweet'].edge_index.size()[1] or 53 | graph['user', 'posts', 'tweet'].edge_index.size()[1] >=5) and \ 54 | (0 == graph['tweet', 'retweets', 'tweet'].edge_index.size()[1] or 55 | graph['tweet', 'retweets', 'tweet'].edge_index.size()[1] >=5): 56 | if graph_id in list(ids_true): 57 | relevant_real.append(T.NormalizeFeatures()(graph)) 58 | elif graph_id in list(ids_fake): 59 | relevant_fake.append(T.NormalizeFeatures()(graph)) 60 | 61 | print("Number or real news graphs:", len(relevant_real)) 62 | print("Number of fake news graphs:", len(relevant_fake)) 63 | 64 | all_graphs = relevant_real + relevant_fake 65 | shuffle(all_graphs) 66 | 67 | kf = KFold(n_splits=5) 68 | kf.get_n_splits(all_graphs) 69 | 70 | train_splits = [] 71 | test_splits = [] 72 | 73 | for train_index, test_index in kf.split(all_graphs): 74 | print(30*"*") 75 | X_train, X_test = itemgetter(*train_index)(all_graphs), itemgetter(*test_index)(all_graphs) 76 | print("Num train:", len(X_train), "Num test:", len(X_test)) 77 | train_splits.append(X_train) 78 | test_splits.append(X_test) 79 | 80 | 81 | # ###################################### # 82 | # ################ SAGE ################ # 83 | # ###################################### # 84 | 85 | 86 | acc_all = [] 87 | p_all = [] 88 | r_all = [] 89 | f1_all = [] 90 | 91 | for idx, val in enumerate(train_splits): 92 | X_train = val 93 | X_test = test_splits[idx] 94 | 95 | y_tensors = [] 96 | for graph in val: 97 | y_tensors.append(graph['article'].y) 98 | 99 | class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.asarray([0, 1]), 100 | y=torch.cat(y_tensors).cpu().detach().numpy()), 101 | dtype=torch.float32) 102 | class_weights.to(device) 103 | 104 | train_loader = DataLoader(X_train, batch_size=CONFIG['batch_size'], shuffle=False) 105 | test_loader = DataLoader(X_test, batch_size=CONFIG['batch_size'], shuffle=False) 106 | 107 | model = GraphSAGE(hidden_channels=CONFIG['hidden_dim'], out_channels=2, metadata=relevant_fake[0].metadata()) 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay']) 110 | criterion = torch.nn.CrossEntropyLoss(weight=class_weights) 111 | criterion.to(device) 112 | acc, precision, recall, f1 = train_eval_model(model=model, train_loader=train_loader, test_loader=test_loader, 113 | loss_fct=criterion, optimizer=optimizer, num_epochs=CONFIG['epochs'], 114 | verbose=0) 115 | acc_all.append(acc) 116 | p_all.append(precision) 117 | r_all.append(recall) 118 | f1_all.append(f1) 119 | 120 | print("ACC SAGE", acc_all, sum(acc_all) / len(acc_all)) 121 | print("P SAGE", p_all, sum(p_all) / len(p_all)) 122 | print("R SAGE", r_all, sum(r_all) / len(r_all)) 123 | print("F1 SAGE", f1_all, sum(f1_all) / len(f1_all)) 124 | 125 | 126 | # ###################################### # 127 | # ################ GAT ################# # 128 | # ###################################### # 129 | 130 | 131 | acc_all = [] 132 | p_all = [] 133 | r_all = [] 134 | f1_all = [] 135 | 136 | for idx, val in enumerate(train_splits): 137 | X_train = val 138 | X_test = test_splits[idx] 139 | 140 | y_tensors = [] 141 | for graph in val: 142 | y_tensors.append(graph['article'].y) 143 | 144 | class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.asarray([0, 1]), 145 | y=torch.cat(y_tensors).cpu().detach().numpy()), 146 | dtype=torch.float32) 147 | class_weights.to(device) 148 | 149 | train_loader = DataLoader(X_train, batch_size=CONFIG['batch_size'], shuffle=False) 150 | test_loader = DataLoader(X_test, batch_size=CONFIG['batch_size'], shuffle=False) 151 | 152 | model = GAT(hidden_channels=CONFIG['hidden_dim'], out_channels=2, metadata=relevant_fake[0].metadata()) 153 | 154 | optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay']) 155 | criterion = torch.nn.CrossEntropyLoss(weight=class_weights) 156 | criterion.to(device) 157 | acc, precision, recall, f1 = train_eval_model(model=model, train_loader=train_loader, test_loader=test_loader, 158 | loss_fct=criterion, optimizer=optimizer, num_epochs=CONFIG['epochs'], 159 | verbose=0) 160 | acc_all.append(acc) 161 | p_all.append(precision) 162 | r_all.append(recall) 163 | f1_all.append(f1) 164 | 165 | print("ACC GAT", acc_all, sum(acc_all) / len(acc_all)) 166 | print("P GAT", p_all, sum(p_all) / len(p_all)) 167 | print("R GAT", r_all, sum(r_all) / len(r_all)) 168 | print("F1 GAT", f1_all, sum(f1_all) / len(f1_all)) 169 | 170 | 171 | # ###################################### # 172 | # ################ HGT ################# # 173 | # ###################################### # 174 | 175 | 176 | acc_all = [] 177 | p_all = [] 178 | r_all = [] 179 | f1_all = [] 180 | 181 | for idx, val in enumerate(train_splits): 182 | X_train = val 183 | X_test = test_splits[idx] 184 | 185 | y_tensors = [] 186 | for graph in val: 187 | y_tensors.append(graph['article'].y) 188 | 189 | class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.asarray([0, 1]), 190 | y=torch.cat(y_tensors).cpu().detach().numpy()), 191 | dtype=torch.float32) 192 | class_weights.to(device) 193 | 194 | train_loader = DataLoader(X_train, batch_size=CONFIG['batch_size'], shuffle=False) 195 | test_loader = DataLoader(X_test, batch_size=CONFIG['batch_size'], shuffle=False) 196 | 197 | model = HGT(hidden_channels=CONFIG['hidden_dim'], out_channels=2, metadata=relevant_fake[0].metadata()) 198 | 199 | optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay']) 200 | criterion = torch.nn.CrossEntropyLoss(weight=class_weights) 201 | criterion.to(device) 202 | acc, precision, recall, f1 = train_eval_model(model=model, train_loader=train_loader, test_loader=test_loader, 203 | loss_fct=criterion, optimizer=optimizer, num_epochs=CONFIG['epochs'], 204 | verbose=0) 205 | acc_all.append(acc) 206 | p_all.append(precision) 207 | r_all.append(recall) 208 | f1_all.append(f1) 209 | 210 | print("ACC", acc_all, sum(acc_all) / len(acc_all)) 211 | print("P", p_all, sum(p_all) / len(p_all)) 212 | print("R", r_all, sum(r_all) / len(r_all)) 213 | print("F1", f1_all, sum(f1_all) / len(f1_all)) -------------------------------------------------------------------------------- /temporal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doGregor/Graph-FakeNewsNet/c5d86b3244499261898b2ff261a847e3dda8d983/temporal/__init__.py -------------------------------------------------------------------------------- /temporal/temporal_gnn_models.py: -------------------------------------------------------------------------------- 1 | from temporal.temporal_layers import * 2 | import torch.nn.functional as F 3 | import torch 4 | from torch_geometric.nn import global_mean_pool, HGTConv 5 | from torch.nn import Transformer 6 | import math 7 | 8 | 9 | class HeteroLSTMGCN(torch.nn.Module): 10 | def __init__(self, node_feature_dict, metadata, num_hidden_nodes, num_output_features): 11 | super(HeteroLSTMGCN, self).__init__() 12 | self.recurrent = HeteroGCLSTM(node_feature_dict, num_hidden_nodes, metadata) 13 | self.linear = torch.nn.Linear(num_hidden_nodes*3, num_output_features) 14 | 15 | def forward(self, x_dict, edge_index_dict, batch_dict, h_dict, c_dict): 16 | h_0, c_0 = self.recurrent(x_dict, edge_index_dict, h_dict, c_dict) 17 | 18 | h = {key: val.relu() for key, val in h_0.items()} 19 | h = {key: global_mean_pool(val, batch_dict[key]) for key, val in h.items()} 20 | h = torch.cat([h['article'], h['tweet'], h['user']], dim=1) 21 | h = F.dropout(h, p=0.5, training=self.training) 22 | h = self.linear(h) 23 | 24 | return h, h_0, c_0 25 | 26 | 27 | class HeteroGRUGCN(torch.nn.Module): 28 | def __init__(self, num_hidden_nodes, metadata, num_output_features): 29 | super(HeteroGRUGCN, self).__init__() 30 | self.recurrent = HeteroGConvGRU(num_hidden_nodes, metadata) 31 | self.linear = torch.nn.Linear(num_hidden_nodes * 3, num_output_features) 32 | 33 | def forward(self, x_dict, edge_index_dict, batch_dict, h_dict): 34 | h_0 = self.recurrent(x_dict, edge_index_dict, h_dict) 35 | 36 | h = {key: val.relu() for key, val in h_0.items()} 37 | h = {key: global_mean_pool(val, batch_dict[key]) for key, val in h.items()} 38 | h = torch.cat([h['article'], h['tweet'], h['user']], dim=1) 39 | h = F.dropout(h, p=0.5, training=self.training) 40 | h = self.linear(h) 41 | 42 | return h, h_0 43 | 44 | 45 | class HeteroGATGRUModel(torch.nn.Module): 46 | def __init__(self, num_hidden_nodes, metadata, num_output_features, num_attention_heads=3): 47 | super(HeteroGATGRUModel, self).__init__() 48 | self.recurrent = HeteroGATGRU(num_hidden_nodes, metadata, num_attention_heads) 49 | self.linear = torch.nn.Linear(num_hidden_nodes * 3 * num_attention_heads, num_output_features) 50 | 51 | def forward(self, x_dict, edge_index_dict, batch_dict, h_dict): 52 | h_0 = self.recurrent(x_dict, edge_index_dict, h_dict) 53 | 54 | h = {key: val.relu() for key, val in h_0.items()} 55 | h = {key: global_mean_pool(val, batch_dict[key]) for key, val in h.items()} 56 | h = torch.cat([h['article'], h['tweet'], h['user']], dim=1) 57 | h = F.dropout(h, p=0.5, training=self.training) 58 | h = self.linear(h) 59 | 60 | return h, h_0 61 | 62 | 63 | class PositionalEncoding(nn.Module): 64 | def __init__(self, d_model, dropout=0.1, max_len=5000): 65 | super().__init__() 66 | self.dropout = nn.Dropout(p=dropout) 67 | 68 | position = torch.arange(max_len).unsqueeze(1) 69 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 70 | pe = torch.zeros(max_len, 1, d_model) 71 | pe[:, 0, 0::2] = torch.sin(position * div_term) 72 | pe[:, 0, 1::2] = torch.cos(position * div_term) 73 | self.register_buffer('pe', pe) 74 | 75 | def forward(self, x): 76 | """ 77 | Args: 78 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 79 | """ 80 | x = x + self.pe[:x.size(0)] 81 | return self.dropout(x) 82 | 83 | 84 | class TransformerGAT(torch.nn.Module): 85 | def __init__(self, 86 | num_timesteps, 87 | num_graph_features, 88 | num_output_features, 89 | metadata, 90 | num_attention_heads=1, 91 | num_layers=2, 92 | num_hidden_gnn=64, 93 | num_output_gnn=8, 94 | num_transformer_heads=8, 95 | num_encoders=3, 96 | num_decoders=3): 97 | super(TransformerGAT, self).__init__() 98 | 99 | self.graph_convs = nn.ModuleList() 100 | for i in range(num_timesteps): 101 | convs = nn.ModuleList() 102 | convs.append(HGTConv(-1, num_hidden_gnn, metadata, num_attention_heads)) 103 | convs.append(HGTConv(-1, num_output_gnn, metadata, num_attention_heads)) 104 | self.graph_convs.append(convs) 105 | 106 | self.pos_encoder = PositionalEncoding(d_model=num_graph_features) 107 | self.Transformer = Transformer(d_model=num_graph_features, 108 | num_encoder_layers=num_encoders, 109 | num_decoder_layers=num_decoders, 110 | nhead=num_transformer_heads) 111 | 112 | self.linear = torch.nn.Linear(num_timesteps * num_graph_features, num_output_features) 113 | 114 | def forward(self, snapshot_batch): 115 | 116 | hidden_representations = [] 117 | for idx, t_batch in enumerate(snapshot_batch): 118 | conv_layers = self.graph_convs[idx] 119 | batch_x_dict = t_batch.x_dict 120 | for conv in conv_layers: 121 | batch_x_dict = conv(batch_x_dict, t_batch.edge_index_dict) 122 | batch_x_dict = {key: x.relu() for key, x in batch_x_dict.items()} 123 | batch_x_dict = {key: global_mean_pool(x, t_batch.batch_dict[key]) for key, x in batch_x_dict.items()} 124 | x = torch.cat([batch_x_dict['article'], batch_x_dict['tweet'], batch_x_dict['user']], dim=1) 125 | hidden_representations.append(x) 126 | 127 | x_features = torch.stack((hidden_representations)) 128 | 129 | x_features = self.pos_encoder(x_features) 130 | x_features = self.Transformer(x_features, x_features) 131 | x_features = x_features.relu() 132 | 133 | x_features = torch.moveaxis(x_features, 0, 1) 134 | x_features = torch.reshape(x_features, (x_features.shape[0], x_features.shape[1] * x_features.shape[2])) 135 | 136 | x_features = self.linear(x_features) 137 | 138 | return x_features 139 | -------------------------------------------------------------------------------- /temporal/temporal_gnn_training.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from random import shuffle 4 | from torch_geometric.data import Batch 5 | from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score 6 | 7 | 8 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 9 | 10 | 11 | def train_temporal_model(model, train_batches, loss_fct, optimizer, shuffle_batches=True, cumulative_costs=True): 12 | model.train() 13 | y_true = [] 14 | y_pred = [] 15 | if shuffle_batches: 16 | shuffle(train_batches) 17 | for batch in train_batches: 18 | cost = 0 19 | h_dict = None 20 | for idx_snap in range(0, len(batch[0])): 21 | graph_list = [] 22 | for graph_sequence in batch: 23 | graph_list.append(graph_sequence[idx_snap]) 24 | snapshot_batch = Batch.from_data_list(graph_list) 25 | snapshot_batch.to(DEVICE) 26 | y_hat, h_dict = model(snapshot_batch.x_dict, snapshot_batch.edge_index_dict, 27 | snapshot_batch.batch_dict, h_dict) 28 | if cumulative_costs: 29 | cost += loss_fct(y_hat, torch.flatten(snapshot_batch['article'].y)) 30 | else: 31 | cost = loss_fct(y_hat, torch.flatten(snapshot_batch['article'].y)) 32 | if idx_snap == len(batch[0]) - 1: 33 | y_pred.append(y_hat.argmax(dim=1).cpu().detach().numpy()) 34 | y_true.append(snapshot_batch['article'].y.cpu().detach().numpy()) 35 | if cumulative_costs: 36 | cost /= len(batch[0]) 37 | cost.backward() 38 | optimizer.step() 39 | optimizer.zero_grad() 40 | return y_true, y_pred 41 | 42 | 43 | def evaluate_temporal_model(model, test_batches): 44 | model.eval() 45 | true_y = [] 46 | pred_y = [] 47 | for batch in test_batches: 48 | h_dict = None 49 | for idx_snap in range(0, len(batch[0])): 50 | graph_list = [] 51 | for graph_sequence in batch: 52 | graph_list.append(graph_sequence[idx_snap]) 53 | snapshot_batch = Batch.from_data_list(graph_list) 54 | snapshot_batch.to(DEVICE) 55 | y_hat, h_dict = model(snapshot_batch.x_dict, snapshot_batch.edge_index_dict, 56 | snapshot_batch.batch_dict, h_dict) 57 | pred = y_hat.argmax(dim=1) 58 | pred_y.append(pred.cpu().detach().numpy()) 59 | true_y.append(snapshot_batch['article'].y.cpu().detach().numpy()) 60 | return true_y, pred_y 61 | 62 | 63 | def train_eval_temporal_model(model, train_batches, test_batches, loss_fct, optimizer, cumulative_costs=True, 64 | shuffle_train_batches=True, num_epochs=20, verbose=1): 65 | for epoch in range(num_epochs): 66 | model.to(DEVICE) 67 | true_train, pred_train = train_temporal_model(model=model, train_batches=train_batches, loss_fct=loss_fct, 68 | shuffle_batches=shuffle_train_batches, optimizer=optimizer, 69 | cumulative_costs=cumulative_costs) 70 | true_test, pred_test = evaluate_temporal_model(model=model, test_batches=test_batches) 71 | 72 | true_train = np.concatenate(true_train) 73 | pred_train = np.concatenate(pred_train) 74 | true_test = np.concatenate(true_test) 75 | pred_test = np.concatenate(pred_test) 76 | 77 | if epoch + 1 == num_epochs: 78 | print(classification_report(true_test, pred_test, digits=5)) 79 | 80 | test_acc = accuracy_score(true_test, pred_test) 81 | test_p = precision_score(true_test, pred_test, average='macro') 82 | test_r = recall_score(true_test, pred_test, average='macro') 83 | test_f1 = f1_score(true_test, pred_test, average='macro') 84 | 85 | return test_acc, test_p, test_r, test_f1 86 | else: 87 | if verbose == 1: 88 | train_acc = accuracy_score(true_train, pred_train) 89 | train_f1 = f1_score(true_train, pred_train, average='macro') 90 | 91 | test_acc = accuracy_score(true_test, pred_test) 92 | test_f1 = f1_score(true_test, pred_test, average='macro') 93 | 94 | print_epoch = epoch + 1 95 | print(f'Epoch: {print_epoch:03d},', 96 | f'Train Acc: {train_acc:.4f}, Train F1: {train_f1:.4f},', 97 | f'Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}') 98 | -------------------------------------------------------------------------------- /temporal/temporal_graph_structure.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from data_preprocessing.load_data import * 3 | from torch_geometric.data import Data, HeteroData 4 | import torch_geometric.transforms as T 5 | from datetime import datetime, timedelta 6 | from data_preprocessing.feature_extraction import * 7 | import torch 8 | 9 | 10 | 11 | def create_heterogeneous_snapshot(news_id_dict, tweet_id_list, start_time, end_time, 12 | dataset='politifact', include_tweets=True, 13 | include_users=True, include_user_timeline_tweets=True, include_retweets=True, 14 | include_user_followers=True, include_user_following=True, add_new_users=False, 15 | to_undirected=True, include_text=False): 16 | node_ids = {'article': [], 17 | 'tweet': [], 18 | 'user': []} 19 | graph = HeteroData() 20 | graph['article'].x = [] 21 | graph['article'].y = [] 22 | if include_tweets: 23 | graph['tweet'].x = [[], []] 24 | graph['tweet', 'cites', 'article'].edge_index = [[], []] 25 | if include_users: 26 | graph['user'].x = [[], []] 27 | graph['user', 'posts', 'tweet'].edge_index = [[], []] 28 | if include_user_followers or include_user_following: 29 | graph['user', 'follows', 'user'].edge_index = [[], []] 30 | if include_retweets: 31 | graph['tweet', 'retweets', 'tweet'].edge_index = [[], []] 32 | 33 | for subset, news_id in news_id_dict.items(): 34 | if content_available(news_id=news_id, dataset=dataset, subset=subset): 35 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 36 | # news node feature 37 | graph['article'].x.append(get_news_features(news_content)) 38 | if subset == 'fake': 39 | graph['article'].y.append(1) 40 | elif subset == 'real': 41 | graph['article'].y.append(0) 42 | node_ids['article'].append(news_id) 43 | if include_tweets: 44 | tweets_path, tweet_list = get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset) 45 | for tweet in tweet_list: 46 | tweet_data = open_tweet_json(tweets_path, tweet) 47 | if tweet_data['id'] not in node_ids['tweet']: 48 | # tweets node features 49 | if tweet in tweet_id_list: 50 | graph['tweet'].x[0].append(get_tweet_features(tweet_data)[0]) 51 | graph['tweet'].x[1].append(get_tweet_features(tweet_data)[1:]) 52 | else: 53 | graph['tweet'].x[0].append('') 54 | graph['tweet'].x[1].append([0]*2) 55 | node_ids['tweet'].append(tweet_data['id']) 56 | node_id_news = node_ids['article'].index(news_id) 57 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 58 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 59 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 60 | if include_users: 61 | user_information = get_user_information(tweet_data['user']['id']) 62 | if user_information: 63 | if user_information['id'] not in node_ids['user']: 64 | # user node features 65 | if tweet in tweet_id_list: 66 | graph['user'].x[0].append(get_user_features(user_information)[0]) 67 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 68 | else: 69 | graph['user'].x[0].append('') 70 | graph['user'].x[1].append([0]*4) 71 | node_ids['user'].append(user_information['id']) 72 | node_id_user = node_ids['user'].index(user_information['id']) 73 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 74 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 75 | if include_retweets: 76 | retweets_path, retweet_ids = get_retweet_ids(news_id=news_id, dataset=dataset, subset=subset) 77 | if tweet in retweet_ids: 78 | retweets_data = open_retweet_json(retweets_path, tweet) 79 | for retweet in retweets_data: 80 | if retweet['id'] not in node_ids['tweet']: 81 | if start_time <= datetime.strptime(retweet['created_at'], '%a %b %d %H:%M:%S +0000 %Y') < end_time: 82 | # retweets node features 83 | graph['tweet'].x[0].append(get_tweet_features(retweet)[0]) 84 | graph['tweet'].x[1].append(get_tweet_features(retweet)[1:]) 85 | else: 86 | graph['tweet'].x[0].append('') 87 | graph['tweet'].x[1].append([0]*2) 88 | node_ids['tweet'].append(retweet['id']) 89 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 90 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 91 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 92 | if include_users: 93 | user_information = get_user_information(retweet['user']['id']) 94 | if user_information: 95 | if user_information['id'] not in node_ids['user']: 96 | # user node features 97 | if tweet in tweet_id_list: 98 | graph['user'].x[0].append(get_user_features(user_information)[0]) 99 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 100 | else: 101 | graph['user'].x[0].append('') 102 | graph['user'].x[1].append([0]*4) 103 | node_ids['user'].append(user_information['id']) 104 | node_id_user = node_ids['user'].index(user_information['id']) 105 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 106 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 107 | else: 108 | print(f"[WARNING] excluding sample with id {news_id} no news or tweets available") 109 | graph['article'].x = torch.tensor(graph['article'].x, dtype=torch.float32) 110 | return graph 111 | if include_users and include_user_timeline_tweets and len(node_ids['user']) > 0: 112 | for user_id in node_ids['user']: 113 | user_timeline_tweets = get_user_timeline_tweets(user_id, n=0) 114 | node_id_user = node_ids['user'].index(user_id) 115 | if len(user_timeline_tweets) > 0: 116 | for user_timeline_tweet_data in user_timeline_tweets: 117 | if user_timeline_tweet_data['id'] not in node_ids['tweet']: 118 | if start_time <= datetime.strptime(user_timeline_tweet_data['created_at'], '%a %b %d %H:%M:%S +0000 %Y') < end_time: 119 | graph['tweet'].x[0].append(get_tweet_features(user_timeline_tweet_data)[0]) 120 | graph['tweet'].x[1].append(get_tweet_features(user_timeline_tweet_data)[1:]) 121 | else: 122 | graph['tweet'].x[0].append('') 123 | graph['tweet'].x[1].append([0]*2) 124 | node_ids['tweet'].append(user_timeline_tweet_data['id']) 125 | node_id_tweet = node_ids['tweet'].index(user_timeline_tweet_data['id']) 126 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 127 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 128 | graph['article'].x = torch.tensor(text_embeddings(graph['article'].x), dtype=torch.float32) 129 | graph['article'].y = torch.tensor(graph['article'].y, dtype=torch.long) 130 | if include_tweets: 131 | if include_text and np.asarray(graph['tweet'].x[1]).shape[0] > 0: 132 | graph['tweet'].x = torch.tensor(np.concatenate((text_embeddings(graph['tweet'].x[0]), np.asarray(graph['tweet'].x[1])), axis=1), dtype=torch.float32) 133 | else: 134 | graph['tweet'].x = torch.tensor(graph['tweet'].x[1], dtype=torch.float32) 135 | graph['tweet', 'cites', 'article'].edge_index = torch.tensor(graph['tweet', 'cites', 'article'].edge_index, dtype=torch.long) 136 | if include_users: 137 | if include_text and np.asarray(graph['user'].x[1]).shape[0] > 0: 138 | graph['user'].x = torch.tensor(np.concatenate((text_embeddings(graph['user'].x[0]), np.asarray(graph['user'].x[1])), axis=1), dtype=torch.float32) 139 | else: 140 | graph['user'].x = torch.tensor(graph['user'].x[1], dtype=torch.float32) 141 | graph['user', 'posts', 'tweet'].edge_index = torch.tensor(graph['user', 'posts', 'tweet'].edge_index, dtype=torch.long) 142 | if include_user_followers or include_user_following: 143 | graph['user', 'follows', 'user'].edge_index = torch.tensor(graph['user', 'follows', 'user'].edge_index, dtype=torch.long) 144 | if include_retweets: 145 | graph['tweet', 'retweets', 'tweet'].edge_index = torch.tensor(graph['tweet', 'retweets', 'tweet'].edge_index, dtype=torch.long) 146 | graph = graph.coalesce() 147 | if to_undirected: 148 | graph = T.ToUndirected(merge=False)(graph) 149 | return graph 150 | 151 | 152 | 153 | 154 | def create_evolving_heterogeneous_snapshot(news_id_dict, tweet_id_list, start_time, end_time, 155 | dataset='politifact', include_tweets=True, include_users=True, 156 | include_user_timeline_tweets=True, include_retweets=True, 157 | include_user_followers=True, include_user_following=True, 158 | add_new_users=False, to_undirected=True, include_text=False): 159 | node_ids = {'article': [], 160 | 'tweet': [], 161 | 'user': []} 162 | graph = HeteroData() 163 | graph['article'].x = [] 164 | graph['article'].y = [] 165 | if include_tweets: 166 | graph['tweet'].x = [[], []] 167 | graph['tweet', 'cites', 'article'].edge_index = [[], []] 168 | if include_users: 169 | graph['user'].x = [[], []] 170 | graph['user', 'posts', 'tweet'].edge_index = [[], []] 171 | if include_user_followers or include_user_following: 172 | graph['user', 'follows', 'user'].edge_index = [[], []] 173 | if include_retweets: 174 | graph['tweet', 'retweets', 'tweet'].edge_index = [[], []] 175 | 176 | for subset, news_id in news_id_dict.items(): 177 | if content_available(news_id=news_id, dataset=dataset, subset=subset): 178 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 179 | # news node feature 180 | graph['article'].x.append(get_news_features(news_content)) 181 | if subset == 'fake': 182 | graph['article'].y.append(1) 183 | elif subset == 'real': 184 | graph['article'].y.append(0) 185 | node_ids['article'].append(news_id) 186 | if include_tweets: 187 | tweets_path, tweet_list = get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset) 188 | for tweet in tweet_list: 189 | tweet_data = open_tweet_json(tweets_path, tweet) 190 | if tweet in tweet_id_list: 191 | if tweet_data['id'] not in node_ids['tweet']: 192 | graph['tweet'].x[0].append(get_tweet_features(tweet_data)[0]) 193 | graph['tweet'].x[1].append(get_tweet_features(tweet_data)[1:]) 194 | node_ids['tweet'].append(tweet_data['id']) 195 | node_id_news = node_ids['article'].index(news_id) 196 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 197 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 198 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 199 | else: 200 | node_id_news = node_ids['article'].index(news_id) 201 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 202 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 203 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 204 | 205 | if include_users: 206 | user_information = get_user_information(tweet_data['user']['id']) 207 | if user_information: 208 | if tweet in tweet_id_list: 209 | if user_information['id'] not in node_ids['user']: 210 | graph['user'].x[0].append(get_user_features(user_information)[0]) 211 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 212 | node_ids['user'].append(user_information['id']) 213 | node_id_user = node_ids['user'].index(user_information['id']) 214 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 215 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 216 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 217 | else: 218 | node_id_user = node_ids['user'].index(user_information['id']) 219 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 220 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 221 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 222 | 223 | if include_retweets: 224 | retweets_path, retweet_ids = get_retweet_ids(news_id=news_id, dataset=dataset, subset=subset) 225 | if tweet in tweet_id_list and tweet in retweet_ids: 226 | retweets_data = open_retweet_json(retweets_path, tweet) 227 | for retweet in retweets_data: 228 | if retweet['id'] not in node_ids['tweet']: 229 | if start_time <= datetime.strptime(retweet['created_at'], 230 | '%a %b %d %H:%M:%S +0000 %Y') < end_time: 231 | # retweets node features 232 | graph['tweet'].x[0].append(get_tweet_features(retweet)[0]) 233 | graph['tweet'].x[1].append(get_tweet_features(retweet)[1:]) 234 | node_ids['tweet'].append(retweet['id']) 235 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 236 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 237 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 238 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 239 | else: 240 | if start_time <= datetime.strptime(retweet['created_at'], 241 | '%a %b %d %H:%M:%S +0000 %Y') < end_time: 242 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 243 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 244 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 245 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 246 | 247 | if include_users: 248 | user_information = get_user_information(retweet['user']['id']) 249 | if user_information: 250 | if tweet in tweet_id_list and retweet['id'] in node_ids['tweet']: 251 | if user_information['id'] not in node_ids['user']: 252 | graph['user'].x[0].append(get_user_features(user_information)[0]) 253 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 254 | node_ids['user'].append(user_information['id']) 255 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 256 | node_id_user = node_ids['user'].index(user_information['id']) 257 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 258 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 259 | else: 260 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 261 | node_id_user = node_ids['user'].index(user_information['id']) 262 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 263 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 264 | else: 265 | print(f"[WARNING] excluding sample with id {news_id} no news or tweets available") 266 | graph['article'].x = torch.tensor(graph['article'].x, dtype=torch.float32) 267 | return graph 268 | 269 | graph['article'].x = torch.tensor(text_embeddings(graph['article'].x), dtype=torch.float32) 270 | graph['article'].y = torch.tensor(graph['article'].y, dtype=torch.long) 271 | if include_tweets: 272 | if include_text and np.asarray(graph['tweet'].x[1]).shape[0] > 0: 273 | graph['tweet'].x = torch.tensor( 274 | np.concatenate((text_embeddings(graph['tweet'].x[0]), np.asarray(graph['tweet'].x[1])), axis=1), 275 | dtype=torch.float32) 276 | else: 277 | graph['tweet'].x = torch.tensor(graph['tweet'].x[1], dtype=torch.float32) 278 | graph['tweet', 'cites', 'article'].edge_index = torch.tensor(graph['tweet', 'cites', 'article'].edge_index, 279 | dtype=torch.long) 280 | if include_users: 281 | if include_text and np.asarray(graph['user'].x[1]).shape[0] > 0: 282 | graph['user'].x = torch.tensor( 283 | np.concatenate((text_embeddings(graph['user'].x[0]), np.asarray(graph['user'].x[1])), axis=1), 284 | dtype=torch.float32) 285 | else: 286 | graph['user'].x = torch.tensor(graph['user'].x[1], dtype=torch.float32) 287 | graph['user', 'posts', 'tweet'].edge_index = torch.tensor(graph['user', 'posts', 'tweet'].edge_index, 288 | dtype=torch.long) 289 | if include_user_followers or include_user_following: 290 | graph['user', 'follows', 'user'].edge_index = torch.tensor(graph['user', 'follows', 'user'].edge_index, 291 | dtype=torch.long) 292 | if include_retweets: 293 | graph['tweet', 'retweets', 'tweet'].edge_index = torch.tensor(graph['tweet', 'retweets', 'tweet'].edge_index, 294 | dtype=torch.long) 295 | graph = graph.coalesce() 296 | if to_undirected: 297 | graph = T.ToUndirected(merge=False)(graph) 298 | return graph 299 | 300 | 301 | def create_fixed_heterogeneous_snapshot(news_id_dict, tweet_id_list, max_tweet_id_list, 302 | start_time, end_time, max_end_time, 303 | dataset='politifact', include_tweets=True, include_users=True, 304 | include_user_timeline_tweets=True, include_retweets=True, 305 | include_user_followers=True, include_user_following=True, 306 | add_new_users=False, to_undirected=True, include_text=False): 307 | node_ids = {'article': [], 308 | 'tweet': [], 309 | 'user': []} 310 | graph = HeteroData() 311 | graph['article'].x = [] 312 | graph['article'].y = [] 313 | if include_tweets: 314 | graph['tweet'].x = [[], []] 315 | graph['tweet', 'cites', 'article'].edge_index = [[], []] 316 | if include_users: 317 | graph['user'].x = [[], []] 318 | graph['user', 'posts', 'tweet'].edge_index = [[], []] 319 | if include_user_followers or include_user_following: 320 | graph['user', 'follows', 'user'].edge_index = [[], []] 321 | if include_retweets: 322 | graph['tweet', 'retweets', 'tweet'].edge_index = [[], []] 323 | 324 | for subset, news_id in news_id_dict.items(): 325 | if content_available(news_id=news_id, dataset=dataset, subset=subset): 326 | news_content = get_news_content(news_id=news_id, dataset=dataset, subset=subset) 327 | # news node feature 328 | graph['article'].x.append(get_news_features(news_content)) 329 | if subset == 'fake': 330 | graph['article'].y.append(1) 331 | elif subset == 'real': 332 | graph['article'].y.append(0) 333 | node_ids['article'].append(news_id) 334 | 335 | if include_tweets: 336 | tweets_path, tweet_list = get_news_tweet_ids(news_id=news_id, dataset=dataset, subset=subset) 337 | for tweet in tweet_list: 338 | tweet_data = open_tweet_json(tweets_path, tweet) 339 | if tweet in tweet_id_list: 340 | if tweet_data['id'] not in node_ids['tweet']: 341 | graph['tweet'].x[0].append(get_tweet_features(tweet_data)[0]) 342 | graph['tweet'].x[1].append(get_tweet_features(tweet_data)[1:]) 343 | node_ids['tweet'].append(tweet_data['id']) 344 | node_id_news = node_ids['article'].index(news_id) 345 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 346 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 347 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 348 | else: 349 | node_id_news = node_ids['article'].index(news_id) 350 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 351 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 352 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 353 | elif tweet in max_tweet_id_list: 354 | if tweet_data['id'] not in node_ids['tweet']: 355 | graph['tweet'].x[0].append('') 356 | graph['tweet'].x[1].append([0]*2) 357 | node_ids['tweet'].append(tweet_data['id']) 358 | node_id_news = node_ids['article'].index(news_id) 359 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 360 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 361 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 362 | else: 363 | node_id_news = node_ids['article'].index(news_id) 364 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 365 | graph['tweet', 'cites', 'article'].edge_index[0] += [node_id_tweet] 366 | graph['tweet', 'cites', 'article'].edge_index[1] += [node_id_news] 367 | 368 | if include_users: 369 | user_information = get_user_information(tweet_data['user']['id']) 370 | if user_information: 371 | if tweet in tweet_id_list: 372 | if user_information['id'] not in node_ids['user']: 373 | graph['user'].x[0].append(get_user_features(user_information)[0]) 374 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 375 | node_ids['user'].append(user_information['id']) 376 | node_id_user = node_ids['user'].index(user_information['id']) 377 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 378 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 379 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 380 | else: 381 | node_id_user = node_ids['user'].index(user_information['id']) 382 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 383 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 384 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 385 | elif tweet in max_tweet_id_list: 386 | if user_information['id'] not in node_ids['user']: 387 | graph['user'].x[0].append('') 388 | graph['user'].x[1].append([0]*4) 389 | node_ids['user'].append(user_information['id']) 390 | node_id_user = node_ids['user'].index(user_information['id']) 391 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 392 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 393 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 394 | else: 395 | node_id_user = node_ids['user'].index(user_information['id']) 396 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 397 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 398 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_tweet] 399 | 400 | if include_retweets: 401 | retweets_path, retweet_ids = get_retweet_ids(news_id=news_id, dataset=dataset, subset=subset) 402 | if tweet in tweet_id_list and tweet in retweet_ids: 403 | retweets_data = open_retweet_json(retweets_path, tweet) 404 | for retweet in retweets_data: 405 | if retweet['id'] not in node_ids['tweet']: 406 | if start_time <= datetime.strptime(retweet['created_at'], 407 | '%a %b %d %H:%M:%S +0000 %Y') < end_time: 408 | # retweets node features 409 | graph['tweet'].x[0].append(get_tweet_features(retweet)[0]) 410 | graph['tweet'].x[1].append(get_tweet_features(retweet)[1:]) 411 | node_ids['tweet'].append(retweet['id']) 412 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 413 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 414 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 415 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 416 | elif start_time <= datetime.strptime(retweet['created_at'], 417 | '%a %b %d %H:%M:%S +0000 %Y') < max_end_time: 418 | # retweets node features 419 | graph['tweet'].x[0].append('') 420 | graph['tweet'].x[1].append([0]*2) 421 | node_ids['tweet'].append(retweet['id']) 422 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 423 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 424 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 425 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 426 | else: 427 | if start_time <= datetime.strptime(retweet['created_at'], 428 | '%a %b %d %H:%M:%S +0000 %Y') < max_end_time: 429 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 430 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 431 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 432 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 433 | 434 | if include_users: 435 | user_information = get_user_information(retweet['user']['id']) 436 | if user_information: 437 | if tweet in tweet_id_list and retweet['id'] in node_ids['tweet']: 438 | if user_information['id'] not in node_ids['user']: 439 | graph['user'].x[0].append(get_user_features(user_information)[0]) 440 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 441 | node_ids['user'].append(user_information['id']) 442 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 443 | node_id_user = node_ids['user'].index(user_information['id']) 444 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 445 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 446 | else: 447 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 448 | node_id_user = node_ids['user'].index(user_information['id']) 449 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 450 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 451 | elif tweet in max_tweet_id_list and retweet['id'] in node_ids['tweet']: 452 | if user_information['id'] not in node_ids['user']: 453 | graph['user'].x[0].append('') 454 | graph['user'].x[1].append([0]*4) 455 | node_ids['user'].append(user_information['id']) 456 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 457 | node_id_user = node_ids['user'].index(user_information['id']) 458 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 459 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 460 | else: 461 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 462 | node_id_user = node_ids['user'].index(user_information['id']) 463 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 464 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 465 | elif tweet in max_tweet_id_list and tweet in retweet_ids: 466 | retweets_data = open_retweet_json(retweets_path, tweet) 467 | for retweet in retweets_data: 468 | if retweet['id'] not in node_ids['tweet']: 469 | if start_time <= datetime.strptime(retweet['created_at'], 470 | '%a %b %d %H:%M:%S +0000 %Y') < end_time: 471 | # retweets node features 472 | graph['tweet'].x[0].append(get_tweet_features(retweet)[0]) 473 | graph['tweet'].x[1].append(get_tweet_features(retweet)[1:]) 474 | node_ids['tweet'].append(retweet['id']) 475 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 476 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 477 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 478 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 479 | elif start_time <= datetime.strptime(retweet['created_at'], 480 | '%a %b %d %H:%M:%S +0000 %Y') < max_end_time: 481 | # retweets node features 482 | graph['tweet'].x[0].append('') 483 | graph['tweet'].x[1].append([0]*2) 484 | node_ids['tweet'].append(retweet['id']) 485 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 486 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 487 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 488 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 489 | else: 490 | if start_time <= datetime.strptime(retweet['created_at'], 491 | '%a %b %d %H:%M:%S +0000 %Y') < max_end_time: 492 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 493 | node_id_tweet = node_ids['tweet'].index(tweet_data['id']) 494 | graph['tweet', 'retweets', 'tweet'].edge_index[0] += [node_id_retweet] 495 | graph['tweet', 'retweets', 'tweet'].edge_index[1] += [node_id_tweet] 496 | 497 | if include_users: 498 | user_information = get_user_information(retweet['user']['id']) 499 | if user_information: 500 | if tweet in tweet_id_list and retweet['id'] in node_ids['tweet']: 501 | if user_information['id'] not in node_ids['user']: 502 | graph['user'].x[0].append(get_user_features(user_information)[0]) 503 | graph['user'].x[1].append(get_user_features(user_information)[1:]) 504 | node_ids['user'].append(user_information['id']) 505 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 506 | node_id_user = node_ids['user'].index(user_information['id']) 507 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 508 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 509 | else: 510 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 511 | node_id_user = node_ids['user'].index(user_information['id']) 512 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 513 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 514 | elif tweet in max_tweet_id_list and retweet['id'] in node_ids['tweet']: 515 | if user_information['id'] not in node_ids['user']: 516 | graph['user'].x[0].append('') 517 | graph['user'].x[1].append([0]*4) 518 | node_ids['user'].append(user_information['id']) 519 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 520 | node_id_user = node_ids['user'].index(user_information['id']) 521 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 522 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 523 | else: 524 | node_id_retweet = node_ids['tweet'].index(retweet['id']) 525 | node_id_user = node_ids['user'].index(user_information['id']) 526 | graph['user', 'posts', 'tweet'].edge_index[0] += [node_id_user] 527 | graph['user', 'posts', 'tweet'].edge_index[1] += [node_id_retweet] 528 | 529 | else: 530 | print(f"[WARNING] excluding sample with id {news_id} no news or tweets available") 531 | graph['article'].x = torch.tensor(graph['article'].x, dtype=torch.float32) 532 | return graph 533 | 534 | graph['article'].x = torch.tensor(text_embeddings(graph['article'].x), dtype=torch.float32) 535 | graph['article'].y = torch.tensor(graph['article'].y, dtype=torch.long) 536 | if include_tweets: 537 | if include_text and np.asarray(graph['tweet'].x[1]).shape[0] > 0: 538 | graph['tweet'].x = torch.tensor( 539 | np.concatenate((text_embeddings(graph['tweet'].x[0]), np.asarray(graph['tweet'].x[1])), axis=1), 540 | dtype=torch.float32) 541 | else: 542 | graph['tweet'].x = torch.tensor(graph['tweet'].x[1], dtype=torch.float32) 543 | graph['tweet', 'cites', 'article'].edge_index = torch.tensor(graph['tweet', 'cites', 'article'].edge_index, 544 | dtype=torch.long) 545 | if include_users: 546 | if include_text and np.asarray(graph['user'].x[1]).shape[0] > 0: 547 | graph['user'].x = torch.tensor( 548 | np.concatenate((text_embeddings(graph['user'].x[0]), np.asarray(graph['user'].x[1])), axis=1), 549 | dtype=torch.float32) 550 | else: 551 | graph['user'].x = torch.tensor(graph['user'].x[1], dtype=torch.float32) 552 | graph['user', 'posts', 'tweet'].edge_index = torch.tensor(graph['user', 'posts', 'tweet'].edge_index, 553 | dtype=torch.long) 554 | if include_user_followers or include_user_following: 555 | graph['user', 'follows', 'user'].edge_index = torch.tensor(graph['user', 'follows', 'user'].edge_index, 556 | dtype=torch.long) 557 | if include_retweets: 558 | graph['tweet', 'retweets', 'tweet'].edge_index = torch.tensor(graph['tweet', 'retweets', 'tweet'].edge_index, 559 | dtype=torch.long) 560 | graph = graph.coalesce() 561 | if to_undirected: 562 | graph = T.ToUndirected(merge=False)(graph) 563 | return graph 564 | -------------------------------------------------------------------------------- /temporal/temporal_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | from torch_geometric.nn import HeteroConv, SAGEConv, GATConv 4 | from torch_geometric.nn.inits import glorot 5 | import torch.nn as nn 6 | 7 | 8 | class HeteroGCLSTM(torch.nn.Module): 9 | r"""An implementation similar to the Integrated Graph Convolutional Long Short Term 10 | Memory Cell for heterogeneous Graphs. 11 | Args: 12 | in_channels_dict (dict of keys=str and values=int): Dimension of each node's input features. 13 | out_channels (int): Number of output features. 14 | metadata (tuple): Metadata on node types and edge types in the graphs. Can be generated via PyG method 15 | :obj:`snapshot.metadata()` where snapshot is a single HeteroData object. 16 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 17 | an additive bias. (default: :obj:`True`) 18 | """ 19 | 20 | def __init__( 21 | self, 22 | in_channels_dict: dict, 23 | out_channels: int, 24 | metadata: tuple, 25 | bias: bool = True 26 | ): 27 | super(HeteroGCLSTM, self).__init__() 28 | 29 | self.in_channels_dict = in_channels_dict 30 | self.out_channels = out_channels 31 | self.metadata = metadata 32 | self.bias = bias 33 | self._create_parameters_and_layers() 34 | self._set_parameters() 35 | 36 | def _create_input_gate_parameters_and_layers(self): 37 | self.conv_i = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 38 | out_channels=self.out_channels, 39 | bias=self.bias) for edge_type in self.metadata[1]}) 40 | 41 | self.W_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) 42 | for node_type, in_channels in self.in_channels_dict.items()}) 43 | self.b_i = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) 44 | for node_type in self.in_channels_dict}) 45 | 46 | def _create_forget_gate_parameters_and_layers(self): 47 | self.conv_f = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 48 | out_channels=self.out_channels, 49 | bias=self.bias) for edge_type in self.metadata[1]}) 50 | 51 | self.W_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) 52 | for node_type, in_channels in self.in_channels_dict.items()}) 53 | self.b_f = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) 54 | for node_type in self.in_channels_dict}) 55 | 56 | def _create_cell_state_parameters_and_layers(self): 57 | self.conv_c = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 58 | out_channels=self.out_channels, 59 | bias=self.bias) for edge_type in self.metadata[1]}) 60 | 61 | self.W_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) 62 | for node_type, in_channels in self.in_channels_dict.items()}) 63 | self.b_c = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) 64 | for node_type in self.in_channels_dict}) 65 | 66 | def _create_output_gate_parameters_and_layers(self): 67 | self.conv_o = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 68 | out_channels=self.out_channels, 69 | bias=self.bias) for edge_type in self.metadata[1]}) 70 | 71 | self.W_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(in_channels, self.out_channels)) 72 | for node_type, in_channels in self.in_channels_dict.items()}) 73 | self.b_o = nn.ParameterDict({node_type: Parameter(torch.Tensor(1, self.out_channels)) 74 | for node_type in self.in_channels_dict}) 75 | 76 | def _create_parameters_and_layers(self): 77 | self._create_input_gate_parameters_and_layers() 78 | self._create_forget_gate_parameters_and_layers() 79 | self._create_cell_state_parameters_and_layers() 80 | self._create_output_gate_parameters_and_layers() 81 | 82 | def _set_parameters(self): 83 | for key in self.W_i: 84 | glorot(self.W_i[key]) 85 | for key in self.W_f: 86 | glorot(self.W_f[key]) 87 | for key in self.W_c: 88 | glorot(self.W_c[key]) 89 | for key in self.W_o: 90 | glorot(self.W_o[key]) 91 | for key in self.b_i: 92 | glorot(self.b_i[key]) 93 | for key in self.b_f: 94 | glorot(self.b_f[key]) 95 | for key in self.b_c: 96 | glorot(self.b_c[key]) 97 | for key in self.b_o: 98 | glorot(self.b_o[key]) 99 | 100 | def _set_hidden_state(self, x_dict, h_dict): 101 | if h_dict is None: 102 | h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels).to('cuda') for node_type, X in x_dict.items()} 103 | return h_dict 104 | 105 | def _set_cell_state(self, x_dict, c_dict): 106 | if c_dict is None: 107 | c_dict = {node_type: torch.zeros(X.shape[0], self.out_channels).to('cuda') for node_type, X in x_dict.items()} 108 | return c_dict 109 | 110 | def _calculate_input_gate(self, x_dict, edge_index_dict, h_dict, c_dict): 111 | i_dict = {node_type: torch.matmul(X, self.W_i[node_type]) for node_type, X in x_dict.items()} 112 | conv_i = self.conv_i(h_dict, edge_index_dict) 113 | i_dict = {node_type: I + conv_i[node_type] for node_type, I in i_dict.items()} 114 | i_dict = {node_type: I + self.b_i[node_type] for node_type, I in i_dict.items()} 115 | i_dict = {node_type: torch.sigmoid(I) for node_type, I in i_dict.items()} 116 | return i_dict 117 | 118 | def _calculate_forget_gate(self, x_dict, edge_index_dict, h_dict, c_dict): 119 | f_dict = {node_type: torch.matmul(X, self.W_f[node_type]) for node_type, X in x_dict.items()} 120 | conv_f = self.conv_f(h_dict, edge_index_dict) 121 | f_dict = {node_type: F + conv_f[node_type] for node_type, F in f_dict.items()} 122 | f_dict = {node_type: F + self.b_f[node_type] for node_type, F in f_dict.items()} 123 | f_dict = {node_type: torch.sigmoid(F) for node_type, F in f_dict.items()} 124 | return f_dict 125 | 126 | def _calculate_cell_state(self, x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict): 127 | t_dict = {node_type: torch.matmul(X, self.W_c[node_type]) for node_type, X in x_dict.items()} 128 | conv_c = self.conv_c(h_dict, edge_index_dict) 129 | t_dict = {node_type: T + conv_c[node_type] for node_type, T in t_dict.items()} 130 | t_dict = {node_type: T + self.b_c[node_type] for node_type, T in t_dict.items()} 131 | t_dict = {node_type: torch.tanh(T) for node_type, T in t_dict.items()} 132 | c_dict = {node_type: f_dict[node_type] * C + i_dict[node_type] * t_dict[node_type] for node_type, C in c_dict.items()} 133 | return c_dict 134 | 135 | def _calculate_output_gate(self, x_dict, edge_index_dict, h_dict, c_dict): 136 | o_dict = {node_type: torch.matmul(X, self.W_o[node_type]) for node_type, X in x_dict.items()} 137 | conv_o = self.conv_o(h_dict, edge_index_dict) 138 | o_dict = {node_type: O + conv_o[node_type] for node_type, O in o_dict.items()} 139 | o_dict = {node_type: O + self.b_o[node_type] for node_type, O in o_dict.items()} 140 | o_dict = {node_type: torch.sigmoid(O) for node_type, O in o_dict.items()} 141 | return o_dict 142 | 143 | def _calculate_hidden_state(self, o_dict, c_dict): 144 | h_dict = {node_type: o_dict[node_type] * torch.tanh(C) for node_type, C in c_dict.items()} 145 | return h_dict 146 | 147 | def forward( 148 | self, 149 | x_dict, 150 | edge_index_dict, 151 | h_dict=None, 152 | c_dict=None, 153 | ): 154 | """ 155 | Making a forward pass. If the hidden state and cell state 156 | matrix dicts are not present when the forward pass is called these are 157 | initialized with zeros. 158 | Arg types: 159 | * **x_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensors)* - Node features dicts. Can 160 | be obtained via PyG method :obj:`snapshot.x_dict` where snapshot is a single HeteroData object. 161 | * **edge_index_dict** *(Dictionary where keys=Tuples and values=PyTorch Long Tensors)* - Graph edge type 162 | and index dicts. Can be obtained via PyG method :obj:`snapshot.edge_index_dict`. 163 | * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and 164 | hidden state matrix dict for all nodes. 165 | * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor, optional)* - Node type and 166 | cell state matrix dict for all nodes. 167 | Return types: 168 | * **h_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and 169 | hidden state matrix dict for all nodes. 170 | * **c_dict** *(Dictionary where keys=Strings and values=PyTorch Float Tensor)* - Node type and 171 | cell state matrix dict for all nodes. 172 | """ 173 | 174 | h_dict = self._set_hidden_state(x_dict, h_dict) 175 | c_dict = self._set_cell_state(x_dict, c_dict) 176 | i_dict = self._calculate_input_gate(x_dict, edge_index_dict, h_dict, c_dict) 177 | f_dict = self._calculate_forget_gate(x_dict, edge_index_dict, h_dict, c_dict) 178 | c_dict = self._calculate_cell_state(x_dict, edge_index_dict, h_dict, c_dict, i_dict, f_dict) 179 | o_dict = self._calculate_output_gate(x_dict, edge_index_dict, h_dict, c_dict) 180 | h_dict = self._calculate_hidden_state(o_dict, c_dict) 181 | return h_dict, c_dict 182 | 183 | 184 | class HeteroGConvGRU(torch.nn.Module): 185 | def __init__( 186 | self, 187 | out_channels: int, 188 | metadata: tuple, 189 | bias: bool = True 190 | ): 191 | super(HeteroGConvGRU, self).__init__() 192 | 193 | self.out_channels = out_channels 194 | self.metadata = metadata 195 | self.bias = bias 196 | self._create_parameters_and_layers() 197 | 198 | def _create_update_gate_parameters_and_layers(self): 199 | 200 | self.conv_x_z = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 201 | out_channels=self.out_channels, 202 | bias=self.bias) for edge_type in self.metadata[1]}) 203 | 204 | self.conv_h_z = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 205 | out_channels=self.out_channels, 206 | bias=self.bias) for edge_type in self.metadata[1]}) 207 | 208 | def _create_reset_gate_parameters_and_layers(self): 209 | 210 | self.conv_x_r = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 211 | out_channels=self.out_channels, 212 | bias=self.bias) for edge_type in self.metadata[1]}) 213 | 214 | self.conv_h_r = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 215 | out_channels=self.out_channels, 216 | bias=self.bias) for edge_type in self.metadata[1]}) 217 | 218 | def _create_candidate_state_parameters_and_layers(self): 219 | 220 | self.conv_x_h = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 221 | out_channels=self.out_channels, 222 | bias=self.bias) for edge_type in self.metadata[1]}) 223 | 224 | self.conv_h_h = HeteroConv({edge_type: SAGEConv(in_channels=(-1, -1), 225 | out_channels=self.out_channels, 226 | bias=self.bias) for edge_type in self.metadata[1]}) 227 | 228 | def _create_parameters_and_layers(self): 229 | self._create_update_gate_parameters_and_layers() 230 | self._create_reset_gate_parameters_and_layers() 231 | self._create_candidate_state_parameters_and_layers() 232 | 233 | def _set_hidden_state(self, x_dict, h_dict): 234 | if h_dict is None: 235 | h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels).to('cuda') for node_type, X in x_dict.items()} 236 | return h_dict 237 | 238 | def _calculate_update_gate(self, x_dict, edge_index_dict, h_dict): 239 | z_dict = self.conv_x_z(x_dict, edge_index_dict) 240 | conv_h_z = self.conv_h_z(h_dict, edge_index_dict) 241 | z_dict = {node_type: Z + conv_h_z[node_type] for node_type, Z in z_dict.items()} 242 | z_dict = {node_type: torch.sigmoid(Z) for node_type, Z in z_dict.items()} 243 | return z_dict 244 | 245 | def _calculate_reset_gate(self, x_dict, edge_index_dict, h_dict): 246 | r_dict = self.conv_x_r(x_dict, edge_index_dict) 247 | conv_h_r = self.conv_h_r(h_dict, edge_index_dict) 248 | r_dict = {node_type: R + conv_h_r[node_type] for node_type, R in r_dict.items()} 249 | r_dict = {node_type: torch.sigmoid(R) for node_type, R in r_dict.items()} 250 | return r_dict 251 | 252 | def _calculate_candidate_state(self, x_dict, edge_index_dict, h_dict, r_dict): 253 | h_tilde_dict = self.conv_x_h(x_dict, edge_index_dict) 254 | h_tilde_dict = {node_type: h_tilde for node_type, h_tilde in h_tilde_dict.items()} 255 | h_r_dict = {node_type: H * r_dict[node_type] for node_type, H in h_dict.items()} 256 | conv_h_h = self.conv_h_h(h_r_dict, edge_index_dict) 257 | h_tilde_dict = {node_type: h_tilde + conv_h_h[node_type] for node_type, h_tilde in h_tilde_dict.items()} 258 | h_tilde_dict = {node_type: torch.tanh(h_tilde) for node_type, h_tilde in h_tilde_dict.items()} 259 | return h_tilde_dict 260 | 261 | def _calculate_hidden_state(self, z_dict, h_dict, h_tilde_dict): 262 | h_dict = {node_type: z_dict[node_type] * H + (1 - z_dict[node_type]) * h_tilde_dict[node_type] for node_type, H in h_dict.items()} 263 | return h_dict 264 | 265 | def forward( 266 | self, 267 | x_dict, 268 | edge_index_dict, 269 | h_dict=None, 270 | ): 271 | h_dict = self._set_hidden_state(x_dict, h_dict) 272 | z_dict = self._calculate_update_gate(x_dict, edge_index_dict, h_dict) 273 | r_dict = self._calculate_reset_gate(x_dict, edge_index_dict, h_dict) 274 | h_tilde_dict = self._calculate_candidate_state(x_dict, edge_index_dict, h_dict, r_dict) 275 | h_dict = self._calculate_hidden_state(z_dict, h_dict, h_tilde_dict) 276 | return h_dict 277 | 278 | 279 | class HeteroGATGRU(torch.nn.Module): 280 | def __init__( 281 | self, 282 | out_channels: int, 283 | metadata: tuple, 284 | num_attention_heads: int = 3, 285 | bias: bool = True 286 | ): 287 | super(HeteroGATGRU, self).__init__() 288 | 289 | self.out_channels = out_channels 290 | self.metadata = metadata 291 | self.num_attention_heads = num_attention_heads 292 | self.bias = bias 293 | self._create_parameters_and_layers() 294 | 295 | def _create_update_gate_parameters_and_layers(self): 296 | 297 | self.conv_x_z = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 298 | out_channels=self.out_channels, 299 | heads=self.num_attention_heads, 300 | add_self_loops=False) for edge_type in self.metadata[1]}) 301 | 302 | self.conv_h_z = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 303 | out_channels=self.out_channels, 304 | heads=self.num_attention_heads, 305 | add_self_loops=False) for edge_type in self.metadata[1]}) 306 | 307 | def _create_reset_gate_parameters_and_layers(self): 308 | 309 | self.conv_x_r = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 310 | out_channels=self.out_channels, 311 | heads=self.num_attention_heads, 312 | add_self_loops=False) for edge_type in self.metadata[1]}) 313 | 314 | self.conv_h_r = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 315 | out_channels=self.out_channels, 316 | heads=self.num_attention_heads, 317 | add_self_loops=False) for edge_type in self.metadata[1]}) 318 | 319 | def _create_candidate_state_parameters_and_layers(self): 320 | 321 | self.conv_x_h = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 322 | out_channels=self.out_channels, 323 | heads=self.num_attention_heads, 324 | add_self_loops=False) for edge_type in self.metadata[1]}) 325 | 326 | self.conv_h_h = HeteroConv({edge_type: GATConv(in_channels=(-1, -1), 327 | out_channels=self.out_channels, 328 | heads=self.num_attention_heads, 329 | add_self_loops=False) for edge_type in self.metadata[1]}) 330 | 331 | def _create_parameters_and_layers(self): 332 | self._create_update_gate_parameters_and_layers() 333 | self._create_reset_gate_parameters_and_layers() 334 | self._create_candidate_state_parameters_and_layers() 335 | 336 | def _set_hidden_state(self, x_dict, h_dict): 337 | if h_dict is None: 338 | h_dict = {node_type: torch.zeros(X.shape[0], self.out_channels*self.num_attention_heads).to('cuda') 339 | for node_type, X in x_dict.items()} 340 | return h_dict 341 | 342 | def _calculate_update_gate(self, x_dict, edge_index_dict, h_dict): 343 | z_dict = self.conv_x_z(x_dict, edge_index_dict) 344 | conv_h_z = self.conv_h_z(h_dict, edge_index_dict) 345 | z_dict = {node_type: Z + conv_h_z[node_type] for node_type, Z in z_dict.items()} 346 | z_dict = {node_type: torch.sigmoid(Z) for node_type, Z in z_dict.items()} 347 | return z_dict 348 | 349 | def _calculate_reset_gate(self, x_dict, edge_index_dict, h_dict): 350 | r_dict = self.conv_x_r(x_dict, edge_index_dict) 351 | conv_h_r = self.conv_h_r(h_dict, edge_index_dict) 352 | r_dict = {node_type: R + conv_h_r[node_type] for node_type, R in r_dict.items()} 353 | r_dict = {node_type: torch.sigmoid(R) for node_type, R in r_dict.items()} 354 | return r_dict 355 | 356 | def _calculate_candidate_state(self, x_dict, edge_index_dict, h_dict, r_dict): 357 | h_tilde_dict = self.conv_x_h(x_dict, edge_index_dict) 358 | h_tilde_dict = {node_type: h_tilde for node_type, h_tilde in h_tilde_dict.items()} 359 | h_r_dict = {node_type: H * r_dict[node_type] for node_type, H in h_dict.items()} 360 | conv_h_h = self.conv_h_h(h_r_dict, edge_index_dict) 361 | h_tilde_dict = {node_type: h_tilde + conv_h_h[node_type] for node_type, h_tilde in h_tilde_dict.items()} 362 | h_tilde_dict = {node_type: torch.tanh(h_tilde) for node_type, h_tilde in h_tilde_dict.items()} 363 | return h_tilde_dict 364 | 365 | def _calculate_hidden_state(self, z_dict, h_dict, h_tilde_dict): 366 | h_dict = {node_type: z_dict[node_type] * H + (1 - z_dict[node_type]) * h_tilde_dict[node_type] 367 | for node_type, H in h_dict.items()} 368 | return h_dict 369 | 370 | def forward( 371 | self, 372 | x_dict, 373 | edge_index_dict, 374 | h_dict=None, 375 | ): 376 | h_dict = self._set_hidden_state(x_dict, h_dict) 377 | z_dict = self._calculate_update_gate(x_dict, edge_index_dict, h_dict) 378 | r_dict = self._calculate_reset_gate(x_dict, edge_index_dict, h_dict) 379 | h_tilde_dict = self._calculate_candidate_state(x_dict, edge_index_dict, h_dict, r_dict) 380 | h_dict = self._calculate_hidden_state(z_dict, h_dict, h_tilde_dict) 381 | return h_dict 382 | --------------------------------------------------------------------------------