├── data ├── ml_UCI-Msg.npy └── ml_UCI-Msg_node.npy ├── train_example.sh ├── .idea ├── vcs.xml ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── .gitignore ├── misc.xml ├── modules.xml ├── DGNN.iml ├── remote-mappings.xml └── deployment.xml ├── README.md ├── modules ├── merge.py ├── message_function.py ├── memory.py ├── message_aggregator.py ├── propagater.py └── update.py ├── models ├── time_encoding.py └── dgnn.py ├── LICENSE ├── utils ├── preprocess_data.py ├── preprocess_uci_msg.py ├── utils.py └── data_processing.py ├── evaluation.py ├── train.py └── trainUCI.py /data/ml_UCI-Msg.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyd1502/DGNN/HEAD/data/ml_UCI-Msg.npy -------------------------------------------------------------------------------- /data/ml_UCI-Msg_node.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wyd1502/DGNN/HEAD/data/ml_UCI-Msg_node.npy -------------------------------------------------------------------------------- /train_example.sh: -------------------------------------------------------------------------------- 1 | python trainUCI.py --bs 100 --prefix ucitest2.3 --n_runs 15 --different_new_nodes --gpu 1 --lr 0.0000005 --patience 25 --threshold 25 -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../../:\Users\41374\Downloads\DGNN\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/DGNN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/remote-mappings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DGNN 2 | Pytorch implementation of "Streaming Graph Neural Network" 3 | (https://arxiv.org/abs/1810.10627) 4 | (I'm not author just trying to reproduce the experiment results) 5 | 6 | 7 | The best result of my implementation on UCI dataset is : 8 | mrr: 0.0259, recall@20: 0.1276, recall@50:0.2078. 9 | 10 | The result in orgin paper is : 11 | mrr:0.0342 recall@20:0.1284 recall@50: 0.2547. 12 | 13 | Any changes to the existing implementation are welcome. 14 | 15 | The main difficulty is how to process several events in one batch(batch size > 1). I borrow the message aggregation idea from Tgn(http://arxiv.org/abs/2006.10637) to solve the problem. Besides, I have not used 4 propagation module as the paper does but 2 in my implementation. 16 | 17 | -------------------------------------------------------------------------------- /modules/merge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from collections import defaultdict 5 | 6 | 7 | class MemoryMerge(nn.Module): 8 | def __init__(self, memory_dimension, device='cpu'): 9 | super(MemoryMerge, self).__init__() 10 | self.device = device 11 | self.W1 = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 12 | self.W2 = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 13 | self.bias = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 14 | self.act = torch.nn.ReLU() 15 | 16 | def forward(self, memory_s, memory_g): 17 | return torch.matmul(memory_s, self.W1)+torch.matmul(memory_g, self.W2) + self.bias 18 | -------------------------------------------------------------------------------- /models/time_encoding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class TimeEncode(torch.nn.Module): 6 | # Time Encoding proposed by TGAT 7 | def __init__(self, dimension): 8 | super(TimeEncode, self).__init__() 9 | 10 | self.dimension = dimension 11 | self.w = torch.nn.Linear(1, dimension) 12 | 13 | self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))) 14 | .float().reshape(dimension, -1)) 15 | self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float()) 16 | 17 | def forward(self, t): 18 | # t has shape [batch_size, seq_len] 19 | # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1] 20 | t = t.unsqueeze(dim=2) 21 | 22 | # output has shape [batch_size, seq_len, dimension] 23 | output = torch.cos(self.w(t)) 24 | 25 | return output 26 | -------------------------------------------------------------------------------- /modules/message_function.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class MessageFunction(nn.Module): 6 | def __init__(self, memory_dimension, message_dimension, edge_dimension=0, device="cpu"): 7 | super(MessageFunction, self).__init__() 8 | self.device = device 9 | self.W1 = nn.Parameter(torch.zeros((memory_dimension, message_dimension,)).to(self.device)) 10 | self.W2 = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device)) 11 | self.W3 = nn.Parameter(torch.zeros((edge_dimension, message_dimension)).to(self.device)) 12 | nn.init.xavier_uniform_(self.W1) 13 | nn.init.xavier_uniform_(self.W2) 14 | self.bias = nn.Parameter(torch.zeros(message_dimension).to(self.device)) 15 | self.act = nn.ReLU() 16 | 17 | def compute_message(self, memory_s, memory_g, edge_fea=None): 18 | messages = self.act(torch.matmul(memory_s, self.W1)+torch.matmul(memory_g, self.W2,) + 19 | torch.matmul(edge_fea, self.W3)+self.bias) 20 | #messages = self.bn1(messages) 21 | return messages 22 | 23 | 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 wyd1502 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 31 | -------------------------------------------------------------------------------- /utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | import argparse 6 | 7 | 8 | def preprocess(data_name): 9 | u_list, i_list, ts_list, label_list = [], [], [], [] 10 | feat_l = [] 11 | idx_list = [] 12 | 13 | with open(data_name) as f: 14 | s = next(f) 15 | for idx, line in enumerate(f): 16 | e = line.strip().split(',') 17 | u = int(e[0]) 18 | i = int(e[1]) 19 | 20 | ts = float(e[2]) 21 | label = float(e[3]) # int(e[3]) 22 | 23 | feat = np.array([float(x) for x in e[4:]]) 24 | 25 | u_list.append(u) 26 | i_list.append(i) 27 | ts_list.append(ts) 28 | label_list.append(label) 29 | idx_list.append(idx) 30 | 31 | feat_l.append(feat) 32 | return pd.DataFrame({'u': u_list, 33 | 'i': i_list, 34 | 'ts': ts_list, 35 | 'label': label_list, 36 | 'idx': idx_list}), np.array(feat_l) 37 | 38 | 39 | def reindex(df, bipartite=True): 40 | new_df = df.copy() 41 | if bipartite: 42 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 43 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 44 | 45 | upper_u = df.u.max() + 1 46 | new_i = df.i + upper_u 47 | 48 | new_df.i = new_i 49 | new_df.u += 1 50 | new_df.i += 1 51 | new_df.idx += 1 52 | else: 53 | new_df.u += 1 54 | new_df.i += 1 55 | new_df.idx += 1 56 | 57 | return new_df 58 | 59 | 60 | def run(data_name, bipartite=True): 61 | Path("data/").mkdir(parents=True, exist_ok=True) 62 | PATH = './data/{}.csv'.format(data_name) 63 | OUT_DF = './data/ml_{}.csv'.format(data_name) 64 | OUT_FEAT = './data/ml_{}.npy'.format(data_name) 65 | OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) 66 | 67 | df, feat = preprocess(PATH) 68 | new_df = reindex(df, bipartite) 69 | 70 | empty = np.zeros(feat.shape[1])[np.newaxis, :] 71 | feat = np.vstack([empty, feat]) 72 | 73 | max_idx = max(new_df.u.max(), new_df.i.max()) 74 | rand_feat = np.zeros((max_idx + 1, 172)) 75 | 76 | new_df.to_csv(OUT_DF) 77 | np.save(OUT_FEAT, feat) 78 | np.save(OUT_NODE_FEAT, rand_feat) 79 | 80 | parser = argparse.ArgumentParser('Interface for DGNN data preprocessing') 81 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', 82 | default='wikipedia') 83 | 84 | args = parser.parse_args() 85 | 86 | run(args.data) -------------------------------------------------------------------------------- /utils/preprocess_uci_msg.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | from pathlib import Path 5 | import argparse 6 | 7 | 8 | def preprocess(data_name): 9 | u_list, i_list, ts_list, label_list = [], [], [], [] 10 | feat_l = [] 11 | idx_list = [] 12 | 13 | with open(data_name) as f: 14 | s = next(f) 15 | for idx, line in enumerate(f): 16 | e = line.strip().split(' ') 17 | u = int(e[0]) 18 | i = int(e[1]) 19 | 20 | ts = float(e[2]) 21 | label = float(0) # int(e[3]) 22 | 23 | feat = np.array([0, 0, 0, 0, 0, 0]) 24 | 25 | u_list.append(u) 26 | i_list.append(i) 27 | ts_list.append(ts) 28 | label_list.append(label) 29 | idx_list.append(idx) 30 | 31 | feat_l.append(feat) 32 | return pd.DataFrame({'u': u_list, 33 | 'i': i_list, 34 | 'ts': ts_list, 35 | 'label': label_list, 36 | 'idx': idx_list}), np.array(feat_l) 37 | 38 | 39 | def reindex(df, bipartite=True): 40 | new_df = df.copy() 41 | new_df.ts = df.ts - df.ts.min() 42 | if bipartite: 43 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 44 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 45 | 46 | upper_u = df.u.max() + 1 47 | new_i = df.i + upper_u 48 | 49 | new_df.i = new_i 50 | new_df.u += 1 51 | new_df.i += 1 52 | new_df.idx += 1 53 | else: 54 | new_df.u += 1 55 | new_df.i += 1 56 | new_df.idx += 1 57 | 58 | return new_df 59 | 60 | 61 | def run(data_name, bipartite=True): 62 | Path("data/").mkdir(parents=True, exist_ok=True) 63 | PATH = './data/{}.txt'.format(data_name) 64 | OUT_DF = './data/ml_{}.csv'.format(data_name) 65 | OUT_FEAT = './data/ml_{}.npy'.format(data_name) 66 | OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) 67 | 68 | df, feat = preprocess(PATH) 69 | new_df = reindex(df, bipartite) 70 | 71 | empty = np.zeros(feat.shape[1])[np.newaxis, :] 72 | feat = np.vstack([empty, feat]) 73 | 74 | max_idx = max(new_df.u.max(), new_df.i.max()) 75 | rand_feat = np.zeros((max_idx + 1, 100)) 76 | 77 | new_df.to_csv(OUT_DF) 78 | np.save(OUT_FEAT, feat) 79 | np.save(OUT_NODE_FEAT, rand_feat) 80 | 81 | parser = argparse.ArgumentParser('Interface for DGNN data preprocessing') 82 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', 83 | default='UCI-Msg') 84 | 85 | args = parser.parse_args() 86 | 87 | run(args.data, False) -------------------------------------------------------------------------------- /modules/memory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from collections import defaultdict 5 | from copy import deepcopy 6 | 7 | 8 | class Memory(nn.Module): 9 | 10 | def __init__(self, n_nodes, memory_dimension, message_dimension=None, 11 | device="cpu"): 12 | super(Memory, self).__init__() 13 | self.n_nodes = n_nodes 14 | self.memory_dimension = memory_dimension 15 | self.message_dimension = message_dimension 16 | self.device = device 17 | 18 | 19 | self.__init_memory__() 20 | 21 | def __init_memory__(self, seed=0): 22 | """ 23 | Initializes the memory to all zeros. It should be called at the start of each epoch. 24 | """ 25 | # Treat memory as parameter so that it is saved and loaded together with the model 26 | torch.manual_seed(seed) 27 | self.cell = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device), 28 | requires_grad=False) 29 | self.hidden = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device), 30 | requires_grad=False) 31 | self.memory = [self.cell, self.hidden] 32 | self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device), 33 | requires_grad=False) 34 | nn.init.xavier_normal(self.hidden) 35 | nn.init.xavier_normal(self.cell) 36 | self.messages = defaultdict(list) 37 | 38 | def store_raw_messages(self, nodes, node_id_to_messages): 39 | for node in nodes: 40 | self.messages[node].extend(node_id_to_messages[node]) 41 | 42 | def get_memory(self, node_idxs): 43 | return [self.memory[i][node_idxs, :] for i in range(2)] 44 | 45 | def set_cell(self, node_idxs, values): 46 | self.cell[node_idxs, :] = values 47 | 48 | def set_hidden(self, node_idxs, values): 49 | self.hidden[node_idxs, :] = values 50 | 51 | def get_last_update(self, node_idxs): 52 | return self.last_update[node_idxs] 53 | 54 | def backup_memory(self): 55 | messages_clone = {} 56 | for k, v in self.messages.items(): 57 | messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v] 58 | 59 | return [self.memory[i].data.clone() for i in range(2)], self.last_update.data.clone(), messages_clone 60 | 61 | def restore_memory(self, memory_backup): 62 | self.cell.data, self.hidden.data, self.last_update.data = \ 63 | memory_backup[0][0].clone(), memory_backup[0][1].clone(), memory_backup[1].clone() 64 | self.messages = defaultdict(list) 65 | for k, v in memory_backup[2].items(): 66 | self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v] 67 | 68 | def detach_memory(self): 69 | self.hidden.detach_() 70 | self.cell.detach_() 71 | 72 | # Detach all stored messages 73 | for k, v in self.messages.items(): 74 | new_node_messages = [] 75 | for message in v: 76 | new_node_messages.append((message[0].detach(), message[1])) 77 | 78 | self.messages[k] = new_node_messages 79 | 80 | def clear_messages(self, nodes): 81 | for node in nodes: 82 | self.messages[node] = [] 83 | -------------------------------------------------------------------------------- /modules/message_aggregator.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import torch 3 | import numpy as np 4 | 5 | 6 | class MessageAggregator(torch.nn.Module): 7 | """ 8 | Abstract class for the message aggregator module, which given a batch of node ids and 9 | corresponding messages, aggregates messages with the same node id. 10 | """ 11 | def __init__(self, device): 12 | super(MessageAggregator, self).__init__() 13 | self.device = device 14 | 15 | def aggregate(self, node_ids, messages=None): 16 | """ 17 | Given a list of node ids, and a list of messages of the same length, aggregate different 18 | messages for the same id using one of the possible strategies. 19 | :param node_ids: A list of node ids of length batch_size 20 | :param messages: A tensor of shape [batch_size, message_length] 21 | :param timestamps A tensor of shape [batch_size] 22 | :return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages 23 | """ 24 | 25 | def group_by_id(self, node_ids, messages, timestamps): 26 | node_id_to_messages = defaultdict(list) 27 | 28 | for i, node_id in enumerate(node_ids): 29 | node_id_to_messages[node_id].append((messages[i], timestamps[i])) 30 | 31 | return node_id_to_messages 32 | 33 | 34 | class LastMessageAggregator(MessageAggregator): 35 | def __init__(self, device): 36 | super(LastMessageAggregator, self).__init__(device) 37 | 38 | def aggregate(self, node_ids, messages=None): 39 | """Only keep the last message for each node""" 40 | unique_node_ids = np.unique(node_ids) 41 | unique_messages = [] 42 | unique_timestamps = [] 43 | 44 | to_update_node_ids = [] 45 | 46 | for node_id in unique_node_ids: 47 | if len(messages[node_id]) > 0: 48 | to_update_node_ids.append(node_id) 49 | unique_messages.append(messages[node_id][-1][0]) 50 | unique_timestamps.append(messages[node_id][-1][1]) 51 | 52 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] 53 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] 54 | 55 | return to_update_node_ids, unique_messages, unique_timestamps 56 | 57 | 58 | class MeanMessageAggregator(MessageAggregator): 59 | def __init__(self, device): 60 | super(MeanMessageAggregator, self).__init__(device) 61 | 62 | def aggregate(self, node_ids, messages=None): 63 | """Only keep the last message for each node""" 64 | unique_node_ids = np.unique(node_ids) 65 | unique_messages = [] 66 | unique_timestamps = [] 67 | 68 | to_update_node_ids = [] 69 | n_messages = 0 70 | 71 | for node_id in unique_node_ids: 72 | if len(messages[node_id]) > 0: 73 | n_messages += len(messages[node_id]) 74 | to_update_node_ids.append(node_id) 75 | unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0)) 76 | unique_timestamps.append(messages[node_id][-1][1]) 77 | 78 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] 79 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] 80 | 81 | return to_update_node_ids, unique_messages, unique_timestamps 82 | 83 | 84 | def get_message_aggregator(aggregator_type, device): 85 | if aggregator_type == "last": 86 | return LastMessageAggregator(device=device) 87 | elif aggregator_type == "mean": 88 | return MeanMessageAggregator(device=device) 89 | else: 90 | raise ValueError("Message aggregator {} not implemented".format(aggregator_type)) 91 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | from sklearn.metrics import average_precision_score, roc_auc_score 6 | 7 | def choose_target(model,memory_s, memory_g, src_mem): 8 | u = model.memory_merge(memory_s[1], memory_g[1]) #[num_nodes,mem_d] 9 | u_norm = torch.norm(u, dim=1) #[num_nodes, 1] 10 | u_normalized = u/u_norm.view(-1, 1) #[num_nodes,mem_d] 11 | src_mem_norm = torch.norm(src_mem, dim=1) #[batch_size, 1] 12 | src_mem_normalized = src_mem / src_mem_norm.view(-1, 1) #[batch_size, mem_d] 13 | cos_similarity = torch.matmul(src_mem_normalized, u_normalized.t()) #[batch_size, num_nodes] 14 | cos_similarity, idx = torch.sort(cos_similarity, descending=True) 15 | return cos_similarity, idx 16 | 17 | def recall(des_node, idx, top_k): 18 | bs = idx.shape[0] 19 | idx = idx[:, :top_k] #[bs,top_k] 20 | recall = np.array([a in idx[i] for i, a in enumerate(des_node)])#[bs,1] 21 | recall = recall.sum() / recall.size 22 | return recall 23 | 24 | def MRR(des_node, idx): 25 | bs = idx.shape[0] 26 | mrr = np.array([float(np.where(idx[i].cpu() == a)[0] + 1) for i, a in enumerate(des_node)])#[bs,1] 27 | mrr = (1 / mrr).mean() 28 | return mrr 29 | 30 | 31 | def eval_edge_prediction(model, negative_edge_sampler, data, n_neighbors, batch_size=200): 32 | # Ensures the random sampler uses a seed for evaluation (i.e. we sample always the same 33 | # negatives for validation / test set) 34 | assert negative_edge_sampler.seed is not None 35 | negative_edge_sampler.reset_random_state() 36 | 37 | val_mrr, val_recall_20, val_recall_50 = [], [], [] 38 | with torch.no_grad(): 39 | model = model.eval() 40 | # While usually the test batch size is as big as it fits in memory, here we keep it the same 41 | # size as the training batch size, since it allows the memory to be updated more frequently, 42 | # and later test batches to access information from interactions in previous test batches 43 | # through the memory 44 | TEST_BATCH_SIZE = batch_size 45 | num_test_instance = len(data.sources) 46 | num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE) 47 | 48 | for k in range(num_test_batch): 49 | s_idx = k * TEST_BATCH_SIZE 50 | e_idx = min(num_test_instance, s_idx + TEST_BATCH_SIZE) 51 | sources_batch = data.sources[s_idx:e_idx] 52 | destinations_batch = data.destinations[s_idx:e_idx] 53 | timestamps_batch = data.timestamps[s_idx:e_idx] 54 | edge_idxs_batch = data.edge_idxs[s_idx: e_idx] 55 | 56 | size = len(sources_batch) 57 | _, negative_samples = negative_edge_sampler.sample(size) 58 | 59 | src_mem, des_mem = model(sources_batch, destinations_batch, 60 | negative_samples, timestamps_batch, 61 | edge_idxs_batch, test=True) 62 | 63 | src_cos_sim, src_idx = choose_target(model, model.memory_s.memory, model.memory_g.memory, src_mem) 64 | des_cos_sim, des_idx = choose_target(model, model.memory_s.memory, model.memory_g.memory, des_mem) 65 | recall_20 = (recall(destinations_batch, src_idx, 20) + recall(sources_batch, des_idx, 20)) / 2 66 | recall_50 = (recall(destinations_batch, src_idx, 50) + recall(sources_batch, des_idx, 50)) / 2 67 | mrr = (MRR(destinations_batch, src_idx) + MRR(sources_batch, des_idx)) / 2 68 | true_label = np.concatenate([np.ones(size), np.zeros(size)]) 69 | 70 | val_mrr.append(mrr) 71 | val_recall_20.append(recall_20) 72 | val_recall_50.append(recall_50) 73 | 74 | return np.mean(val_mrr), np.mean(val_recall_20), np.mean(val_recall_50) 75 | -------------------------------------------------------------------------------- /modules/propagater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class Propagater(nn.Module): 7 | def __init__(self, memory, message_dimension, memory_dimension, mean_time_shift_src, neighbor_finder, n_neighbors, 8 | tau=2, device='cpu'): 9 | super(Propagater, self).__init__() 10 | self.memory = memory 11 | self.layer_norm = torch.nn.LayerNorm(memory_dimension) 12 | self.message_dimension = message_dimension 13 | self.device = device 14 | self.alpha = 2 15 | self.neighbor_finder = neighbor_finder 16 | self.n_neighbors = n_neighbors 17 | self.tau = tau * mean_time_shift_src 18 | 19 | self.tanh = nn.Tanh() 20 | self.W_s = nn.Parameter(torch.zeros((message_dimension, memory_dimension)).to(self.device)) 21 | self.bias = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 22 | 23 | 24 | def compute_time_discount(self, edge_delta): 25 | time_intervals = self.alpha * edge_delta 26 | time_discount = torch.exp(-time_intervals) 27 | return time_discount 28 | 29 | def compute_attention_weight(self, memory, sources_idx, neighbors): 30 | batch_s = neighbors.shape[0] 31 | sources = memory[0][sources_idx].view(batch_s, 1, -1) 32 | softmax = nn.Softmax(dim=1) 33 | att = softmax(torch.matmul(neighbors, sources.transpose(1, 2))) 34 | return att 35 | 36 | def forward(self, memory, unique_node_ids, unique_messages, timestamps, inplace=True): 37 | if len(unique_messages) == 0: 38 | return memory 39 | timestamps = timestamps.cpu().numpy() 40 | neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(unique_node_ids, timestamps, 41 | n_neighbors=self.n_neighbors) 42 | neighbors_torch = torch.from_numpy(neighbors).long().to(self.device) 43 | batch_s, _ = neighbors.shape 44 | edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device) 45 | edge_deltas = timestamps[:, np.newaxis] - edge_times 46 | edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device) 47 | mask = (torch.from_numpy(edge_deltas).float().to(self.device) > 0).long() 48 | mask_re = mask.view(batch_s, self.n_neighbors, -1) 49 | neighbors_cell = memory[0][neighbors_torch.flatten()].view(batch_s, self.n_neighbors, -1) 50 | edge_delta_re = edge_deltas_torch.view(batch_s, self.n_neighbors, -1) 51 | unique_messages_re = unique_messages.repeat((self.n_neighbors, 1)).view(batch_s, self.n_neighbors, -1) 52 | time_discounts = self.compute_time_discount(edge_delta_re) 53 | time_threshold = (edge_delta_re < self.tau).long() 54 | att = self.compute_attention_weight(memory, unique_node_ids, neighbors_cell) #(b_s,n_neighbors,1) 55 | unique_messages = torch.matmul(unique_messages_re, self.W_s) #(b_s,n_neighbors,memory_size) 56 | C_v = memory[0][neighbors_torch.flatten()] + (mask_re*time_threshold*time_discounts*att * \ 57 | unique_messages).view(batch_s*self.n_neighbors, -1) #(b_s,n_neighbors,memory_size) 58 | h_v = self.tanh(C_v) 59 | if inplace: 60 | memory[0][neighbors_torch.flatten()] = C_v 61 | memory[1][neighbors_torch.flatten()] = h_v 62 | return memory 63 | else: 64 | memory_cell = memory[0].data 65 | memory_hidden = memory[1].data 66 | memory_cell[neighbors_torch.flatten()] = C_v 67 | memory_hidden[neighbors_torch.flatten()] = h_v 68 | return [memory_cell, memory_hidden] 69 | 70 | 71 | -------------------------------------------------------------------------------- /modules/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from collections import defaultdict 5 | 6 | 7 | class MemoryUpdater(nn.Module): 8 | def __init__(self, memory, message_dimension, memory_dimension, mean_time_shift_src, device): 9 | super(MemoryUpdater, self).__init__() 10 | self.memory = memory 11 | self.layer_norm = torch.nn.LayerNorm(memory_dimension) 12 | self.message_dimension = message_dimension 13 | self.device = device 14 | self.alpha = 2 15 | 16 | self.sig = nn.Sigmoid() 17 | self.tanh = nn.Tanh() 18 | self.W_d = nn.Parameter(torch.zeros((memory_dimension,memory_dimension)).to(self.device)) 19 | self.b_d = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 20 | self.W_f = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device)) 21 | self.U_f = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 22 | self.b_f = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 23 | self.W_i = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device)) 24 | self.U_i = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 25 | self.b_i = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 26 | self.W_o = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device)) 27 | self.U_o = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 28 | self.b_o = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 29 | self.W_c = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device)) 30 | self.U_c = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device)) 31 | self.b_c = nn.Parameter(torch.zeros(memory_dimension).to(self.device)) 32 | def update_memory(self, unique_node_ids, unique_messages, timestamps, inplace=True): 33 | if len(unique_messages) == 0: 34 | return self.memory.memory, self.memory.last_update 35 | hidden = self.memory.memory[1][unique_node_ids] #hidden [bs,memroy_size] 36 | cell = self.memory.memory[0][unique_node_ids] #cell [bs,memroy_size] 37 | messages = unique_messages 38 | time_discounts = self.compute_time_discount(unique_node_ids, timestamps) #[bs,1] 39 | bs = hidden.shape[0] 40 | C_vi = self.tanh(torch.matmul(self.W_d, cell.t()).t()+self.b_d) # [bs,memory_size] 41 | C_v_discount = torch.mul(C_vi, time_discounts.view(bs, 1)) 42 | C_v_t = cell - C_v_discount 43 | C_v_star = C_v_t + C_v_discount 44 | f_t = self.sig(torch.matmul(self.W_f, messages.t()).t() + torch.matmul(self.U_f, hidden.t()).t()+self.b_f) 45 | i_t = self.sig(torch.matmul(self.W_i, messages.t()).t() + torch.matmul(self.U_i, hidden.t()).t() + self.b_i) 46 | o_t = self.sig(torch.matmul(self.W_o, messages.t()).t() + torch.matmul(self.U_o, hidden.t()).t() + self.b_o) 47 | C_hat_t = self.tanh(torch.matmul(self.W_c, messages.t()).t() + torch.matmul(self.U_c, hidden.t()).t() + self.b_c) 48 | C_v_t = torch.mul(f_t, C_v_star) + torch.mul(i_t, C_hat_t) 49 | h_v_t = torch.mul(o_t, self.tanh(C_v_t)) 50 | if inplace: 51 | self.memory.memory[0][unique_node_ids] = C_v_t 52 | self.memory.memory[1][unique_node_ids] = h_v_t 53 | self.memory.last_update[unique_node_ids] = timestamps 54 | return self.memory.memory, self.memory.last_update 55 | else: 56 | memory_cell = self.memory.memory[0].data.clone() 57 | memory_hidden = self.memory.memory[1].data.clone() 58 | last_update = self.memory.last_update.data.clone() 59 | memory_cell[unique_node_ids] = C_v_t 60 | memory_hidden[unique_node_ids] = h_v_t 61 | last_update[unique_node_ids] = timestamps 62 | return [memory_cell, memory_hidden], last_update 63 | 64 | 65 | def compute_time_discount(self,unique_node_ids, timestamps): 66 | time_intervals = timestamps - self.memory.last_update[unique_node_ids] 67 | time_intervals = self.alpha*time_intervals 68 | time_discount = torch.exp(-time_intervals) 69 | return time_discount 70 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class MergeLayer(torch.nn.Module): 6 | def __init__(self, dim1, dim2, dim3, dim4): 7 | super().__init__() 8 | self.fc1 = torch.nn.Linear(dim1 + dim2, dim3) 9 | self.fc2 = torch.nn.Linear(dim3, dim4) 10 | self.act = torch.nn.ReLU() 11 | 12 | torch.nn.init.xavier_normal_(self.fc1.weight) 13 | torch.nn.init.xavier_normal_(self.fc2.weight) 14 | 15 | def forward(self, x1, x2): 16 | x = torch.cat([x1, x2], dim=1) 17 | h = self.act(self.fc1(x)) 18 | return self.fc2(h) 19 | 20 | 21 | class EarlyStopMonitor(object): 22 | def __init__(self, max_round=3, higher_better=True, tolerance=1e-10): 23 | self.max_round = max_round 24 | self.num_round = 0 25 | 26 | self.epoch_count = 0 27 | self.best_epoch = 0 28 | 29 | self.last_best = None 30 | self.higher_better = higher_better 31 | self.tolerance = tolerance 32 | 33 | def early_stop_check(self, curr_val): 34 | if not self.higher_better: 35 | curr_val *= -1 36 | if self.last_best is None: 37 | self.last_best = curr_val 38 | elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance: 39 | self.last_best = curr_val 40 | self.num_round = 0 41 | self.best_epoch = self.epoch_count 42 | else: 43 | self.num_round += 1 44 | 45 | self.epoch_count += 1 46 | 47 | return self.num_round >= self.max_round 48 | 49 | 50 | class RandEdgeSampler(object): 51 | def __init__(self, src_list, dst_list, seed=None): 52 | self.seed = None 53 | self.src_list = np.unique(src_list) 54 | self.dst_list = np.unique(dst_list) 55 | 56 | if seed is not None: 57 | self.seed = seed 58 | self.random_state = np.random.RandomState(self.seed) 59 | 60 | def sample(self, size): 61 | if self.seed is None: 62 | src_index = np.random.randint(0, len(self.src_list), size) 63 | dst_index = np.random.randint(0, len(self.dst_list), size) 64 | else: 65 | 66 | src_index = self.random_state.randint(0, len(self.src_list), size) 67 | dst_index = self.random_state.randint(0, len(self.dst_list), size) 68 | return self.src_list[src_index], self.dst_list[dst_index] 69 | 70 | def reset_random_state(self): 71 | self.random_state = np.random.RandomState(self.seed) 72 | 73 | 74 | def get_neighbor_finder(data, uniform): 75 | max_node_idx = max(data.sources.max(), data.destinations.max()) 76 | adj_list = [[] for _ in range(max_node_idx + 1)] 77 | for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations, 78 | data.edge_idxs, 79 | data.timestamps): 80 | adj_list[source].append((destination, edge_idx, timestamp)) 81 | adj_list[destination].append((source, edge_idx, timestamp)) 82 | 83 | return NeighborFinder(adj_list, uniform=uniform) 84 | 85 | 86 | class NeighborFinder: 87 | def __init__(self, adj_list, uniform=False, seed=None): 88 | self.node_to_neighbors = [] 89 | self.node_to_edge_idxs = [] 90 | self.node_to_edge_timestamps = [] 91 | 92 | for neighbors in adj_list: 93 | # Neighbors is a list of tuples (neighbor, edge_idx, timestamp) 94 | # We sort the list based on timestamp 95 | sorted_neighhbors = sorted(neighbors, key=lambda x: x[2]) 96 | self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors])) 97 | self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors])) 98 | self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors])) 99 | 100 | self.uniform = uniform 101 | 102 | if seed is not None: 103 | self.seed = seed 104 | self.random_state = np.random.RandomState(self.seed) 105 | 106 | def find_before(self, src_idx, cut_time): 107 | """ 108 | Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time. 109 | 110 | Returns 3 lists: neighbors, edge_idxs, timestamps 111 | 112 | """ 113 | i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time) 114 | 115 | return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i] 116 | 117 | def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20): 118 | """ 119 | Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list. 120 | 121 | Params 122 | ------ 123 | src_idx_l: List[int] 124 | cut_time_l: List[float], 125 | num_neighbors: int 126 | """ 127 | assert (len(source_nodes) == len(timestamps)) 128 | 129 | tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1 130 | # NB! All interactions described in these matrices are sorted in each row by time 131 | neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( 132 | np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i] 133 | edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( 134 | np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i] 135 | edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype( 136 | np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i] 137 | 138 | for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)): 139 | source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node, 140 | timestamp) # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time 141 | if len(source_neighbors) > 0 and n_neighbors > 0: 142 | if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling 143 | sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors) 144 | 145 | neighbors[i, :] = source_neighbors[sampled_idx] 146 | edge_times[i, :] = source_edge_times[sampled_idx] 147 | edge_idxs[i, :] = source_edge_idxs[sampled_idx] 148 | 149 | # re-sort based on time 150 | pos = edge_times[i, :].argsort() 151 | neighbors[i, :] = neighbors[i, :][pos] 152 | edge_times[i, :] = edge_times[i, :][pos] 153 | edge_idxs[i, :] = edge_idxs[i, :][pos] 154 | else: 155 | # Take most recent interactions 156 | source_edge_times = source_edge_times[-n_neighbors:] 157 | source_neighbors = source_neighbors[-n_neighbors:] 158 | source_edge_idxs = source_edge_idxs[-n_neighbors:] 159 | 160 | assert (len(source_neighbors) <= n_neighbors) 161 | assert (len(source_edge_times) <= n_neighbors) 162 | assert (len(source_edge_idxs) <= n_neighbors) 163 | 164 | neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors 165 | edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times 166 | edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs 167 | 168 | return neighbors, edge_idxs, edge_times -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import pandas as pd 4 | 5 | 6 | class Data: 7 | def __init__(self, sources, destinations, timestamps, edge_idxs, labels): 8 | self.sources = sources 9 | self.destinations = destinations 10 | self.timestamps = timestamps 11 | self.edge_idxs = edge_idxs 12 | self.labels = labels 13 | self.n_interactions = len(sources) 14 | self.unique_nodes = set(sources) | set(destinations) 15 | self.n_unique_nodes = len(self.unique_nodes) 16 | 17 | 18 | def get_data(dataset_name, different_new_nodes_between_val_and_test=False): 19 | ### Load data and train val test split 20 | graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name)) 21 | edge_features = np.load('./data/ml_{}.npy'.format(dataset_name)) 22 | node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name)) 23 | 24 | val_time, test_time = list(np.quantile(graph_df.ts, [0.80, 0.9])) 25 | 26 | sources = graph_df.u.values 27 | destinations = graph_df.i.values 28 | edge_idxs = graph_df.idx.values 29 | labels = graph_df.label.values 30 | timestamps = graph_df.ts.values 31 | 32 | full_data = Data(sources, destinations, timestamps, edge_idxs, labels) 33 | 34 | random.seed(2020) 35 | 36 | node_set = set(sources) | set(destinations) 37 | n_total_unique_nodes = len(node_set) 38 | 39 | # Compute nodes which appear at test time 40 | test_node_set = set(sources[timestamps > val_time]).union( 41 | set(destinations[timestamps > val_time])) 42 | # Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all 43 | # their edges from training 44 | new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes))) 45 | 46 | # Mask saying for each source and destination whether they are new test nodes 47 | new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values 48 | new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values 49 | 50 | # Mask which is true for edges with both destination and source not being new test nodes (because 51 | # we want to remove all edges involving any new test node) 52 | observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask) 53 | 54 | # For train we keep edges happening before the validation time which do not involve any new node 55 | # used for inductiveness 56 | train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask) 57 | 58 | train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask], 59 | edge_idxs[train_mask], labels[train_mask]) 60 | 61 | # define the new nodes sets for testing inductiveness of the model 62 | train_node_set = set(train_data.sources).union(train_data.destinations) 63 | assert len(train_node_set & new_test_node_set) == 0 64 | new_node_set = node_set - train_node_set 65 | 66 | val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) 67 | test_mask = timestamps > test_time 68 | 69 | if different_new_nodes_between_val_and_test: 70 | n_new_nodes = len(new_test_node_set) // 2 71 | val_new_node_set = set(list(new_test_node_set)[:n_new_nodes]) 72 | test_new_node_set = set(list(new_test_node_set)[n_new_nodes:]) 73 | 74 | edge_contains_new_val_node_mask = np.array( 75 | [(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)]) 76 | edge_contains_new_test_node_mask = np.array( 77 | [(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)]) 78 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask) 79 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask) 80 | 81 | 82 | else: 83 | edge_contains_new_node_mask = np.array( 84 | [(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)]) 85 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask) 86 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask) 87 | 88 | # validation and test with all edges 89 | val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask], 90 | edge_idxs[val_mask], labels[val_mask]) 91 | 92 | test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask], 93 | edge_idxs[test_mask], labels[test_mask]) 94 | 95 | # validation and test with edges that at least has one new node (not in training set) 96 | new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask], 97 | timestamps[new_node_val_mask], 98 | edge_idxs[new_node_val_mask], labels[new_node_val_mask]) 99 | 100 | new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask], 101 | timestamps[new_node_test_mask], edge_idxs[new_node_test_mask], 102 | labels[new_node_test_mask]) 103 | 104 | print("The dataset has {} interactions, involving {} different nodes".format(full_data.n_interactions, 105 | full_data.n_unique_nodes)) 106 | print("The training dataset has {} interactions, involving {} different nodes".format( 107 | train_data.n_interactions, train_data.n_unique_nodes)) 108 | print("The validation dataset has {} interactions, involving {} different nodes".format( 109 | val_data.n_interactions, val_data.n_unique_nodes)) 110 | print("The test dataset has {} interactions, involving {} different nodes".format( 111 | test_data.n_interactions, test_data.n_unique_nodes)) 112 | print("The new node validation dataset has {} interactions, involving {} different nodes".format( 113 | new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes)) 114 | print("The new node test dataset has {} interactions, involving {} different nodes".format( 115 | new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes)) 116 | print("{} nodes were used for the inductive testing, i.e. are never seen during training".format( 117 | len(new_test_node_set))) 118 | 119 | return node_features, edge_features, full_data, train_data, val_data, test_data, \ 120 | new_node_val_data, new_node_test_data 121 | 122 | 123 | def compute_time_statistics(sources, destinations, timestamps): 124 | last_timestamp_sources = dict() 125 | last_timestamp_dst = dict() 126 | all_timediffs_src = [] 127 | all_timediffs_dst = [] 128 | for k in range(len(sources)): 129 | source_id = sources[k] 130 | dest_id = destinations[k] 131 | c_timestamp = timestamps[k] 132 | if source_id not in last_timestamp_sources.keys(): 133 | last_timestamp_sources[source_id] = 0 134 | if dest_id not in last_timestamp_dst.keys(): 135 | last_timestamp_dst[dest_id] = 0 136 | all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id]) 137 | all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id]) 138 | last_timestamp_sources[source_id] = c_timestamp 139 | last_timestamp_dst[dest_id] = c_timestamp 140 | assert len(all_timediffs_src) == len(sources) 141 | assert len(all_timediffs_dst) == len(sources) 142 | mean_time_shift_src = np.mean(all_timediffs_src) 143 | std_time_shift_src = np.std(all_timediffs_src) 144 | mean_time_shift_dst = np.mean(all_timediffs_dst) 145 | std_time_shift_dst = np.std(all_timediffs_dst) 146 | 147 | return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst 148 | -------------------------------------------------------------------------------- /models/dgnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from collections import defaultdict 4 | import logging 5 | import numpy as np 6 | 7 | from modules.memory import Memory 8 | from modules.message_aggregator import get_message_aggregator 9 | from modules.merge import MemoryMerge 10 | from modules.message_function import MessageFunction 11 | from modules.update import MemoryUpdater 12 | from modules.propagater import Propagater 13 | 14 | 15 | class DGNN(nn.Module): 16 | def __init__(self, neighbor_finder, node_features, edge_features, device, 17 | dropout=0.1, 18 | memory_update_at_start=True, message_dimension=100, 19 | memory_dimension=200, n_neighbors=None, aggregator_type="last", 20 | mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0, 21 | std_time_shift_dst=1, threshold=2): 22 | super(DGNN, self).__init__() 23 | self.neighbor_finder = neighbor_finder 24 | self.device = device 25 | self.logger = logging.getLogger(__name__) 26 | 27 | self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device) 28 | self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device) 29 | 30 | self.n_node_features = self.node_raw_features.shape[1] 31 | self.n_nodes = self.node_raw_features.shape[0] 32 | self.n_edge_features = self.edge_raw_features.shape[1] 33 | self.embedding_dimension = self.n_node_features 34 | self.n_neighbors = n_neighbors 35 | self.memory_s = None 36 | self.memory_g = None 37 | 38 | self.threshold = threshold 39 | self.mean_time_shift_src = mean_time_shift_src 40 | self.std_time_shift_src = std_time_shift_src 41 | self.mean_time_shift_dst = mean_time_shift_dst 42 | self.std_time_shift_dst = std_time_shift_dst 43 | self.memory_dimension = memory_dimension 44 | self.memory_update_at_start = memory_update_at_start 45 | self.message_dimension = message_dimension 46 | self.memory_merge = MemoryMerge(self.memory_dimension, self.device) 47 | self.memory_s = Memory(n_nodes=self.n_nodes, 48 | memory_dimension=self.memory_dimension, 49 | message_dimension=message_dimension, 50 | device=device) 51 | self.memory_g = Memory(n_nodes=self.n_nodes, 52 | memory_dimension=self.memory_dimension, 53 | message_dimension=message_dimension, 54 | device=device) 55 | self.message_dim = message_dimension 56 | self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type, 57 | device=device) 58 | self.message_function = MessageFunction(memory_dimension=memory_dimension, 59 | message_dimension=message_dimension, edge_dimension=self.n_edge_features, 60 | device=self.device) 61 | self.memory_updater_s = MemoryUpdater(memory=self.memory_s, message_dimension=message_dimension, 62 | memory_dimension=self.memory_dimension, 63 | mean_time_shift_src=self.mean_time_shift_src/2, 64 | device=self.device) 65 | self.memory_updater_g = MemoryUpdater(memory=self.memory_g, message_dimension=message_dimension, 66 | memory_dimension=self.memory_dimension, 67 | mean_time_shift_src=self.mean_time_shift_dst / 2, 68 | device=self.device) 69 | self.propagater_s = Propagater(memory=self.memory_s, message_dimension=message_dimension, 70 | memory_dimension=self.memory_dimension, 71 | mean_time_shift_src=self.mean_time_shift_src / 2, 72 | neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold, 73 | device=self.device) 74 | self.propagater_g = Propagater(memory=self.memory_g, message_dimension=message_dimension, 75 | memory_dimension=self.memory_dimension, 76 | mean_time_shift_src=self.mean_time_shift_dst / 2, 77 | neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold, 78 | device=self.device) 79 | self.W_s = nn.Parameter(torch.zeros((memory_dimension, memory_dimension // 2)).to(self.device)) 80 | #nn.xavier_ 81 | self.W_g = nn.Parameter(torch.zeros((memory_dimension, memory_dimension // 2)).to(self.device)) 82 | 83 | def update_memory(self, source_nodes, destination_nodes, messages_s, messages_g): 84 | # Aggregate messages for the same nodes 85 | 86 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes, 87 | messages_s) 88 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate(destination_nodes, 89 | messages_g) 90 | 91 | # Update the memory with the aggregated messages 92 | self.memory_updater_s.update_memory(unique_src_nodes, unique_src_messages, 93 | timestamps=unique_src_timestamps) 94 | self.memory_updater_g.update_memory(unique_des_nodes, unique_des_messages, 95 | timestamps=unique_des_timestamps) 96 | 97 | def propagate(self, source_nodes, destination_nodes, messages_s, messages_g): 98 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes, 99 | messages_s) 100 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate( 101 | destination_nodes, 102 | messages_g) 103 | 104 | self.propagater_s(self.memory_s.memory, unique_src_nodes, unique_src_messages, 105 | timestamps=unique_src_timestamps) 106 | self.propagater_g(self.memory_g.memory, unique_des_nodes, unique_des_messages, 107 | timestamps=unique_des_timestamps) 108 | 109 | def compute_loss(self, memory_s, memory_g, source_nodes, destination_nodes): 110 | source_mem = self.memory_merge(memory_s[1][source_nodes], memory_g[1][source_nodes]) 111 | destination_mem = self.memory_merge(memory_s[1][destination_nodes], memory_g[1][destination_nodes]) 112 | source_emb = torch.matmul(source_mem, self.W_s) 113 | destination_emb = torch.matmul(destination_mem, self.W_g) 114 | score = torch.sum(source_emb*destination_emb, dim=1) 115 | return score.sigmoid() 116 | 117 | def forward(self, source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, test=False): 118 | n_samples = len(source_nodes) 119 | nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes]) 120 | positives = np.concatenate([source_nodes, destination_nodes]) 121 | timestamps = np.concatenate([edge_times, edge_times]) 122 | memory = None 123 | time_diffs = None 124 | memory_s, last_update_s, memory_g, last_update_g = \ 125 | self.get_updated_memory(list(range(self.n_nodes)), list(range(self.n_nodes)), 126 | self.memory_s.messages, self.memory_g.messages) 127 | 128 | pos_score = self.compute_loss(memory_s, memory_g, source_nodes, destination_nodes) 129 | neg_score = self.compute_loss(memory_s, memory_g, source_nodes, negative_nodes) 130 | self.update_memory(source_nodes, destination_nodes, 131 | self.memory_s.messages, self.memory_g.messages) 132 | self.propagate(source_nodes, destination_nodes, 133 | self.memory_s.messages, self.memory_g.messages) 134 | self.memory_s.clear_messages(positives) 135 | self.memory_g.clear_messages(positives) 136 | unique_sources, source_id_to_messages = self.get_messages(source_nodes, 137 | destination_nodes, 138 | edge_times, edge_idxs) 139 | unique_destinations, destination_id_to_messages = self.get_messages(destination_nodes, 140 | source_nodes, 141 | edge_times, 142 | edge_idxs) 143 | self.memory_s.store_raw_messages(unique_sources, source_id_to_messages) 144 | self.memory_g.store_raw_messages(unique_destinations, destination_id_to_messages) 145 | if not test: 146 | return pos_score, neg_score 147 | else: 148 | source_mem = self.memory_merge(memory_s[1][source_nodes], memory_g[1][source_nodes]) 149 | destination_mem = self.memory_merge(memory_s[1][destination_nodes], memory_g[1][destination_nodes]) 150 | return source_mem, destination_mem 151 | 152 | def get_messages(self, source_nodes, destination_nodes, edge_times, edge_idxs): 153 | edge_times = torch.from_numpy(edge_times).float().to(self.device) 154 | edge_features = self.edge_raw_features[edge_idxs] 155 | source_memory = self.memory_merge(self.memory_s.memory[1][source_nodes], 156 | self.memory_g.memory[1][source_nodes]) 157 | destination_memory = self.memory_merge(self.memory_s.memory[1][destination_nodes], 158 | self.memory_g.memory[1][destination_nodes]) 159 | 160 | source_message = self.message_function.compute_message(source_memory, destination_memory, edge_features) 161 | messages = defaultdict(list) 162 | unique_sources = np.unique(source_nodes) 163 | 164 | for i in range(len(source_nodes)): 165 | messages[source_nodes[i]].append((source_message[i], edge_times[i])) 166 | 167 | return unique_sources, messages 168 | 169 | def get_updated_memory(self, source_nodes, destination_nodes, message_s, message_g): 170 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes, 171 | message_s) 172 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate(destination_nodes, 173 | message_g) 174 | updated_src_memory, updated_src_last_update = self.memory_updater_s.update_memory(unique_src_nodes, 175 | unique_src_messages, 176 | timestamps=unique_src_timestamps, 177 | inplace=False) 178 | updated_des_memory, updated_des_last_update = self.memory_updater_g.update_memory(unique_des_nodes, 179 | unique_des_messages, 180 | timestamps=unique_des_timestamps, 181 | inplace=False) 182 | updated_src_memory = self.propagater_s(updated_src_memory, unique_src_nodes, unique_src_messages, 183 | timestamps=unique_src_timestamps, inplace=False) 184 | updated_des_memory = self.propagater_g(updated_des_memory, unique_des_nodes, unique_des_messages, 185 | timestamps=unique_des_timestamps, inplace=False) 186 | 187 | return updated_src_memory, updated_src_last_update, updated_des_memory, updated_des_last_update 188 | 189 | def set_neighbor_finder(self, neighbor_finder): 190 | self.neighbor_finder = neighbor_finder 191 | self.propagater_s.neighbor_finder = neighbor_finder 192 | self.propagater_g.neighbor_finder = neighbor_finder 193 | 194 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import time 4 | import sys 5 | import argparse 6 | import torch 7 | import numpy as np 8 | import pickle 9 | from pathlib import Path 10 | 11 | from evaluation import eval_edge_prediction 12 | from models.dgnn import DGNN 13 | from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder 14 | from utils.data_processing import get_data, compute_time_statistics 15 | 16 | 17 | 18 | ### Argument and global variables 19 | parser = argparse.ArgumentParser('DGGN training') 20 | parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)', 21 | default='reddit') 22 | parser.add_argument('--bs', type=int, default=200, help='Batch_size') 23 | parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints') 24 | parser.add_argument('--n_degree', type=int, default=20, help='Number of neighbors to sample') 25 | parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs') 26 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate') 27 | parser.add_argument('--patience', type=int, default=8, help='Patience for early stopping') 28 | parser.add_argument('--n_runs', type=int, default=10, help='Number of runs') 29 | parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability') 30 | parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use') 31 | parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding') 32 | parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to ' 33 | 'backprop') 34 | 35 | parser.add_argument('--aggregator', type=str, default="last", help='Type of message ' 36 | 'aggregator') 37 | 38 | parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages') 39 | 40 | parser.add_argument('--memory_dim', type=int, default=172, help='Dimensions of the memory for ' 41 | 'each user') 42 | parser.add_argument('--different_new_nodes', action='store_true', 43 | help='Whether to use disjoint set of new nodes for train and val') 44 | parser.add_argument('--uniform', action='store_true', 45 | help='take uniform sampling from temporal neighbors') 46 | parser.add_argument('--seed', type=int, default=0, help='random seed') 47 | 48 | try: 49 | args = parser.parse_args() 50 | except: 51 | parser.print_help() 52 | sys.exit(0) 53 | torch.manual_seed(args.seed) 54 | np.random.seed(args.seed) 55 | BATCH_SIZE = args.bs 56 | NUM_NEIGHBORS = args.n_degree 57 | NUM_NEG = 1 58 | NUM_EPOCH = args.n_epoch 59 | DROP_OUT = args.drop_out 60 | GPU = args.gpu 61 | SEQ_LEN = NUM_NEIGHBORS 62 | DATA = args.data 63 | LEARNING_RATE = args.lr 64 | NODE_DIM = args.node_dim 65 | MESSAGE_DIM = args.message_dim 66 | MEMORY_DIM = args.memory_dim 67 | 68 | Path("./saved_models/").mkdir(parents=True, exist_ok=True) 69 | Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True) 70 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth' 71 | get_checkpoint_path = lambda \ 72 | epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth' 73 | 74 | ### set up logger 75 | logging.basicConfig(level=logging.INFO) 76 | logger = logging.getLogger() 77 | logger.setLevel(logging.DEBUG) 78 | Path("log/").mkdir(parents=True, exist_ok=True) 79 | timenow = time.strftime("%Y-%m-%d_%H:%M:%S",time.localtime(time.time())) 80 | fh = logging.FileHandler('log/{}.log'.format(str(args.prefix)+timenow)) 81 | fh.setLevel(logging.DEBUG) 82 | ch = logging.StreamHandler() 83 | ch.setLevel(logging.WARN) 84 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 85 | fh.setFormatter(formatter) 86 | ch.setFormatter(formatter) 87 | logger.addHandler(fh) 88 | logger.addHandler(ch) 89 | logger.info(args) 90 | 91 | ### Extract data for training, validation and testing 92 | node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \ 93 | new_node_test_data = get_data(DATA, 94 | different_new_nodes_between_val_and_test=args.different_new_nodes) 95 | 96 | # Initialize training neighbor finder to retrieve temporal graph 97 | train_ngh_finder = get_neighbor_finder(train_data, args.uniform) 98 | 99 | # Initialize validation and test neighbor finder to retrieve temporal graph 100 | full_ngh_finder = get_neighbor_finder(full_data, args.uniform) 101 | 102 | # Initialize negative samplers. Set seeds for validation and testing so negatives are the same 103 | # across different runs 104 | # NB: in the inductive setting, negatives are sampled only amongst other new nodes 105 | train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations) 106 | val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0) 107 | nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations, 108 | seed=1) 109 | test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2) 110 | nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources, 111 | new_node_test_data.destinations, 112 | seed=3) 113 | 114 | # Set device 115 | device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu' 116 | device = torch.device(device_string) 117 | 118 | # Compute time statistics 119 | mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \ 120 | compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps) 121 | 122 | for i in range(args.n_runs): 123 | results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix) 124 | Path("results/").mkdir(parents=True, exist_ok=True) 125 | 126 | # Initialize Model 127 | dgnn = DGNN(neighbor_finder=train_ngh_finder, node_features=node_features, 128 | edge_features=edge_features, device=device, dropout=DROP_OUT, 129 | message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM, 130 | aggregator_type=args.aggregator, n_neighbors=NUM_NEIGHBORS, 131 | mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src, 132 | mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst) 133 | for a in dgnn.parameters(): 134 | if a.ndim > 1: 135 | torch.nn.init.xavier_uniform_(a) 136 | 137 | dgnn.memory_s.__init_memory__() 138 | dgnn.memory_g.__init_memory__() 139 | 140 | criterion = torch.nn.BCELoss() 141 | optimizer = torch.optim.Adam(dgnn.parameters(), lr=LEARNING_RATE) 142 | dgnn = dgnn.to(device) 143 | 144 | num_instance = len(train_data.sources) 145 | num_batch = math.ceil(num_instance / BATCH_SIZE) 146 | 147 | logger.info('num of training instances: {}'.format(num_instance)) 148 | logger.info('num of batches per epoch: {}'.format(num_batch)) 149 | idx_list = np.arange(num_instance) 150 | 151 | new_nodes_val_aps = [] 152 | val_aps = [] 153 | epoch_times = [] 154 | total_epoch_times = [] 155 | train_losses = [] 156 | 157 | early_stopper = EarlyStopMonitor(max_round=args.patience) 158 | for epoch in range(NUM_EPOCH): 159 | start_epoch = time.time() 160 | ### Training 161 | 162 | # Reinitialize memory of the model at the start of each epoch 163 | dgnn.memory_s.__init_memory__() 164 | dgnn.memory_g.__init_memory__() 165 | 166 | # Train using only training graph 167 | dgnn.set_neighbor_finder(train_ngh_finder) 168 | m_loss = [] 169 | 170 | logger.info('start {} epoch'.format(epoch)) 171 | for k in range(0, num_batch, args.backprop_every): 172 | loss = 0 173 | optimizer.zero_grad() 174 | 175 | # Custom loop to allow to perform backpropagation only every a certain number of batches 176 | for j in range(args.backprop_every): 177 | batch_idx = k + j 178 | 179 | if batch_idx >= num_batch: 180 | continue 181 | 182 | start_idx = batch_idx * BATCH_SIZE 183 | end_idx = min(num_instance, start_idx + BATCH_SIZE) 184 | sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \ 185 | train_data.destinations[start_idx:end_idx] 186 | edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx] 187 | timestamps_batch = train_data.timestamps[start_idx:end_idx] 188 | 189 | size = len(sources_batch) 190 | _, negatives_batch = train_rand_sampler.sample(size) 191 | 192 | with torch.no_grad(): 193 | pos_label = torch.ones(size, dtype=torch.float, device=device) 194 | neg_label = torch.zeros(size, dtype=torch.float, device=device) 195 | 196 | dgnn = dgnn.train() 197 | pos_prob, neg_prob = dgnn(sources_batch, destinations_batch, negatives_batch, 198 | timestamps_batch, edge_idxs_batch) 199 | 200 | loss += criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label) 201 | 202 | loss /= args.backprop_every 203 | 204 | loss.backward() 205 | optimizer.step() 206 | m_loss.append(loss.item()) 207 | 208 | # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to 209 | # the start of time 210 | dgnn.memory_s.detach_memory() 211 | dgnn.memory_g.detach_memory() 212 | 213 | epoch_time = time.time() - start_epoch 214 | epoch_times.append(epoch_time) 215 | 216 | ### Validation 217 | # Validation uses the full graph 218 | dgnn.set_neighbor_finder(full_ngh_finder) 219 | 220 | # Backup memory at the end of training, so later we can restore it and use it for the 221 | # validation on unseen nodes 222 | train_memory_backup_s = dgnn.memory_s.backup_memory() 223 | train_memory_backup_g = dgnn.memory_g.backup_memory() 224 | 225 | val_ap, val_auc = eval_edge_prediction(model=dgnn, 226 | negative_edge_sampler=val_rand_sampler, 227 | data=val_data, 228 | n_neighbors=NUM_NEIGHBORS) 229 | 230 | val_memory_backup_s = dgnn.memory_s.backup_memory() 231 | val_memory_backup_g = dgnn.memory_g.backup_memory() 232 | # Restore memory we had at the end of training to be used when validating on new nodes. 233 | # Also backup memory after validation so it can be used for testing (since test edges are 234 | # strictly later in time than validation edges) 235 | dgnn.memory_s.restore_memory(train_memory_backup_s) 236 | dgnn.memory_g.restore_memory(train_memory_backup_g) 237 | 238 | # Validate on unseen nodes 239 | nn_val_ap, nn_val_auc = eval_edge_prediction(model=dgnn,negative_edge_sampler=val_rand_sampler, 240 | data=new_node_val_data, n_neighbors=NUM_NEIGHBORS) 241 | 242 | # Restore memory we had at the end of validation 243 | dgnn.memory_s.restore_memory(val_memory_backup_s) 244 | dgnn.memory_g.restore_memory(val_memory_backup_g) 245 | 246 | new_nodes_val_aps.append(nn_val_ap) 247 | val_aps.append(val_ap) 248 | train_losses.append(np.mean(m_loss)) 249 | 250 | # Save temporary results to disk 251 | pickle.dump({ 252 | "val_aps": val_aps, 253 | "new_nodes_val_aps": new_nodes_val_aps, 254 | "train_losses": train_losses, 255 | "epoch_times": epoch_times, 256 | "total_epoch_times": total_epoch_times 257 | }, open(results_path, "wb")) 258 | 259 | total_epoch_time = time.time() - start_epoch 260 | total_epoch_times.append(total_epoch_time) 261 | 262 | logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time)) 263 | logger.info('Epoch mean loss: {}'.format(np.mean(m_loss))) 264 | logger.info( 265 | 'val auc: {}, new node val auc: {}'.format(val_auc, nn_val_auc)) 266 | logger.info( 267 | 'val ap: {}, new node val ap: {}'.format(val_ap, nn_val_ap)) 268 | 269 | # Early stopping 270 | if early_stopper.early_stop_check(val_ap): 271 | logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round)) 272 | logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}') 273 | best_model_path = get_checkpoint_path(early_stopper.best_epoch) 274 | dgnn.load_state_dict(torch.load(best_model_path)) 275 | logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference') 276 | dgnn.eval() 277 | break 278 | else: 279 | torch.save(dgnn.state_dict(), get_checkpoint_path(epoch)) 280 | 281 | # Training has finished, we have loaded the best model, and we want to backup its current 282 | # memory (which has seen validation edges) so that it can also be used when testing on unseen 283 | # nodes 284 | val_memory_backup_s = dgnn.memory_s.backup_memory() 285 | val_memory_backup_g = dgnn.memory_g.backup_memory() 286 | 287 | ### Test 288 | dgnn.propagater_g.neighbor_finder = full_ngh_finder 289 | dgnn.propagater_s.neighbor_finder = full_ngh_finder 290 | test_ap, test_auc = eval_edge_prediction(model=dgnn, negative_edge_sampler=test_rand_sampler, 291 | data=test_data, n_neighbors=NUM_NEIGHBORS) 292 | 293 | dgnn.memory_s.restore_memory(val_memory_backup_s) 294 | dgnn.memory_g.restore_memory(val_memory_backup_g) 295 | 296 | # Test on unseen nodes 297 | nn_test_ap, nn_test_auc = eval_edge_prediction(model=dgnn, negative_edge_sampler=nn_test_rand_sampler, 298 | data=new_node_test_data, n_neighbors=NUM_NEIGHBORS) 299 | 300 | logger.info( 301 | 'Test statistics: Old nodes -- auc: {}, ap: {}'.format(test_auc, test_ap)) 302 | logger.info( 303 | 'Test statistics: New nodes -- auc: {}, ap: {}'.format(nn_test_auc, nn_test_ap)) 304 | # Save results for this run 305 | pickle.dump({ 306 | "val_aps": val_aps, 307 | "new_nodes_val_aps": new_nodes_val_aps, 308 | "test_ap": test_ap, 309 | "new_node_test_ap": nn_test_ap, 310 | "epoch_times": epoch_times, 311 | "train_losses": train_losses, 312 | "total_epoch_times": total_epoch_times 313 | }, open(results_path, "wb")) 314 | 315 | logger.info('Saving DGNN model') 316 | # Restore memory at the end of validation (save a model which is ready for testing) 317 | dgnn.memory_s.restore_memory(val_memory_backup_s) 318 | dgnn.memory_g.restore_memory(val_memory_backup_g) 319 | torch.save(dgnn.state_dict(), MODEL_SAVE_PATH) 320 | logger.info('DGNN model saved') 321 | -------------------------------------------------------------------------------- /trainUCI.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import time 4 | import sys 5 | import argparse 6 | import torch 7 | import numpy as np 8 | import pickle 9 | from pathlib import Path 10 | 11 | from evaluation import eval_edge_prediction 12 | from models.dgnn import DGNN 13 | from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder 14 | from utils.data_processing import get_data, compute_time_statistics 15 | 16 | 17 | 18 | ### Argument and global variables 19 | parser = argparse.ArgumentParser('DGGN training') 20 | parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)', 21 | default='UCI-Msg') 22 | parser.add_argument('--bs', type=int, default=200, help='Batch_size') 23 | parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints') 24 | parser.add_argument('--n_degree', type=int, default=20, help='Number of neighbors to sample') 25 | parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs') 26 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate') 27 | parser.add_argument('--patience', type=int, default=8, help='Patience for early stopping') 28 | parser.add_argument('--n_runs', type=int, default=10, help='Number of runs') 29 | parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability') 30 | parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use') 31 | parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding') 32 | parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to ' 33 | 'backprop') 34 | 35 | parser.add_argument('--aggregator', type=str, default="last", help='Type of message ' 36 | 'aggregator') 37 | 38 | parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages') 39 | 40 | parser.add_argument('--memory_dim', type=int, default=100, help='Dimensions of the memory for ' 41 | 'each user') 42 | parser.add_argument('--different_new_nodes', action='store_true', 43 | help='Whether to use disjoint set of new nodes for train and val') 44 | parser.add_argument('--uniform', action='store_true', 45 | help='take uniform sampling from temporal neighbors') 46 | parser.add_argument('--seed', type=int, default=0, help='random seed') 47 | parser.add_argument('--threshold', type=int, default=2, help='time threshold') 48 | 49 | try: 50 | args = parser.parse_args() 51 | except: 52 | parser.print_help() 53 | sys.exit(0) 54 | torch.manual_seed(args.seed) 55 | np.random.seed(args.seed) 56 | BATCH_SIZE = args.bs 57 | NUM_NEIGHBORS = args.n_degree 58 | NUM_NEG = 1 59 | NUM_EPOCH = args.n_epoch 60 | DROP_OUT = args.drop_out 61 | GPU = args.gpu 62 | SEQ_LEN = NUM_NEIGHBORS 63 | DATA = args.data 64 | LEARNING_RATE = args.lr 65 | NODE_DIM = args.node_dim 66 | MESSAGE_DIM = args.message_dim 67 | MEMORY_DIM = args.memory_dim 68 | 69 | Path("./saved_models/").mkdir(parents=True, exist_ok=True) 70 | Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True) 71 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth' 72 | get_checkpoint_path = lambda \ 73 | epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth' 74 | 75 | ### set up logger 76 | logging.basicConfig(level=logging.INFO) 77 | logger = logging.getLogger() 78 | logger.setLevel(logging.DEBUG) 79 | Path("log/").mkdir(parents=True, exist_ok=True) 80 | timenow = time.strftime("%Y-%m-%d_%H:%M:%S",time.localtime(time.time())) 81 | fh = logging.FileHandler('log/{}.log'.format(str(args.prefix)+timenow)) 82 | fh.setLevel(logging.DEBUG) 83 | ch = logging.StreamHandler() 84 | ch.setLevel(logging.WARN) 85 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 86 | fh.setFormatter(formatter) 87 | ch.setFormatter(formatter) 88 | logger.addHandler(fh) 89 | logger.addHandler(ch) 90 | logger.info(args) 91 | 92 | ### Extract data for training, validation and testing 93 | node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \ 94 | new_node_test_data = get_data(DATA, 95 | different_new_nodes_between_val_and_test=args.different_new_nodes) 96 | 97 | # Initialize training neighbor finder to retrieve temporal graph 98 | train_ngh_finder = get_neighbor_finder(train_data, args.uniform) 99 | 100 | # Initialize validation and test neighbor finder to retrieve temporal graph 101 | full_ngh_finder = get_neighbor_finder(full_data, args.uniform) 102 | 103 | # Initialize negative samplers. Set seeds for validation and testing so negatives are the same 104 | # across different runs 105 | # NB: in the inductive setting, negatives are sampled only amongst other new nodes 106 | train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations) 107 | val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0) 108 | nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations, 109 | seed=1) 110 | test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2) 111 | nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources, 112 | new_node_test_data.destinations, 113 | seed=3) 114 | 115 | # Set device 116 | device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu' 117 | device = torch.device(device_string) 118 | 119 | # Compute time statistics 120 | mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \ 121 | compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps) 122 | 123 | for i in range(args.n_runs): 124 | results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix) 125 | Path("results/").mkdir(parents=True, exist_ok=True) 126 | 127 | # Initialize Model 128 | dgnn = DGNN(neighbor_finder=train_ngh_finder, node_features=node_features, 129 | edge_features=edge_features, device=device, dropout=DROP_OUT, 130 | message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM, 131 | aggregator_type=args.aggregator, n_neighbors=NUM_NEIGHBORS, 132 | mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src, 133 | mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst, threshold=args.threshold) 134 | for a in dgnn.parameters(): 135 | if a.ndim > 1: 136 | torch.nn.init.xavier_uniform_(a) 137 | 138 | dgnn.memory_s.__init_memory__(args.seed) 139 | dgnn.memory_g.__init_memory__(args.seed) 140 | 141 | criterion = torch.nn.BCELoss() 142 | optimizer = torch.optim.Adam(dgnn.parameters(), lr=LEARNING_RATE) 143 | dgnn = dgnn.to(device) 144 | 145 | num_instance = len(train_data.sources) 146 | num_batch = math.ceil(num_instance / BATCH_SIZE) 147 | 148 | logger.info('num of training instances: {}'.format(num_instance)) 149 | logger.info('num of batches per epoch: {}'.format(num_batch)) 150 | idx_list = np.arange(num_instance) 151 | 152 | new_nodes_val_mrrs = [] 153 | val_mrrs = [] 154 | epoch_times = [] 155 | total_epoch_times = [] 156 | train_losses = [] 157 | 158 | early_stopper = EarlyStopMonitor(max_round=args.patience) 159 | for epoch in range(NUM_EPOCH): 160 | start_epoch = time.time() 161 | ### Training 162 | 163 | # Reinitialize memory of the model at the start of each epoch 164 | dgnn.memory_s.__init_memory__(args.seed) 165 | dgnn.memory_g.__init_memory__(args.seed) 166 | 167 | # Train using only training graph 168 | dgnn.set_neighbor_finder(train_ngh_finder) 169 | m_loss = [] 170 | 171 | logger.info('start {} epoch'.format(epoch)) 172 | for k in range(0, num_batch, args.backprop_every): 173 | loss = 0 174 | optimizer.zero_grad() 175 | 176 | # Custom loop to allow to perform backpropagation only every a certain number of batches 177 | for j in range(args.backprop_every): 178 | batch_idx = k + j 179 | 180 | if batch_idx >= num_batch: 181 | continue 182 | 183 | start_idx = batch_idx * BATCH_SIZE 184 | end_idx = min(num_instance, start_idx + BATCH_SIZE) 185 | sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \ 186 | train_data.destinations[start_idx:end_idx] 187 | edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx] 188 | timestamps_batch = train_data.timestamps[start_idx:end_idx] 189 | 190 | size = len(sources_batch) 191 | _, negatives_batch = train_rand_sampler.sample(size) 192 | 193 | with torch.no_grad(): 194 | pos_label = torch.ones(size, dtype=torch.float, device=device) 195 | neg_label = torch.zeros(size, dtype=torch.float, device=device) 196 | 197 | dgnn = dgnn.train() 198 | pos_prob, neg_prob = dgnn(sources_batch, destinations_batch, negatives_batch, 199 | timestamps_batch, edge_idxs_batch) 200 | 201 | loss += criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label) 202 | 203 | loss /= args.backprop_every 204 | 205 | loss.backward() 206 | optimizer.step() 207 | m_loss.append(loss.item()) 208 | 209 | # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to 210 | # the start of time 211 | dgnn.memory_s.detach_memory() 212 | dgnn.memory_g.detach_memory() 213 | 214 | epoch_time = time.time() - start_epoch 215 | epoch_times.append(epoch_time) 216 | 217 | ### Validation 218 | # Validation uses the full graph 219 | dgnn.set_neighbor_finder(full_ngh_finder) 220 | 221 | # Backup memory at the end of training, so later we can restore it and use it for the 222 | # validation on unseen nodes 223 | train_memory_backup_s = dgnn.memory_s.backup_memory() 224 | train_memory_backup_g = dgnn.memory_g.backup_memory() 225 | 226 | val_mrr, val_recall_20, val_recall_50= eval_edge_prediction(model=dgnn, negative_edge_sampler=val_rand_sampler, data=val_data, 227 | n_neighbors=NUM_NEIGHBORS) 228 | 229 | val_memory_backup_s = dgnn.memory_s.backup_memory() 230 | val_memory_backup_g = dgnn.memory_g.backup_memory() 231 | # Restore memory we had at the end of training to be used when validating on new nodes. 232 | # Also backup memory after validation so it can be used for testing (since test edges are 233 | # strictly later in time than validation edges) 234 | dgnn.memory_s.restore_memory(train_memory_backup_s) 235 | dgnn.memory_g.restore_memory(train_memory_backup_g) 236 | 237 | # Validate on unseen nodes 238 | nn_val_mrr, nn_val_recall_20, nn_val_recall_50 = eval_edge_prediction(model=dgnn,negative_edge_sampler=val_rand_sampler, 239 | data=new_node_val_data, n_neighbors=NUM_NEIGHBORS) 240 | 241 | # Restore memory we had at the end of validation 242 | dgnn.memory_s.restore_memory(val_memory_backup_s) 243 | dgnn.memory_g.restore_memory(val_memory_backup_g) 244 | 245 | new_nodes_val_mrrs.append(nn_val_mrr) 246 | val_mrrs.append(val_mrr) 247 | train_losses.append(np.mean(m_loss)) 248 | 249 | # Save temporary results to disk 250 | pickle.dump({ 251 | "val_mrrs": val_mrrs, 252 | "new_nodes_val_aps": new_nodes_val_mrrs, 253 | "train_losses": train_losses, 254 | "epoch_times": epoch_times, 255 | "total_epoch_times": total_epoch_times 256 | }, open(results_path, "wb")) 257 | 258 | total_epoch_time = time.time() - start_epoch 259 | total_epoch_times.append(total_epoch_time) 260 | 261 | logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time)) 262 | logger.info('Epoch mean loss: {}'.format(np.mean(m_loss))) 263 | logger.info( 264 | 'val mrr: {}, new node val mrr: {}'.format(val_mrr, nn_val_mrr)) 265 | logger.info( 266 | 'val recall 20: {}, new node val recall 20: {}'.format(val_recall_20, nn_val_recall_20)) 267 | logger.info( 268 | 'val recall 50: {}, new node val recall 50: {}'.format(val_recall_50, nn_val_recall_50)) 269 | 270 | # Early stopping 271 | if early_stopper.early_stop_check(val_mrr): 272 | logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round)) 273 | logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}') 274 | best_model_path = get_checkpoint_path(early_stopper.best_epoch) 275 | dgnn.load_state_dict(torch.load(best_model_path)) 276 | logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference') 277 | dgnn.eval() 278 | break 279 | else: 280 | torch.save(dgnn.state_dict(), get_checkpoint_path(epoch)) 281 | 282 | # Training has finished, we have loaded the best model, and we want to backup its current 283 | # memory (which has seen validation edges) so that it can also be used when testing on unseen 284 | # nodes 285 | val_memory_backup_s = dgnn.memory_s.backup_memory() 286 | val_memory_backup_g = dgnn.memory_g.backup_memory() 287 | 288 | ### Test 289 | dgnn.propagater_g.neighbor_finder = full_ngh_finder 290 | dgnn.propagater_s.neighbor_finder = full_ngh_finder 291 | test_mrr, test_recall_20, test_recall_50 = eval_edge_prediction(model=dgnn, negative_edge_sampler=test_rand_sampler, 292 | data=test_data, n_neighbors=NUM_NEIGHBORS) 293 | 294 | dgnn.memory_s.restore_memory(val_memory_backup_s) 295 | dgnn.memory_g.restore_memory(val_memory_backup_g) 296 | 297 | # Test on unseen nodes 298 | nn_test_mrr, nn_test_recall_20, nn_test_recall_50 = eval_edge_prediction(model=dgnn, negative_edge_sampler=nn_test_rand_sampler, 299 | data=new_node_test_data, n_neighbors=NUM_NEIGHBORS) 300 | 301 | logger.info( 302 | 'Test statistics: Old nodes -- mrr: {}, recall_20: {}, recall_50:{}'.format(test_mrr, test_recall_20, 303 | test_recall_50)) 304 | logger.info( 305 | 'Test statistics: New nodes -- mrr: {}, recall_20: {}, recall_50:{}'.format(nn_test_mrr, nn_test_recall_20, 306 | nn_test_recall_50)) 307 | # Save results for this run 308 | pickle.dump({ 309 | "val_aps": val_mrrs, 310 | "new_nodes_val_aps": new_nodes_val_mrrs, 311 | "test_ap": test_mrr, 312 | "new_node_test_ap": nn_test_mrr, 313 | "epoch_times": epoch_times, 314 | "train_losses": train_losses, 315 | "total_epoch_times": total_epoch_times 316 | }, open(results_path, "wb")) 317 | 318 | logger.info('Saving DGNN model') 319 | # Restore memory at the end of validation (save a model which is ready for testing) 320 | dgnn.memory_s.restore_memory(val_memory_backup_s) 321 | dgnn.memory_g.restore_memory(val_memory_backup_g) 322 | torch.save(dgnn.state_dict(), MODEL_SAVE_PATH) 323 | logger.info('DGNN model saved') 324 | --------------------------------------------------------------------------------