├── net ├── __pycache__ │ └── tgn.cpython-36.pyc └── tgn.py ├── utils ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── data_processing.cpython-36.pyc ├── preprocess_data.py ├── data_processing.py └── utils.py ├── modules ├── __pycache__ │ ├── memory.cpython-36.pyc │ ├── time_encoding.cpython-36.pyc │ ├── memory_updater.cpython-36.pyc │ ├── embedding_module.cpython-36.pyc │ ├── message_aggregator.cpython-36.pyc │ ├── message_function.cpython-36.pyc │ └── temporal_attention.cpython-36.pyc ├── time_encoding.py ├── message_function.py ├── memory.py ├── temporal_attention.py ├── memory_updater.py ├── message_aggregator.py └── embedding_module.py ├── LICENSE ├── README.md └── link_prediction.py /net/__pycache__/tgn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/net/__pycache__/tgn.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/memory.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/memory.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/time_encoding.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/time_encoding.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/data_processing.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/utils/__pycache__/data_processing.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/memory_updater.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/memory_updater.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/embedding_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/embedding_module.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/message_aggregator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/message_aggregator.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/message_function.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/message_function.cpython-36.pyc -------------------------------------------------------------------------------- /modules/__pycache__/temporal_attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KimMeen/TGN/HEAD/modules/__pycache__/temporal_attention.cpython-36.pyc -------------------------------------------------------------------------------- /modules/time_encoding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Aug 25 17:32:19 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class TimeEncode(torch.nn.Module): 13 | """ 14 | Time Encoding proposed by TGAT 15 | """ 16 | def __init__(self, dimension): 17 | super(TimeEncode, self).__init__() 18 | 19 | self.dimension = dimension 20 | 21 | self.w = torch.nn.Linear(1, dimension) 22 | 23 | self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))) 24 | .float().reshape(dimension, -1)) 25 | self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float()) 26 | 27 | def forward(self, t): 28 | # t has shape [batch_size, seq_len], i.e. [source_nodes, num_temp_neighbors] 29 | # [batch_size, seq_len, 1], i.e. [source_nodes, num_temp_neighbors, 1] 30 | t = t.unsqueeze(dim=2) 31 | 32 | # output has shape [batch_size, seq_len, dimension], i.e. [source_nodes, num_temp_neighbors, dimension] 33 | output = torch.cos(self.w(t)) 34 | 35 | return output -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ming Jin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TGN 2 | A PyTorch annotated replication of [twitter-research](https://github.com/twitter-research)/**[tgn](https://github.com/twitter-research/tgn)** 3 | 4 | Paper: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) 5 | 6 | ## Requirements 7 | 8 | Python >= 3.6 9 | 10 | ``` 11 | pandas==1.1.0 12 | torch==1.6.0 13 | scikit_learn==0.23.1 14 | ``` 15 | 16 | ## Preprocess datasets 17 | 18 | #### Download the public data 19 | 20 | Download the sample datasets (eg. wikipedia and reddit) from [here](http://snap.stanford.edu/jodie/) and store their csv files in a folder named `./data` 21 | 22 | #### Preprocess the data 23 | 24 | We use the dense `npy` format to save the features in binary format. If edge features or nodes features are absent, they will be replaced by a vector of zeros. 25 | 26 | ``` 27 | python utils/preprocess_data.py --data wikipedia 28 | python utils/preprocess_data.py --data reddit 29 | ``` 30 | 31 | ## Model training 32 | 33 | Self-supervised learning using the link prediction task: 34 | 35 | ``` 36 | # TGN-attn: self-supervised learning on the wikipedia dataset 37 | python link_prediction.py --data wikipedia --embedding_module graph_sum --use_memory --memory_update_at_start 38 | 39 | # TGN-attn-reddit: self-supervised learning on the reddit dataset 40 | python link_prediction.py --data reddit --embedding_module graph_sum --use_memory --memory_update_at_start 41 | ``` 42 | 43 | ** Check more commands with `--help` 44 | 45 | ## TODOs 46 | 47 | - Add code for training on the downstream node-classification task (semi-supervised setting) 48 | 49 | ## Cite the paper 50 | 51 | ``` 52 | @inproceedings{tgn_icml_grl2020, 53 | title={Temporal Graph Networks for Deep Learning on Dynamic Graphs}, 54 | author={Emanuele Rossi and Ben Chamberlain and Fabrizio Frasca and Davide Eynard and Federico 55 | Monti and Michael Bronstein}, 56 | booktitle={ICML 2020 Workshop on Graph Representation Learning}, 57 | year={2020} 58 | } 59 | ``` -------------------------------------------------------------------------------- /modules/message_function.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 14:04:31 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | from torch import nn 9 | 10 | 11 | class MessageFunction(nn.Module): 12 | """ 13 | Abstract class 14 | 15 | Module which computes the message for a given interaction. 16 | """ 17 | 18 | def compute_message(self, raw_messages): 19 | return None 20 | 21 | 22 | class MLPMessageFunction(MessageFunction): 23 | """ 24 | MLP message function to calculate the message m(t) 25 | 26 | INPUT: 27 | raw_message_dimension: Dimension of the raw_message 28 | message_dimension: Dimension of the message 29 | raw_messages: [S_i(t-1) || S_j(t-1) || delta_t || e(t)] for interaction events 30 | 31 | OUTPUT: 32 | message: m(t) <-- [S_i(t-1) || S_j(t-1) || delta_t || e(t)] for interation events 33 | 34 | """ 35 | def __init__(self, raw_message_dimension, message_dimension): 36 | super(MLPMessageFunction, self).__init__() 37 | 38 | self.mlp = self.layers = nn.Sequential( 39 | nn.Linear(raw_message_dimension, raw_message_dimension // 2), 40 | nn.ReLU(), 41 | nn.Linear(raw_message_dimension // 2, message_dimension), 42 | ) 43 | 44 | def compute_message(self, raw_messages): 45 | messages = self.mlp(raw_messages) 46 | 47 | return messages 48 | 49 | 50 | class IdentityMessageFunction(MessageFunction): 51 | """ 52 | message function returns m(t) = raw_message 53 | 54 | """ 55 | 56 | def compute_message(self, raw_messages): 57 | 58 | return raw_messages 59 | 60 | 61 | ''' 62 | ############### 63 | REFERENCE ENTRY 64 | ############### 65 | ''' 66 | def get_message_function(module_type, raw_message_dimension, message_dimension): 67 | if module_type == "mlp": 68 | return MLPMessageFunction(raw_message_dimension, message_dimension) 69 | elif module_type == "identity": 70 | return IdentityMessageFunction() 71 | else: 72 | raise ValueError("Message function {} not implemented".format(module_type)) -------------------------------------------------------------------------------- /modules/memory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 17:56:38 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | from collections import defaultdict 11 | 12 | 13 | class Memory(nn.Module): 14 | """ 15 | Memory class, represented as 'S' in the paper 16 | 17 | INIT INPUT: 18 | n_nodes: Number of unique ndoes 19 | memory_dimension: Memory dimension 20 | input_dimension: Message dimension 21 | """ 22 | 23 | def __init__(self, n_nodes, memory_dimension, input_dimension, device="cpu"): 24 | 25 | super(Memory, self).__init__() 26 | 27 | self.n_nodes = n_nodes 28 | self.memory_dimension = memory_dimension 29 | self.input_dimension = input_dimension 30 | # self.message_dimension = message_dimension 31 | self.device = device 32 | 33 | # self.combination_method = combination_method 34 | 35 | self.__init_memory__() 36 | 37 | def __init_memory__(self): 38 | """ 39 | Initializes the memory to all zeros. 40 | It should be called at the start of each epoch. 41 | """ 42 | 43 | # Treat memory as parameter so that it is saved and loaded together with the model 44 | # requires_grad has been set as FALSE 45 | self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device), requires_grad=False) 46 | self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device), requires_grad=False) 47 | 48 | self.messages = defaultdict(list) 49 | 50 | def store_raw_messages(self, nodes, node_id_to_messages): 51 | """ 52 | Set nodes' raw message (i.e. self.message) by values in node_id_to_messages 53 | """ 54 | for node in nodes: 55 | self.messages[node].extend(node_id_to_messages[node]) 56 | 57 | def get_memory(self, node_idxs): 58 | """ 59 | Return node_idxs' memory 60 | """ 61 | return self.memory[node_idxs, :] 62 | 63 | def set_memory(self, node_idxs, values): 64 | """ 65 | Set node_idxs' memory by values 66 | """ 67 | self.memory[node_idxs, :] = values 68 | 69 | def get_last_update(self, node_idxs): 70 | """ 71 | Return node_idxs' last updated timestamp 72 | """ 73 | return self.last_update[node_idxs] 74 | 75 | def backup_memory(self): 76 | """ 77 | Return a copy of all nodes' memory, last update timestamp, and message 78 | """ 79 | messages_clone = {} 80 | for k, v in self.messages.items(): 81 | messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v] 82 | 83 | return self.memory.data.clone(), self.last_update.data.clone(), messages_clone 84 | 85 | def restore_memory(self, memory_backup): 86 | """ 87 | Set all nodes' memory, last update timestamp, and message by using memory_backup 88 | """ 89 | self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone() 90 | 91 | self.messages = defaultdict(list) 92 | for k, v in memory_backup[2].items(): 93 | self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v] 94 | 95 | def detach_memory(self): 96 | """ 97 | Detach memory and all stored messages from the network 98 | """ 99 | self.memory.detach_() 100 | 101 | # Detach all stored messages 102 | for k, v in self.messages.items(): 103 | new_node_messages = [] 104 | for message in v: 105 | new_node_messages.append((message[0].detach(), message[1])) 106 | 107 | self.messages[k] = new_node_messages 108 | 109 | def clear_messages(self, nodes): 110 | """ 111 | Clear given nodes' message 112 | """ 113 | for node in nodes: 114 | self.messages[node] = [] -------------------------------------------------------------------------------- /modules/temporal_attention.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 22:27:23 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import torch 9 | from torch import nn 10 | 11 | from utils.utils import MergeLayer 12 | 13 | 14 | class TemporalAttentionLayer(torch.nn.Module): 15 | """ 16 | Temporal attention layer. Return the temporal embedding of a node given the node itself, 17 | its neighbors and the edge timestamps. 18 | """ 19 | 20 | def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim, 21 | output_dimension, n_head=2, dropout=0.1): 22 | super(TemporalAttentionLayer, self).__init__() 23 | 24 | self.n_head = n_head 25 | 26 | self.feat_dim = n_node_features 27 | self.time_dim = time_dim 28 | 29 | self.query_dim = n_node_features + time_dim 30 | self.key_dim = n_neighbors_features + time_dim + n_edge_features 31 | 32 | self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension) 33 | 34 | self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim, 35 | kdim=self.key_dim, 36 | vdim=self.key_dim, 37 | num_heads=n_head, 38 | dropout=dropout) 39 | 40 | def forward(self, src_node_features, src_time_features, neighbors_features, 41 | neighbors_time_features, edge_features, neighbors_padding_mask): 42 | """ 43 | Temporal attention model 44 | :param src_node_features: float Tensor of shape [batch_size, n_node_features] 45 | :param src_time_features: float Tensor of shape [batch_size, 1, time_dim] 46 | :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features] 47 | :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors, 48 | time_dim] 49 | :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features] 50 | :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors] 51 | :return: 52 | attn_output: float Tensor of shape [1, batch_size, n_node_features] 53 | attn_output_weights: [batch_size, 1, n_neighbors] 54 | """ 55 | 56 | src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1) 57 | 58 | query = torch.cat([src_node_features_unrolled, src_time_features], dim=2) 59 | key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2) 60 | 61 | # print(neighbors_features.shape, edge_features.shape, neighbors_time_features.shape) 62 | # Reshape tensors so to expected shape by multi head attention 63 | query = query.permute([1, 0, 2]) # [1, batch_size, num_of_features] 64 | key = key.permute([1, 0, 2]) # [n_neighbors, batch_size, num_of_features] 65 | 66 | # Compute mask of which source nodes have no valid neighbors 67 | invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True) 68 | # If a source node has no valid neighbor, set it's first neighbor to be valid. This will 69 | # force the attention to just 'attend' on this neighbor (which has the same features as all 70 | # the others since they are fake neighbors) and will produce an equivalent result to the 71 | # original tgat paper which was forcing fake neighbors to all have same attention of 1e-10 72 | neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False 73 | 74 | # print(query.shape, key.shape) 75 | 76 | attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key, 77 | key_padding_mask=neighbors_padding_mask) 78 | 79 | # mask = torch.unsqueeze(neighbors_padding_mask, dim=2) # mask [B, N, 1] 80 | # mask = mask.permute([0, 2, 1]) 81 | # attn_output, attn_output_weights = self.multi_head_target(q=query, k=key, v=key, 82 | # mask=mask) 83 | 84 | attn_output = attn_output.squeeze() 85 | attn_output_weights = attn_output_weights.squeeze() 86 | 87 | # Source nodes with no neighbors have an all zero attention output. The attention output is 88 | # then added or concatenated to the original source node features and then fed into an MLP. 89 | # This means that an all zero vector is not used. 90 | attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0) 91 | attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0) 92 | 93 | # Skip connection with temporal attention over neighborhood and the features of the node itself 94 | attn_output = self.merger(attn_output, src_node_features) 95 | 96 | return attn_output, attn_output_weights -------------------------------------------------------------------------------- /utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 26 14:52:24 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import numpy as np 9 | import pandas as pd 10 | from pathlib import Path 11 | import argparse 12 | 13 | 14 | def preprocess(data_name): 15 | """ 16 | u: Users 17 | i: Items 18 | ts: Timestamps 19 | label: Dynamic node labels 20 | feat: Interaction features 21 | 22 | See JODIE paper for deatials about the meaning of user and item 23 | """ 24 | u_list, i_list, ts_list, label_list = [], [], [], [] 25 | feat_l = [] 26 | idx_list = [] 27 | 28 | with open(data_name) as f: 29 | 30 | # s = next(f) 31 | 32 | for idx, line in enumerate(f): 33 | e = line.strip().split(',') 34 | u = int(e[0]) 35 | i = int(e[1]) 36 | ts = float(e[2]) 37 | label = float(e[3]) 38 | feat = np.array([float(x) for x in e[4:]]) 39 | 40 | u_list.append(u) 41 | i_list.append(i) 42 | ts_list.append(ts) 43 | label_list.append(label) 44 | idx_list.append(idx) 45 | feat_l.append(feat) 46 | 47 | return pd.DataFrame({'u': u_list, 48 | 'i': i_list, 49 | 'ts': ts_list, 50 | 'label': label_list, 51 | 'idx': idx_list}), np.array(feat_l) 52 | 53 | 54 | def reindex(df, bipartite=True): 55 | """ 56 | Treat users and items as "nodes", their interactions as "temporal edges" 57 | Specifically, users are "source nodes", and items are "destination nodes" in a bipartite graph 58 | 59 | df looks like this: 60 | 61 | u i ts label idx 62 | 0 0 0 0.0 0.0 0 63 | 1 1 1 36.0 0.0 1 64 | 2 1 1 77.0 0.0 2 65 | 3 2 2 131.0 0.0 3 66 | 4 1 1 150.0 0.0 4 67 | ... ... ... ... ... ... 68 | 157469 2003 632 2678155.0 0.0 157469 69 | 157470 3762 798 2678158.0 0.0 157470 70 | 157471 2399 495 2678293.0 0.0 157471 71 | 157472 7479 920 2678333.0 0.0 157472 72 | 157473 2399 495 2678373.0 0.0 157473 73 | 74 | new_df looks like this: 75 | u i ts label idx 76 | 0 1 8228 0.0 0.0 1 77 | 1 2 8229 36.0 0.0 2 78 | 2 2 8229 77.0 0.0 3 79 | 3 3 8230 131.0 0.0 4 80 | 4 2 8229 150.0 0.0 5 81 | ... ... ... ... ... ... 82 | 157469 2004 8860 2678155.0 0.0 157470 83 | 157470 3763 9026 2678158.0 0.0 157471 84 | 157471 2400 8723 2678293.0 0.0 157472 85 | 157472 7480 9148 2678333.0 0.0 157473 86 | 157473 2400 8723 2678373.0 0.0 157474 87 | """ 88 | new_df = df.copy() 89 | 90 | if bipartite: 91 | 92 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 93 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 94 | 95 | upper_u = df.u.max() + 1 # last source node index 96 | new_i = df.i + upper_u # create dest node indeies after source nodes 97 | 98 | new_df.i = new_i # dest node indeies 99 | new_df.u += 1 # source node indeies (start from 1) 100 | new_df.i += 1 101 | new_df.idx += 1 102 | 103 | else: 104 | 105 | new_df.u += 1 106 | new_df.i += 1 107 | new_df.idx += 1 108 | 109 | return new_df 110 | 111 | def run(data_name, bipartite=True): 112 | Path("data/").mkdir(parents=True, exist_ok=True) 113 | PATH = './data/{}.csv'.format(data_name) 114 | OUT_DF = './data/ml_{}.csv'.format(data_name) 115 | OUT_FEAT = './data/ml_{}.npy'.format(data_name) 116 | OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) 117 | 118 | df, feat = preprocess(PATH) # get the interaction feature vectors and a dataframe which contains index, u, i, ts, label 119 | new_df = reindex(df, bipartite) # get bipartite version of df 120 | 121 | empty = np.zeros(feat.shape[1])[np.newaxis, :] # with shape [1, feat_dim] 122 | feat = np.vstack([empty, feat]) # with shape [interactions, feat_dim] 123 | 124 | max_idx = max(new_df.u.max(), new_df.i.max()) # number of nodes 125 | rand_feat = np.zeros((max_idx + 1, 172)) # initialize node features with fixed 172 dimension size for datasets without dynamic node features 126 | 127 | new_df.to_csv(OUT_DF) # temporal bipartite interaction graph 128 | np.save(OUT_FEAT, feat) # interaction (i.e. Temporal edge) features 129 | np.save(OUT_NODE_FEAT, rand_feat) # initial node features 130 | 131 | ### Entry 132 | parser = argparse.ArgumentParser('Interface for TGN data preprocessing') 133 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', default='wikipedia') 134 | 135 | args = parser.parse_args() 136 | run(args.data) -------------------------------------------------------------------------------- /modules/memory_updater.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 16:51:23 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | from torch import nn 9 | import torch 10 | 11 | 12 | class MemoryUpdater(nn.Module): 13 | """ 14 | Abstract class for updating node memory 15 | """ 16 | def update_memory(self, unique_node_ids, unique_messages, timestamps): 17 | pass 18 | 19 | class SequenceMemoryUpdater(MemoryUpdater): 20 | """ 21 | RNN based memory updater. 22 | Node's memory as the hidden state of RNN, aggregated message as the input, updated memory as the new hidden state 23 | 24 | INPUT: 25 | memory: Memory instance 26 | message_dimension: The dim of the message 27 | memory_dimension: The dim of the memory 28 | --------------------------------------- 29 | unique_node_ids: A list of unique node ids 30 | unique_messages: A tensor of shape [unique_node_ids, aggregated message] 31 | timestamps: A tensor contains corresponding timestamp for those aggregated messages 32 | 33 | OUTPUT: 34 | update_memory(): There is no output, we update node's memory by referring set_memory() method 35 | 36 | updated_memory: A tensor of shape [unique_nodes, memory_dimension] 37 | updated_last_update: A tensor of shape [unique_nodes] 38 | 39 | EXAMPLE (last_aggregator + memory_updator): 40 | node_ids = [0, 1, 1, 2] 41 | messages = {0: [(tensor([1., 2., 3., 4., 5.]), tensor(1))], 42 | 1: [(tensor([2., 3., 4., 5., 6.]), tensor(1)), (tensor([3., 4., 5., 6., 7.]), tensor(2))], 43 | 2: [(tensor([4., 5., 6., 7., 8.]), tensor(2))]} 44 | 45 | ==> 46 | to_update_node_ids: [0, 1, 2] 47 | 48 | unique_messages: tensor([[1., 2., 3., 4., 5.], 49 | [3., 4., 5., 6., 7.], 50 | [4., 5., 6., 7., 8.]]) 51 | 52 | unique_timestamps: tensor([1., 2., 2.]) 53 | 54 | ==> 55 | 56 | updated_memory: tensor([[-0.8096], 57 | [-0.9309], 58 | [-0.9771]], grad_fn=) 59 | 60 | updated_last_update: tensor([1., 2., 2.]) 61 | """ 62 | def __init__(self, memory, message_dimension, memory_dimension, device): 63 | super(SequenceMemoryUpdater, self).__init__() 64 | self.memory = memory 65 | self.layer_norm = torch.nn.LayerNorm(memory_dimension) 66 | self.message_dimension = message_dimension 67 | self.device = device 68 | 69 | def update_memory(self, unique_node_ids, unique_messages, timestamps): 70 | """ 71 | Linked with Memory instance to update node's memory 72 | """ 73 | if len(unique_node_ids) <= 0: # This will happen at the very begining if memory_update_at_start 74 | return 75 | 76 | assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ 77 | "update memory to time in the past" 78 | 79 | memory = self.memory.get_memory(unique_node_ids) # get the memory of these nodes 80 | self.memory.last_update[unique_node_ids] = timestamps # set the last update timestamp of node ids 81 | 82 | # E.g. S_1(t1) <-- aggregated message m^line_1(t1) + previous memory S_1(t0) 83 | updated_memory = self.memory_updater(unique_messages, memory) 84 | 85 | self.memory.set_memory(unique_node_ids, updated_memory) 86 | 87 | def get_updated_memory(self, unique_node_ids, unique_messages, timestamps): 88 | """ 89 | Detached from Memory instance to update node's memory and then return updated_memory and updated_last_update 90 | """ 91 | if len(unique_node_ids) <= 0: # This will happen at the very begining if memory_update_at_start 92 | return self.memory.memory.data.clone(), self.memory.last_update.data.clone() 93 | 94 | assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \ 95 | "update memory to time in the past" 96 | 97 | updated_memory = self.memory.memory.data.clone() 98 | updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids]) 99 | 100 | updated_last_update = self.memory.last_update.data.clone() 101 | updated_last_update[unique_node_ids] = timestamps 102 | 103 | return updated_memory, updated_last_update 104 | 105 | class GRUMemoryUpdater(SequenceMemoryUpdater): 106 | def __init__(self, memory, message_dimension, memory_dimension, device): 107 | super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device) 108 | 109 | self.memory_updater = nn.GRUCell(input_size=message_dimension, 110 | hidden_size=memory_dimension) 111 | 112 | class RNNMemoryUpdater(SequenceMemoryUpdater): 113 | def __init__(self, memory, message_dimension, memory_dimension, device): 114 | super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device) 115 | 116 | self.memory_updater = nn.RNNCell(input_size=message_dimension, 117 | hidden_size=memory_dimension) 118 | 119 | ''' 120 | ############### 121 | REFERENCE ENTRY 122 | ############### 123 | ''' 124 | def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device): 125 | if module_type == "gru": 126 | return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device) 127 | elif module_type == "rnn": 128 | return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device) -------------------------------------------------------------------------------- /modules/message_aggregator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 14:16:46 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | 9 | from collections import defaultdict 10 | import torch 11 | import numpy as np 12 | 13 | 14 | class MessageAggregator(torch.nn.Module): 15 | """ 16 | Abstract class for the message aggregator module 17 | """ 18 | def __init__(self, device): 19 | super(MessageAggregator, self).__init__() 20 | self.device = device 21 | 22 | def aggregate(self, node_ids, messages): 23 | """ 24 | Aggregate functions to be implemented 25 | """ 26 | 27 | # def group_by_id(self, node_ids, messages, timestamps): 28 | # """ 29 | # NOT HAS BEEN USED 30 | # """ 31 | # node_id_to_messages = defaultdict(list) 32 | 33 | # for i, node_id in enumerate(node_ids): 34 | # node_id_to_messages[node_id].append((messages[i], timestamps[i])) 35 | 36 | # return node_id_to_messages 37 | 38 | 39 | class LastMessageAggregator(MessageAggregator): 40 | def __init__(self, device): 41 | super(LastMessageAggregator, self).__init__(device) 42 | 43 | def aggregate(self, node_ids, messages): 44 | """ 45 | Given a list of node ids in a batch and associated messages m_i(t), aggregate different 46 | messages for the same id using the lastest message. 47 | 48 | 49 | INPUT: 50 | node_ids: A list of node ids of length batch_size 51 | messages: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]} 52 | 53 | P.S. timestamps: A tensor of shape [batch_size] 54 | 55 | OUTPUT: 56 | to_update_node_ids: A list of unique node ids 57 | unique_messages: A tensor of shape [unique_node_ids, aggregated message] 58 | unique_timestamps: A tensor contains corresponding timestamp for those aggregated messages 59 | 60 | 61 | EXAMPLE: 62 | node_ids = [1,2,2,3] 63 | messages = {1: [(tensor([1., 2., 3., 4., 5.]), tensor(1))], 64 | 2: [(tensor([2., 3., 4., 5., 6.]), tensor(1)), (tensor([3., 4., 5., 6., 7.]), tensor(2))], 65 | 3: [(tensor([4., 5., 6., 7., 8.]), tensor(2))]} 66 | 67 | ==> 68 | to_update_node_ids: [1, 2, 3] 69 | 70 | unique_messages: tensor([[1., 2., 3., 4., 5.], 71 | [3., 4., 5., 6., 7.], 72 | [4., 5., 6., 7., 8.]]) 73 | 74 | unique_timestamps: tensor([1, 2, 2]) 75 | 76 | """ 77 | unique_node_ids = np.unique(node_ids) 78 | unique_messages = [] 79 | unique_timestamps = [] 80 | 81 | to_update_node_ids = [] 82 | 83 | for node_id in unique_node_ids: 84 | if len(messages[node_id]) > 0: 85 | to_update_node_ids.append(node_id) 86 | unique_messages.append(messages[node_id][-1][0]) 87 | unique_timestamps.append(messages[node_id][-1][1]) 88 | 89 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] 90 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] 91 | 92 | return to_update_node_ids, unique_messages, unique_timestamps 93 | 94 | 95 | class MeanMessageAggregator(MessageAggregator): 96 | def __init__(self, device): 97 | super(MeanMessageAggregator, self).__init__(device) 98 | 99 | def aggregate(self, node_ids, messages): 100 | """ 101 | Given a list of node ids in a batch and associated messages m_i(t_j), aggregate different 102 | messages for the same id by averaging them. 103 | 104 | 105 | INPUT: 106 | node_ids: A list of node ids of length batch_size 107 | messages: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]} 108 | 109 | P.S. timestamps: A tensor of shape [batch_size] 110 | 111 | OUTPUT: 112 | to_update_node_ids: A list of unique node ids 113 | unique_messages: A tensor of shape [unique_node_ids, aggregated message] 114 | unique_timestamps: A tensor contains corresponding timestamp for those aggregated messages 115 | 116 | 117 | EXAMPLE: 118 | node_ids = [1,2,2,3] 119 | messages = {1: [(tensor([1., 2., 3., 4., 5.]), tensor(1))], 120 | 2: [(tensor([2., 3., 4., 5., 6.]), tensor(1)), (tensor([3., 4., 5., 6., 7.]), tensor(2))], 121 | 3: [(tensor([4., 5., 6., 7., 8.]), tensor(2))]} 122 | 123 | ==> 124 | to_update_node_ids: [1, 2, 3] 125 | 126 | unique_messages: tensor([[1.0000, 2.0000, 3.0000, 4.0000, 5.0000], 127 | [2.5000, 3.5000, 4.5000, 5.5000, 6.5000], 128 | [4.0000, 5.0000, 6.0000, 7.0000, 8.0000]]) 129 | 130 | unique_timestamps: tensor([1, 2, 2]) 131 | 132 | """ 133 | unique_node_ids = np.unique(node_ids) 134 | unique_messages = [] 135 | unique_timestamps = [] 136 | 137 | to_update_node_ids = [] 138 | n_messages = 0 139 | 140 | for node_id in unique_node_ids: 141 | if len(messages[node_id]) > 0: 142 | n_messages += len(messages[node_id]) 143 | to_update_node_ids.append(node_id) 144 | unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0)) # This is the difference 145 | unique_timestamps.append(messages[node_id][-1][1]) 146 | 147 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else [] 148 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else [] 149 | 150 | return to_update_node_ids, unique_messages, unique_timestamps 151 | 152 | 153 | ''' 154 | ############### 155 | REFERENCE ENTRY 156 | ############### 157 | ''' 158 | def get_message_aggregator(aggregator_type, device): 159 | if aggregator_type == "last": 160 | return LastMessageAggregator(device=device) 161 | elif aggregator_type == "mean": 162 | return MeanMessageAggregator(device=device) 163 | else: 164 | raise ValueError("Message aggregator {} not implemented".format(aggregator_type)) -------------------------------------------------------------------------------- /utils/data_processing.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 26 14:35:26 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import numpy as np 9 | import random 10 | import pandas as pd 11 | 12 | 13 | class Data: 14 | 15 | def __init__(self, sources, destinations, timestamps, edge_idxs, labels): 16 | 17 | self.sources = sources 18 | self.destinations = destinations 19 | self.timestamps = timestamps 20 | self.edge_idxs = edge_idxs 21 | self.labels = labels 22 | self.n_interactions = len(sources) 23 | self.unique_nodes = set(sources) | set(destinations) 24 | self.n_unique_nodes = len(self.unique_nodes) 25 | 26 | 27 | def get_data(dataset_name, different_new_nodes_between_val_and_test=False): 28 | """ 29 | INPUTS: 30 | dataset_name: Wikipedia or Reddit 31 | different_new_nodes_between_val_and_test: Val and test set will use different unseen nodes (to test inductiveness) 32 | 33 | OUTPUTS: 34 | node_features: Array of shape [n_nodes, node_feat_dim], node_feat_dim is fixed to 172 35 | edge_features: Array of shape [n_interactions, edge_feat_dim] 36 | full_data: Data instance; It contains interactions of the whole temporal graph (i.e. acrossing the entire timespan) 37 | train_data: Data instance; It contains interactions happening before the validation time which not contains unseen nodes for testing inductiveness 38 | val_data: Data instance; It contains interactions after training time but before the testing time. This setting may contain nodes in train_data (transductive setting) 39 | test_data: Similar to val_data, this setting may contain nodes in train_data (transductive setting) 40 | new_node_val_data: Inductive val_data with edges that at least have one unseen node (inductive setting) 41 | new_node_test_data: Inductive test_data with edges that at least have one unseen node (inductive setting) 42 | 43 | P.S. 70%-15%-15% data split ratio applied 44 | """ 45 | ### Load data and train val test split 46 | graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name)) 47 | edge_features = np.load('./data/ml_{}.npy'.format(dataset_name)) 48 | node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name)) 49 | 50 | # val and test splite timestamp: 70%-15%-15% 51 | val_time, test_time = list(np.quantile(graph_df.ts, [0.70, 0.85])) # return two float numbers 52 | 53 | # list of length n_interactions, which may contain duplicate nodes 54 | sources = graph_df.u.values 55 | destinations = graph_df.i.values 56 | edge_idxs = graph_df.idx.values 57 | labels = graph_df.label.values 58 | timestamps = graph_df.ts.values 59 | 60 | full_data = Data(sources, destinations, timestamps, edge_idxs, labels) 61 | 62 | random.seed(2020) 63 | 64 | node_set = set(sources) | set(destinations) # a set of all nodes (no duplications) 65 | n_total_unique_nodes = len(node_set) # notice: set() will remove duplications 66 | 67 | # get nodes which appear at val & test time 68 | test_node_set = set(sources[timestamps > val_time]).union(set(destinations[timestamps > val_time])) 69 | 70 | # Sample (10% * n_nodes) nodes from val & test nodes to be unseen nodes 71 | new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes))) 72 | 73 | # Mask saying for each source and destination whether they are unseen nodes 74 | # Two lists of length n_interactions where True for element (i.e. interaction) if src_node or dest_node belongs to new_test_node_set (i.e. unseen nodes) 75 | new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values 76 | new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values 77 | 78 | # A list of length n_interaction where True for element (i.e. interaction) if both src_node and dest_node are not unseen nodes 79 | observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask) 80 | 81 | # For train we keep edges happening before the validation time and not involve any unseen node 82 | train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask) 83 | train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask], edge_idxs[train_mask], labels[train_mask]) 84 | 85 | train_node_set = set(train_data.sources).union(train_data.destinations) 86 | assert len(train_node_set & new_test_node_set) == 0 87 | 88 | # define the new nodes sets for testing inductiveness of the model 89 | new_node_set = node_set - train_node_set 90 | 91 | # # TODO: Their relationships 92 | # print(len(node_set)) 93 | # print(len(test_node_set)) 94 | # print(len(new_test_node_set)) 95 | # print(len(train_node_set)) 96 | # print(len(new_node_set)) 97 | 98 | # exit() 99 | 100 | # val and test mask where we don't consider unseen nodes issue 101 | val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time) 102 | test_mask = timestamps > test_time 103 | 104 | if different_new_nodes_between_val_and_test: 105 | # val and test set will contain different unseen nodes 106 | n_new_nodes = len(new_test_node_set) // 2 107 | val_new_node_set = set(list(new_test_node_set)[:n_new_nodes]) 108 | test_new_node_set = set(list(new_test_node_set)[n_new_nodes:]) 109 | # one of each (src or dest) contains unseen nodes then True 110 | edge_contains_new_val_node_mask = np.array([(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)]) 111 | edge_contains_new_test_node_mask = np.array([(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)]) 112 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask) 113 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask) 114 | 115 | else: 116 | # val and test set may contain same unseen nodes 117 | edge_contains_new_node_mask = np.array([(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)]) 118 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask) 119 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask) 120 | 121 | # validation and test with all edges 122 | val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask], edge_idxs[val_mask], labels[val_mask]) 123 | test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask], edge_idxs[test_mask], labels[test_mask]) 124 | 125 | # validation and test with edges that at least has one new node (not in training set) 126 | new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask], 127 | timestamps[new_node_val_mask], edge_idxs[new_node_val_mask], 128 | labels[new_node_val_mask]) 129 | 130 | new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask], 131 | timestamps[new_node_test_mask], edge_idxs[new_node_test_mask], 132 | labels[new_node_test_mask]) 133 | 134 | print("The dataset has {} interactions, involving {} different unique nodes".format(full_data.n_interactions, full_data.n_unique_nodes)) 135 | print("The training dataset has {} interactions, involving {} different unique nodes".format(train_data.n_interactions, train_data.n_unique_nodes)) 136 | print("The validation dataset has {} interactions, involving {} different unique nodes".format(val_data.n_interactions, val_data.n_unique_nodes)) 137 | print("The test dataset has {} interactions, involving {} different unique nodes".format(test_data.n_interactions, test_data.n_unique_nodes)) 138 | print("The inductive validation dataset has {} interactions, involving {} different unique nodes".format(new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes)) 139 | print("The inductive test dataset has {} interactions, involving {} different unique nodes".format(new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes)) 140 | print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(len(new_test_node_set))) 141 | 142 | return node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data 143 | 144 | 145 | def compute_time_statistics(sources, destinations, timestamps): 146 | last_timestamp_sources = dict() 147 | last_timestamp_dst = dict() 148 | all_timediffs_src = [] 149 | all_timediffs_dst = [] 150 | for k in range(len(sources)): 151 | source_id = sources[k] 152 | dest_id = destinations[k] 153 | c_timestamp = timestamps[k] 154 | if source_id not in last_timestamp_sources.keys(): 155 | last_timestamp_sources[source_id] = 0 156 | if dest_id not in last_timestamp_dst.keys(): 157 | last_timestamp_dst[dest_id] = 0 158 | all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id]) 159 | all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id]) 160 | last_timestamp_sources[source_id] = c_timestamp 161 | last_timestamp_dst[dest_id] = c_timestamp 162 | assert len(all_timediffs_src) == len(sources) 163 | assert len(all_timediffs_dst) == len(sources) 164 | mean_time_shift_src = np.mean(all_timediffs_src) 165 | std_time_shift_src = np.std(all_timediffs_src) 166 | mean_time_shift_dst = np.mean(all_timediffs_dst) 167 | std_time_shift_dst = np.std(all_timediffs_dst) 168 | return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 26 21:30:25 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | import math 11 | from sklearn.metrics import average_precision_score, roc_auc_score 12 | 13 | 14 | ############################## Neighbor Finder ############################### 15 | class NeighborFinder: 16 | """ 17 | INIT INPUTS: 18 | adj_list: A list of shape [max_node_idx, 1] in this format: [[src_node/dest_node, edge_idx to dest_node/src_node, timestamp]] 19 | uniform: Bool, if Ture then we randomly sample n_neighbors before the cut_time 20 | seed: random seed for 21 | """ 22 | def __init__(self, adj_list, uniform=False, seed=None): 23 | self.node_to_neighbors = [] # neighbor ids 24 | self.node_to_edge_idxs = [] # corresponding edge idx 25 | self.node_to_edge_timestamps = [] # corresponding timestamp 26 | 27 | for neighbors in adj_list: 28 | # neighbors is a tuple: (neighbor, edge_idx, timestamp) 29 | # We sort the list based on timestamp 30 | sorted_neighhbors = sorted(neighbors, key=lambda x: x[2]) 31 | self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors])) 32 | self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors])) 33 | self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors])) 34 | 35 | self.uniform = uniform 36 | 37 | if seed is not None: 38 | self.seed = seed 39 | self.random_state = np.random.RandomState(self.seed) 40 | 41 | def find_before(self, src_idx, cut_time): 42 | """ 43 | Extracts all the interactions happening before cut_time for src_idx in the overall interaction graph. 44 | The returned interactions are sorted by time. 45 | 3 lists will be returned: 46 | node_to_neighbors: List of length [temporal_neighbors_before_cut_time] 47 | node_to_edge_idxs: List of length [temporal_neighbors_before_cut_time] 48 | node_to_edge_timestamps: List of length [temporal_neighbors_before_cut_time] 49 | """ 50 | i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time) 51 | 52 | return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i] 53 | 54 | def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20): 55 | """ 56 | Given a list source_nodes and correspond cut times (i.e. their current timestamps), 57 | this method extracts a list of sampled temporal neighbors of each node in source_nodes. 58 | 59 | INPUTS: 60 | source_nodes: A list (int) of nodes which temporal neighbors need to be extracted 61 | timestamps: A list (float) of timestamps for nodes in source_nodes 62 | n_neighbors: Extract this number of neighbors between time range [0, timestamps] 63 | 64 | OUTPUTS: 65 | neighbors: Arrary of shape [source_nodes, n_neighbors] 66 | edge_idxs: Arrary of shape [source_nodes, n_neighbors] 67 | edge_times: Arrary of shape [source_nodes, n_neighbors] 68 | 69 | """ 70 | assert (len(source_nodes) == len(timestamps)) 71 | 72 | tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1 73 | 74 | # ALL interactions described in these matrices are sorted in each row by time 75 | neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i] 76 | edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i] 77 | edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i] 78 | 79 | for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)): 80 | # extracts all neighbors, interactions (i.e. edges) indexes and timestamps of ALL interactions of source_nodes happening before their corresponding cut_time (i.e. timestamps) 81 | source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node, timestamp) 82 | 83 | if len(source_neighbors) > 0 and n_neighbors > 0: 84 | 85 | if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling 86 | 87 | sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors) # random sample n_neighbors temporal neighbors 88 | 89 | neighbors[i, :] = source_neighbors[sampled_idx] 90 | edge_times[i, :] = source_edge_times[sampled_idx] 91 | edge_idxs[i, :] = source_edge_idxs[sampled_idx] 92 | 93 | # re-sort based on time cuz provided source_nodes are not sorted yet 94 | # so that neighbors, edge_times, and edge_idxs are all sorted by time 95 | pos = edge_times[i, :].argsort() 96 | neighbors[i, :] = neighbors[i, :][pos] 97 | edge_times[i, :] = edge_times[i, :][pos] 98 | edge_idxs[i, :] = edge_idxs[i, :][pos] 99 | 100 | else: 101 | 102 | # Take most recent n_neighbors interactions 103 | source_edge_times = source_edge_times[-n_neighbors:] 104 | source_neighbors = source_neighbors[-n_neighbors:] 105 | source_edge_idxs = source_edge_idxs[-n_neighbors:] 106 | 107 | assert (len(source_neighbors) <= n_neighbors) 108 | assert (len(source_edge_times) <= n_neighbors) 109 | assert (len(source_edge_idxs) <= n_neighbors) 110 | 111 | neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors 112 | edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times 113 | edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs 114 | 115 | return neighbors, edge_idxs, edge_times 116 | 117 | 118 | def get_neighbor_finder(data, uniform, max_node_idx=None): 119 | 120 | max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx 121 | adj_list = [[] for _ in range(max_node_idx + 1)] 122 | for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations, 123 | data.edge_idxs, 124 | data.timestamps): 125 | adj_list[source].append((destination, edge_idx, timestamp)) 126 | adj_list[destination].append((source, edge_idx, timestamp)) 127 | 128 | return NeighborFinder(adj_list, uniform=uniform) 129 | 130 | 131 | ############################ MLP Score function ############################## 132 | class MergeLayer(torch.nn.Module): 133 | """ 134 | Compute probability on an edge given two node embeddings 135 | 136 | INIT INPUTS: 137 | dim1 = dim2 = dim3 = node_feat_dim, emb_dim = node_feat_dim 138 | dim4 = 1 139 | 140 | INPUTS: 141 | x1: torch.cat([source_node_embedding, source_node_embedding], dim=0) with shape [batch_size * 2, emb_dim] 142 | x2: torch.cat([destination_node_embedding, negative_node_embedding]) with shape [batch_size * 2, emb_dim] 143 | """ 144 | def __init__(self, dim1, dim2, dim3, dim4): 145 | super().__init__() 146 | self.fc1 = torch.nn.Linear(dim1 + dim2, dim3) 147 | self.fc2 = torch.nn.Linear(dim3, dim4) 148 | self.act = torch.nn.ReLU() 149 | 150 | torch.nn.init.xavier_normal_(self.fc1.weight) 151 | torch.nn.init.xavier_normal_(self.fc2.weight) 152 | 153 | def forward(self, x1, x2): 154 | x = torch.cat([x1, x2], dim=1) # [batch_size * 2, emb_dim * 2] 155 | h = self.act(self.fc1(x)) # [batch_size * 2, emb_dim] 156 | return self.fc2(h) # [batch_size * 2, 1] 157 | 158 | ############################ Negative Simpler ############################## 159 | class RandEdgeSampler(object): 160 | """ 161 | Negative simpler to randomly simple negatives from provided src and dest list 162 | 163 | INIT INPUTS: 164 | src_list: List of node ids 165 | dst_list: List of node ids 166 | 167 | INPUTS: 168 | size: How many negatives to sample 169 | """ 170 | def __init__(self, src_list, dst_list, seed=None): 171 | self.seed = None 172 | self.src_list = np.unique(src_list) 173 | self.dst_list = np.unique(dst_list) 174 | 175 | if seed is not None: 176 | self.seed = seed 177 | self.random_state = np.random.RandomState(self.seed) 178 | 179 | def sample(self, size): 180 | if self.seed is None: 181 | src_index = np.random.randint(0, len(self.src_list), size) 182 | dst_index = np.random.randint(0, len(self.dst_list), size) 183 | else: 184 | src_index = self.random_state.randint(0, len(self.src_list), size) 185 | dst_index = self.random_state.randint(0, len(self.dst_list), size) 186 | return self.src_list[src_index], self.dst_list[dst_index] 187 | 188 | def reset_random_state(self): 189 | self.random_state = np.random.RandomState(self.seed) 190 | 191 | ############################ EarlyStopMonitor ############################## 192 | class EarlyStopMonitor(object): 193 | 194 | def __init__(self, max_round=3, higher_better=True, tolerance=1e-10): 195 | 196 | self.max_round = max_round 197 | self.num_round = 0 198 | 199 | self.epoch_count = 0 200 | self.best_epoch = 0 201 | 202 | self.last_best = None 203 | self.higher_better = higher_better 204 | self.tolerance = tolerance 205 | 206 | def early_stop_check(self, curr_val): 207 | 208 | if not self.higher_better: 209 | curr_val *= -1 210 | if self.last_best is None: 211 | self.last_best = curr_val 212 | elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance: 213 | self.last_best = curr_val 214 | self.num_round = 0 215 | self.best_epoch = self.epoch_count 216 | else: 217 | self.num_round += 1 218 | 219 | self.epoch_count += 1 220 | 221 | return self.num_round >= self.max_round 222 | 223 | ########################## Evaluate on val & test ########################## 224 | def eval_edge_prediction(model, negative_edge_sampler, data, n_neighbors, batch_size=200): 225 | 226 | # Ensures the random sampler uses a seed for evaluation (i.e. we sample always the same 227 | # negatives for validation / test set) 228 | assert negative_edge_sampler.seed is not None 229 | negative_edge_sampler.reset_random_state() 230 | 231 | val_ap, val_auc = [], [] 232 | with torch.no_grad(): 233 | model = model.eval() 234 | # While usually the test batch size is as big as it fits in memory, here we keep it the same 235 | # size as the training batch size, since it allows the memory to be updated more frequently, 236 | # and later test batches to access information from interactions in previous test batches 237 | # through the memory 238 | TEST_BATCH_SIZE = batch_size 239 | num_test_instance = len(data.sources) 240 | num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE) 241 | 242 | for k in range(num_test_batch): 243 | s_idx = k * TEST_BATCH_SIZE 244 | e_idx = min(num_test_instance, s_idx + TEST_BATCH_SIZE) 245 | sources_batch = data.sources[s_idx:e_idx] 246 | destinations_batch = data.destinations[s_idx:e_idx] 247 | timestamps_batch = data.timestamps[s_idx:e_idx] 248 | edge_idxs_batch = data.edge_idxs[s_idx: e_idx] 249 | 250 | size = len(sources_batch) 251 | _, negative_samples = negative_edge_sampler.sample(size) 252 | 253 | pos_prob, neg_prob = model.compute_edge_probabilities(sources_batch, destinations_batch, 254 | negative_samples, timestamps_batch, 255 | edge_idxs_batch, n_neighbors) 256 | 257 | pred_score = np.concatenate([(pos_prob).cpu().numpy(), (neg_prob).cpu().numpy()]) 258 | true_label = np.concatenate([np.ones(size), np.zeros(size)]) 259 | 260 | val_ap.append(average_precision_score(true_label, pred_score)) 261 | val_auc.append(roc_auc_score(true_label, pred_score)) 262 | 263 | return np.mean(val_ap), np.mean(val_auc) -------------------------------------------------------------------------------- /link_prediction.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Aug 28 16:31:51 2020 4 | 5 | @author: Ming Jin 6 | 7 | TGN - Self-supervised link prediction 8 | """ 9 | 10 | import math 11 | import logging 12 | import time 13 | import sys 14 | import argparse 15 | import torch 16 | import numpy as np 17 | import pickle 18 | from pathlib import Path 19 | 20 | from utils.utils import eval_edge_prediction 21 | from net.tgn import TGN 22 | from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder 23 | from utils.data_processing import get_data, compute_time_statistics 24 | 25 | torch.manual_seed(0) 26 | np.random.seed(0) 27 | 28 | ### Argument and global variables 29 | parser = argparse.ArgumentParser('TGN self-supervised training') 30 | 31 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)', default='wikipedia') 32 | parser.add_argument('--different_new_nodes', action='store_true', help='Whether val and test set use different unseen nodes (to test inductiveness)') 33 | parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints') 34 | parser.add_argument('--batch', type=int, default=200, help='Batch_size') 35 | parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs') 36 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') 37 | parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability') 38 | parser.add_argument('--n_runs', type=int, default=1, help='Number of runs for this script') 39 | parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to backprop') 40 | parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use') 41 | parser.add_argument('--patience', type=int, default=5, help='Patience for early stopping') 42 | 43 | parser.add_argument('--use_memory', action='store_true', help='Whether to augment the model with a node memory') 44 | parser.add_argument('--memory_update_at_end', action='store_true', help='Whether to update memory at the end or at the start of the batch') 45 | parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding') 46 | parser.add_argument('--time_dim', type=int, default=100, help='Dimensions of the time embedding') 47 | parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages') 48 | parser.add_argument('--memory_dim', type=int, default=172, help='Dimensions of the memory for each node') 49 | parser.add_argument('--neighbors', type=int, default=10, help='Number of neighbors to sample') 50 | parser.add_argument('--uniform', action='store_true', help='take uniform sampling from temporal neighbors') 51 | parser.add_argument('--embedding_module', type=str, default="graph_attention", choices=["graph_attention", "graph_sum", "identity", "time"], help='Type of embedding module') 52 | parser.add_argument('--message_function', type=str, default="identity", choices=["mlp", "identity"], help='Type of message function') 53 | parser.add_argument('--aggregator', type=str, default="last", choices=["last", "mean"], help='Type of message aggregator') 54 | parser.add_argument('--memory_updater', type=str, default="gru", choices=["gru", "rnn"], help='Type of memory updater') 55 | parser.add_argument('--n_layer', type=int, default=1, help='Number of network layers') 56 | parser.add_argument('--n_head', type=int, default=2, help='Number of heads used in attention layer') 57 | parser.add_argument('--use_source_embedding_in_message', action='store_true', help='Whether to use the embedding of the source node as part of the message') 58 | parser.add_argument('--use_destination_embedding_in_message', action='store_true', help='Whether to use the embedding of the destination node as part of the message') 59 | 60 | try: 61 | args = parser.parse_args() 62 | except: 63 | parser.print_help() 64 | sys.exit(0) 65 | 66 | BATCH_SIZE = args.batch 67 | NUM_NEIGHBORS = args.neighbors 68 | NUM_NEG = 1 69 | NUM_EPOCH = args.n_epoch 70 | NUM_HEADS = args.n_head 71 | DROP_OUT = args.drop_out 72 | GPU = args.gpu 73 | DATA = args.data 74 | NUM_LAYER = args.n_layer 75 | LEARNING_RATE = args.lr 76 | NODE_DIM = args.node_dim # Notice: node_dim=172 for dataset without node features 77 | TIME_DIM = args.time_dim 78 | USE_MEMORY = args.use_memory 79 | MESSAGE_DIM = args.message_dim 80 | MEMORY_DIM = args.memory_dim 81 | 82 | Path("./saved_models/").mkdir(parents=True, exist_ok=True) 83 | Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True) 84 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth' 85 | get_checkpoint_path = lambda epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth' 86 | 87 | ### set up logger 88 | logging.basicConfig(level=logging.INFO) 89 | logger = logging.getLogger() 90 | logger.setLevel(logging.DEBUG) 91 | Path("log/").mkdir(parents=True, exist_ok=True) 92 | fh = logging.FileHandler('log/{}.log'.format(str(time.time()))) 93 | fh.setLevel(logging.DEBUG) 94 | ch = logging.StreamHandler() 95 | ch.setLevel(logging.WARN) 96 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 97 | fh.setFormatter(formatter) 98 | ch.setFormatter(formatter) 99 | logger.addHandler(fh) 100 | logger.addHandler(ch) 101 | logger.info(args) 102 | 103 | ### Extract data for training, validation and testing 104 | node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \ 105 | new_node_test_data = get_data(DATA, different_new_nodes_between_val_and_test=args.different_new_nodes) 106 | 107 | # Initialize training neighbor finder to retrieve temporal graph 108 | train_ngh_finder = get_neighbor_finder(train_data, args.uniform) 109 | 110 | # Initialize validation and test neighbor finder to retrieve temporal graph 111 | full_ngh_finder = get_neighbor_finder(full_data, args.uniform) 112 | 113 | # Initialize negative samplers 114 | # Set seeds for validation and testing so negatives are the same across different runs 115 | # NB: in the inductive setting, negatives are sampled only amongst other new nodes 116 | train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations) 117 | val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0) 118 | test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2) 119 | nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources, new_node_test_data.destinations, seed=3) 120 | nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations, seed=1) 121 | 122 | # Set device 123 | device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu' 124 | device = torch.device(device_string) 125 | 126 | # Compute time statistics 127 | mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \ 128 | compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps) 129 | 130 | ############################### START ################################# 131 | for i in range(args.n_runs): 132 | 133 | results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix) 134 | Path("results/").mkdir(parents=True, exist_ok=True) 135 | 136 | # Initialize Model 137 | tgn = TGN(neighbor_finder=train_ngh_finder, node_features=node_features, 138 | edge_features=edge_features, device=device, 139 | n_layers=NUM_LAYER, 140 | n_heads=NUM_HEADS, dropout=DROP_OUT, use_memory=USE_MEMORY, 141 | message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM, 142 | memory_update_at_start=not args.memory_update_at_end, 143 | embedding_module_type=args.embedding_module, 144 | message_function=args.message_function, 145 | aggregator_type=args.aggregator, 146 | memory_updater_type=args.memory_updater, 147 | n_neighbors=NUM_NEIGHBORS, 148 | mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src, 149 | mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst, 150 | use_destination_embedding_in_message=args.use_destination_embedding_in_message, 151 | use_source_embedding_in_message=args.use_source_embedding_in_message) 152 | 153 | criterion = torch.nn.BCELoss() 154 | optimizer = torch.optim.Adam(tgn.parameters(), lr=LEARNING_RATE) 155 | tgn = tgn.to(device) 156 | 157 | num_instance = len(train_data.sources) # n_interactions 158 | num_batch = math.ceil(num_instance / BATCH_SIZE) 159 | 160 | logger.info('num of training instances: {}'.format(num_instance)) 161 | logger.info('num of batches per epoch: {}'.format(num_batch)) 162 | idx_list = np.arange(num_instance) 163 | 164 | new_nodes_val_aps = [] 165 | val_aps = [] 166 | epoch_times = [] 167 | total_epoch_times = [] 168 | train_losses = [] 169 | 170 | early_stopper = EarlyStopMonitor(max_round=args.patience) 171 | 172 | ###################### TRAINING ################### 173 | for epoch in range(NUM_EPOCH): 174 | 175 | start_epoch = time.time() 176 | 177 | # Reinitialize memory of the model at the start of each epoch 178 | if USE_MEMORY: 179 | tgn.memory.__init_memory__() 180 | 181 | # Train using only training graph 182 | tgn.set_neighbor_finder(train_ngh_finder) 183 | m_loss = [] 184 | 185 | logger.info('start {} epoch'.format(epoch)) 186 | 187 | ### Start to train on this epoch 188 | for k in range(0, num_batch, args.backprop_every): 189 | 190 | loss = 0 191 | optimizer.zero_grad() 192 | 193 | # Custom loop to allow to perform backpropagation only every a certain number of batches 194 | for j in range(args.backprop_every): 195 | 196 | batch_idx = k + j 197 | 198 | if batch_idx >= num_batch: 199 | continue 200 | 201 | # get a src and dest node training batch 202 | start_idx = batch_idx * BATCH_SIZE 203 | end_idx = min(num_instance, start_idx + BATCH_SIZE) 204 | sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \ 205 | train_data.destinations[start_idx:end_idx] 206 | 207 | # as well as the edge and timestamps for this batch 208 | edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx] 209 | timestamps_batch = train_data.timestamps[start_idx:end_idx] 210 | 211 | # sample batch_size dest negatives (nodes) 212 | size = len(sources_batch) 213 | _, negatives_batch = train_rand_sampler.sample(size) 214 | 215 | # self-supervised labels setting 216 | with torch.no_grad(): 217 | pos_label = torch.ones(size, dtype=torch.float, device=device) 218 | neg_label = torch.zeros(size, dtype=torch.float, device=device) 219 | 220 | # forward propagation 221 | tgn = tgn.train() 222 | pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch, 223 | timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS) 224 | 225 | loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label) 226 | 227 | # backward propagation 228 | loss /= args.backprop_every 229 | loss.backward() 230 | optimizer.step() 231 | m_loss.append(loss.item()) 232 | 233 | ### Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to the start of time 234 | # TODO: If not, "Trying backpropagate but buffers have not been freed" error will happen because: 235 | # 1). For mem_update_at_end: Memory updated at the end may contain this batch information that loss will not cover, 236 | # so we have to detach to ensure the memory has the information that loss has covered to backpropagate. 237 | # 2). For mem_update_at_start: We don't have the issue on (1) but some node messages may be removed after update_memory 238 | # so we may try to backpropagate on those freed messages. 239 | if USE_MEMORY: 240 | tgn.memory.detach_memory() 241 | 242 | epoch_time = time.time() - start_epoch 243 | epoch_times.append(epoch_time) 244 | 245 | ####################### VALIDATION ###################### 246 | # Validation uses the full graph 247 | tgn.set_neighbor_finder(full_ngh_finder) 248 | 249 | if USE_MEMORY: 250 | # Backup memory at the end of training, so later we can restore it and use it for the 251 | # validation on unseen nodes (since validation edges are strictly later in time than training edges) 252 | train_memory_backup = tgn.memory.backup_memory() 253 | 254 | val_ap, val_auc = eval_edge_prediction(model=tgn, negative_edge_sampler=val_rand_sampler, 255 | data=val_data, n_neighbors=NUM_NEIGHBORS) 256 | if USE_MEMORY: 257 | # Backup memory after validation so it can be used for testing (since test edges are 258 | # strictly later in time than validation edges) 259 | val_memory_backup = tgn.memory.backup_memory() 260 | # Restore memory we had at the end of training to be used when validating on unseen nodes. 261 | tgn.memory.restore_memory(train_memory_backup) 262 | 263 | # Validate on unseen nodes 264 | nn_val_ap, nn_val_auc = eval_edge_prediction(model=tgn, negative_edge_sampler=nn_val_rand_sampler, 265 | data=new_node_val_data, n_neighbors=NUM_NEIGHBORS) 266 | 267 | if USE_MEMORY: 268 | # Restore memory we had at the end of validation to get ready for testing if: 269 | # 1). This is last epoch 270 | # 2). Early stopping happen on this epoch 271 | tgn.memory.restore_memory(val_memory_backup) 272 | 273 | new_nodes_val_aps.append(nn_val_ap) 274 | val_aps.append(val_ap) 275 | train_losses.append(np.mean(m_loss)) 276 | 277 | # Save temporary results to disk 278 | pickle.dump({ 279 | "val_aps": val_aps, 280 | "new_nodes_val_aps": new_nodes_val_aps, 281 | "train_losses": train_losses, 282 | "epoch_times": epoch_times, 283 | "total_epoch_times": total_epoch_times 284 | }, open(results_path, "wb")) 285 | 286 | total_epoch_time = time.time() - start_epoch 287 | total_epoch_times.append(total_epoch_time) 288 | 289 | logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time)) 290 | logger.info('Epoch mean loss: {}'.format(np.mean(m_loss))) 291 | logger.info('transductive val auc: {}, inductive val auc: {}'.format(val_auc, nn_val_auc)) 292 | logger.info('transductive val ap: {}, inductive val ap: {}'.format(val_ap, nn_val_ap)) 293 | 294 | # Early stopping 295 | if early_stopper.early_stop_check(val_ap): 296 | logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round)) 297 | logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}') 298 | best_model_path = get_checkpoint_path(early_stopper.best_epoch) 299 | tgn.load_state_dict(torch.load(best_model_path)) 300 | logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference') 301 | tgn.eval() 302 | break 303 | else: 304 | torch.save(tgn.state_dict(), get_checkpoint_path(epoch)) 305 | 306 | ############################ TESTING #################################### 307 | # Training has finished, we have loaded the best model, and we want to backup its current 308 | # memory (which has seen validation edges) so that it can also be used when testing on unseen 309 | # nodes 310 | if USE_MEMORY: 311 | val_memory_backup = tgn.memory.backup_memory() 312 | 313 | ### Test 314 | tgn.embedding_module.neighbor_finder = full_ngh_finder 315 | test_ap, test_auc = eval_edge_prediction(model=tgn, negative_edge_sampler=test_rand_sampler, 316 | data=test_data, n_neighbors=NUM_NEIGHBORS) 317 | 318 | if USE_MEMORY: 319 | tgn.memory.restore_memory(val_memory_backup) 320 | 321 | # Test on unseen nodes 322 | nn_test_ap, nn_test_auc = eval_edge_prediction(model=tgn, negative_edge_sampler=nn_test_rand_sampler, 323 | data=new_node_test_data, n_neighbors=NUM_NEIGHBORS) 324 | 325 | logger.info('Test statistics: Transductive -- auc: {}, ap: {}'.format(test_auc, test_ap)) 326 | logger.info('Test statistics: Inductive -- auc: {}, ap: {}'.format(nn_test_auc, nn_test_ap)) 327 | 328 | # Save results for this run 329 | pickle.dump({ 330 | "val_aps": val_aps, 331 | "new_nodes_val_aps": new_nodes_val_aps, 332 | "test_ap": test_ap, 333 | "new_node_test_ap": nn_test_ap, 334 | "epoch_times": epoch_times, 335 | "train_losses": train_losses, 336 | "total_epoch_times": total_epoch_times 337 | }, open(results_path, "wb")) 338 | 339 | logger.info('Saving TGN model') 340 | 341 | if USE_MEMORY: 342 | # Restore memory at the end of validation (save a model which is ready for testing) 343 | tgn.memory.restore_memory(val_memory_backup) 344 | 345 | torch.save(tgn.state_dict(), MODEL_SAVE_PATH) 346 | logger.info('TGN model saved') -------------------------------------------------------------------------------- /modules/embedding_module.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Mon Aug 24 22:01:32 2020 4 | 5 | @author: Ming Jin 6 | 7 | There are four types of embedding calculation methods: Identity, Time projection, Temporal graph attention, Temporal graph sum 8 | 9 | Providing a batch of nodes (i.e. source nodes), return the embeddings for those nodes 10 | """ 11 | 12 | import torch 13 | from torch import nn 14 | import numpy as np 15 | import math 16 | 17 | from modules.temporal_attention import TemporalAttentionLayer 18 | 19 | 20 | class EmbeddingModule(nn.Module): 21 | """ 22 | Abstract class for the embedding calculation 23 | 24 | INIT: 25 | node_features: Nodes raw features of shape [n_nodes, node_feat_dim] 26 | edge_features: Edges raw features of shape [n_interactinon, edge_feat_dim] 27 | neighbor_finder: NeighborFinder instance 28 | time_encoder: TimeEncoder instance encodes t to a vector of shape [n_time_features] 29 | n_layers: L in the paper, corresponding to L-hops as well 30 | n_node_features: Nodes raw feature dimension equals to node_feat_dim 31 | n_edge_features: Edges raw feature dimension euqals to edge_feat_dim 32 | n_time_features: Time encoding dimension equals to node_feat_dim 33 | embedding_dimension: Embedding dim for nodes, which equals to node_feat_dim 34 | device: 35 | dropout: Not in codes 36 | 37 | INPUTS: 38 | memory: A Tensor of shape [n_nodes, mem_dim]; Memory.memory 39 | source_nodes: Array of shape [source_nodes]; Nodes in a batch in a certain time which embeddings to be calculated 40 | timestamps: Array of shape [source_nodes]; Timestamps of interactions (i.e. Current timestamps) for those nodes 41 | n_layers: A number; Indicating how many graph conv layers (i.e. How deep to aggregate neighbors' information) 42 | n_neighbors: A number; Indicating how many temporal neighbors to be considered in a certain hop 43 | time_diffs: A Tensor of shape [source_nodes]; Delta t, i.e. Differences between the time of a node was last updated (i.e. memory.last_update), 44 | and the time for which we want to compute the embedding of a node 45 | use_time_proj: Not in codes 46 | """ 47 | def __init__(self, node_features, edge_features, neighbor_finder, time_encoder, n_layers, 48 | n_node_features, n_edge_features, n_time_features, embedding_dimension, device, 49 | dropout): 50 | 51 | super(EmbeddingModule, self).__init__() 52 | 53 | self.node_features = node_features 54 | self.edge_features = edge_features 55 | self.neighbor_finder = neighbor_finder 56 | self.time_encoder = time_encoder 57 | self.n_layers = n_layers 58 | self.n_node_features = n_node_features 59 | self.n_edge_features = n_edge_features 60 | self.n_time_features = n_time_features 61 | self.dropout = dropout 62 | self.embedding_dimension = embedding_dimension 63 | self.device = device 64 | 65 | def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, 66 | use_time_proj=True): 67 | pass 68 | 69 | 70 | class IdentityEmbedding(EmbeddingModule): 71 | """ 72 | Identity embedding calculation: Z_i(t) = S_i(t) 73 | 74 | Node embedding of shape [source_nodes, mem_dim] 75 | """ 76 | def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, 77 | use_time_proj=True): 78 | return memory[source_nodes, :] 79 | 80 | 81 | class TimeEmbedding(EmbeddingModule): 82 | """ 83 | Identity embedding calculation: Z_i(t) = S_i(t) * (1 + Linear(time_diff)) 84 | 85 | Node embedding of shape [source_nodes, emb_dim] 86 | """ 87 | def __init__(self, node_features, edge_features, neighbor_finder, time_encoder, n_layers, 88 | n_node_features, n_edge_features, n_time_features, embedding_dimension, device, 89 | n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1): 90 | 91 | super(TimeEmbedding, self).__init__(node_features, edge_features, 92 | neighbor_finder, time_encoder, n_layers, 93 | n_node_features, n_edge_features, n_time_features, 94 | embedding_dimension, device, dropout) 95 | 96 | class NormalLinear(nn.Linear): 97 | # From Jodie code 98 | def reset_parameters(self): 99 | stdv = 1. / math.sqrt(self.weight.size(1)) 100 | self.weight.data.normal_(0, stdv) 101 | if self.bias is not None: 102 | self.bias.data.normal_(0, stdv) 103 | 104 | self.embedding_layer = NormalLinear(1, self.n_node_features) # node_dim = 172 = mem_dim for wiki and reddit 105 | 106 | def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, 107 | use_time_proj=True): 108 | 109 | source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1))) 110 | 111 | return source_embeddings 112 | 113 | 114 | class GraphEmbedding(EmbeddingModule): 115 | """ 116 | THe second abstract class (Relationship: EmbeddingModule <-- GraphEmbedding <-- GraphSumEmbedding/GraphAttentionEmbedding) 117 | for the graph-based embedding calculation 118 | """ 119 | def __init__(self, node_features, edge_features, neighbor_finder, time_encoder, n_layers, 120 | n_node_features, n_edge_features, n_time_features, embedding_dimension, device, 121 | n_heads=2, dropout=0.1, use_memory=True): 122 | 123 | 124 | super(GraphEmbedding, self).__init__(node_features, edge_features, 125 | neighbor_finder, time_encoder, n_layers, 126 | n_node_features, n_edge_features, n_time_features, 127 | embedding_dimension, device, dropout) 128 | 129 | self.use_memory = use_memory 130 | self.device = device 131 | 132 | def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None, 133 | use_time_proj=True): 134 | """ 135 | Recursive implementation of n_layers temporal graph embedding calculation. 136 | Finally we have h^n_layers (t), which is h^L (t) in paper. 137 | 138 | INPUTS: 139 | same as EmbeddingModule 140 | 141 | OUTPUT: 142 | source_node_features (if n_layer==0): h^0 (t) of shape [source_nodes, node_feat_dim] 143 | OR 144 | source_embedding: h^1 (t) to h^n_layers (t), each of shape [source_nodes, emb_dim] 145 | """ 146 | 147 | assert (n_layers >= 0) # layers can't be negative 148 | 149 | source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device) # source node ids 150 | timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1) # timestamps of interactions for those nodes 151 | source_nodes_time_embedding = self.time_encoder(torch.zeros_like(timestamps_torch)) # Phi(0) 152 | source_node_features = self.node_features[source_nodes_torch, :] # source node features of shape [source_nodes, node_feat_dim] 153 | 154 | # h^0 (t) = S(t) + V(t) 155 | if self.use_memory: 156 | source_node_features = memory[source_nodes, :] + source_node_features # node_dim = 172 = mem_dim for wiki and reddit 157 | 158 | if n_layers == 0: 159 | # h^0 (t) 160 | return source_node_features 161 | else: 162 | # n_neighbors TEMPORAL neighbors of source_nodes, as well as the corrsponding edges and associated interaction timestamps 163 | # neighbors: Arrary of shape [source_nodes, n_neighbors] 164 | # edge_idxs: Arrary of shape [source_nodes, n_neighbors] 165 | # edge_times: Arrary of shape [source_nodes, n_neighbors] 166 | neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(source_nodes, timestamps, n_neighbors=n_neighbors) 167 | 168 | # neighbors_torch, edge_idxs, edge_deltas_torch are what we need 169 | neighbors_torch = torch.from_numpy(neighbors).long().to(self.device) 170 | edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device) 171 | edge_deltas = timestamps[:, np.newaxis] - edge_times # This is t - t_N in the paper of shape [source_nodes, num_temp_neighbors] 172 | edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device) 173 | 174 | # Recrsively calculate L-hops neighbors' embeddings of source_nodes 175 | neighbors = neighbors.flatten() # [source_nodes * num_temp_neighbors] 176 | # neighbor_embeddings of shape [len(neighbors), emb_dim] 177 | neighbor_embeddings = self.compute_embedding(memory, 178 | neighbors, 179 | np.repeat(timestamps, n_neighbors), 180 | n_layers=n_layers - 1, 181 | n_neighbors=n_neighbors) 182 | 183 | effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1 184 | # h^(l-1)_j (t) : neighbor_embeddings of shape [source_nodes, n_neighbors, emb_dim] 185 | neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1) 186 | # Phi(t - t_j) : time_embeddings of shape [source_nodes, n_neighbors, n_time_features = node_feat_dim] 187 | edge_time_embeddings = self.time_encoder(edge_deltas_torch) 188 | # e_ij : edge_features of shape [source_nodes, n_neighbors, n_edge_features = node_feat_dim] 189 | edge_features = self.edge_features[edge_idxs, :] 190 | 191 | # mask is a zeros tensor of shape [source_nodes, n_neighbors] 192 | # neighbors_torch will not be affected 193 | mask = neighbors_torch == 0 194 | 195 | # [source_nodes, emb_dim] 196 | source_embedding = self.aggregate(n_layers, source_node_features, 197 | source_nodes_time_embedding, 198 | neighbor_embeddings, 199 | edge_time_embeddings, 200 | edge_features, 201 | mask) 202 | # h^1 (t) to h^n_layers (t) 203 | return source_embedding 204 | 205 | def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding, 206 | neighbor_embeddings, edge_time_embeddings, edge_features, mask): 207 | """ 208 | Sum or attention aggregation for the graph-based embedding calculation 209 | 210 | INPUTS: 211 | n_layers: L in the paper, i.e. L-hops 212 | source_node_features: h^0 (t) = S(t) + V(t) 213 | source_nodes_time_embedding: Phi(0) of shape [source_nodes, n_time_features] 214 | neighbor_embeddings: Neighbor embeddings of l-1 layer with shape [source_nodes, emb_dim] 215 | edge_time_embeddings: Phi(t - t_j) of shape [source_nodes, n_time_features] 216 | edge_features: e_ij (i.e. edge_features) of shape [source_nodes, n_neighbors, n_edge_features = node_feat_dim] 217 | mask: A zeros tensor of shape [source_nodes, n_neighbors]; Has been used by Graph Attn Module 218 | """ 219 | return None 220 | 221 | 222 | class GraphSumEmbedding(GraphEmbedding): 223 | 224 | def __init__(self, node_features, edge_features, neighbor_finder, time_encoder, n_layers, 225 | n_node_features, n_edge_features, n_time_features, embedding_dimension, device, 226 | n_heads=2, dropout=0.1, use_memory=True): 227 | 228 | super(GraphSumEmbedding, self).__init__(node_features=node_features, 229 | edge_features=edge_features, 230 | neighbor_finder=neighbor_finder, 231 | time_encoder=time_encoder, n_layers=n_layers, 232 | n_node_features=n_node_features, 233 | n_edge_features=n_edge_features, 234 | n_time_features=n_time_features, 235 | embedding_dimension=embedding_dimension, 236 | device=device, 237 | n_heads=n_heads, dropout=dropout, 238 | use_memory=use_memory) 239 | 240 | # mapping neighbors_features: [source_nodes, emb_dim + edge_feat_dim + time_feat_dim] --> [source_nodes, emb_dim] 241 | self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features + n_edge_features, 242 | embedding_dimension) for _ in range(n_layers)]) 243 | 244 | self.linear_2 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_node_features + n_time_features, 245 | embedding_dimension) for _ in range(n_layers)]) 246 | 247 | def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding, 248 | neighbor_embeddings, edge_time_embeddings, edge_features, mask): 249 | 250 | # neighbors_features: h_j ^ (l-1) (t) || e_ij || Phi(t - t_j) 251 | neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features], dim=2) 252 | neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features) 253 | # h^wave_i (t) in paper 254 | neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1)) 255 | 256 | # source_features: [source_nodes, n_node_features + n_time_features] 257 | source_features = torch.cat([source_node_features, source_nodes_time_embedding.squeeze()], dim=1) # TODO: Why raw features? 258 | # [source_nodes, emd_dim + n_node_features + n_time_features] 259 | source_embedding = torch.cat([neighbors_sum, source_features], dim=1) 260 | # [source_nodes, emb_dim] 261 | source_embedding = self.linear_2[n_layer - 1](source_embedding) 262 | 263 | return source_embedding 264 | 265 | 266 | class GraphAttentionEmbedding(GraphEmbedding): 267 | 268 | def __init__(self, node_features, edge_features, neighbor_finder, time_encoder, n_layers, 269 | n_node_features, n_edge_features, n_time_features, embedding_dimension, device, 270 | n_heads=2, dropout=0.1, use_memory=True): 271 | 272 | super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, 273 | neighbor_finder, time_encoder, n_layers, 274 | n_node_features, n_edge_features, 275 | n_time_features, 276 | embedding_dimension, device, 277 | n_heads, dropout, 278 | use_memory) 279 | 280 | self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(n_node_features=n_node_features, 281 | n_neighbors_features=n_node_features, 282 | n_edge_features=n_edge_features, 283 | time_dim=n_time_features, 284 | n_head=n_heads, 285 | dropout=dropout, 286 | output_dimension=n_node_features) for _ in range(n_layers)]) 287 | 288 | def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding, 289 | neighbor_embeddings, edge_time_embeddings, edge_features, mask): 290 | 291 | attention_model = self.attention_models[n_layer - 1] 292 | 293 | source_embedding, _ = attention_model(source_node_features, 294 | source_nodes_time_embedding, 295 | neighbor_embeddings, 296 | edge_time_embeddings, 297 | edge_features, 298 | mask) 299 | 300 | return source_embedding 301 | 302 | 303 | ''' 304 | ############### 305 | REFERENCE ENTRY 306 | ############### 307 | ''' 308 | def get_embedding_module(module_type, node_features, edge_features, neighbor_finder, 309 | time_encoder, n_layers, n_node_features, n_edge_features, n_time_features, 310 | embedding_dimension, device, n_heads=2, dropout=0.1, n_neighbors=None, 311 | use_memory=True): 312 | 313 | if module_type == "graph_attention": 314 | return GraphAttentionEmbedding(node_features=node_features, 315 | edge_features=edge_features, 316 | neighbor_finder=neighbor_finder, 317 | time_encoder=time_encoder, 318 | n_layers=n_layers, 319 | n_node_features=n_node_features, 320 | n_edge_features=n_edge_features, 321 | n_time_features=n_time_features, 322 | embedding_dimension=embedding_dimension, 323 | device=device, 324 | n_heads=n_heads, dropout=dropout, use_memory=use_memory) 325 | elif module_type == "graph_sum": 326 | return GraphSumEmbedding(node_features=node_features, 327 | edge_features=edge_features, 328 | neighbor_finder=neighbor_finder, 329 | time_encoder=time_encoder, 330 | n_layers=n_layers, 331 | n_node_features=n_node_features, 332 | n_edge_features=n_edge_features, 333 | n_time_features=n_time_features, 334 | embedding_dimension=embedding_dimension, 335 | device=device, 336 | n_heads=n_heads, dropout=dropout, use_memory=use_memory) 337 | 338 | elif module_type == "identity": 339 | return IdentityEmbedding(node_features=node_features, 340 | edge_features=edge_features, 341 | neighbor_finder=neighbor_finder, 342 | time_encoder=time_encoder, 343 | n_layers=n_layers, 344 | n_node_features=n_node_features, 345 | n_edge_features=n_edge_features, 346 | n_time_features=n_time_features, 347 | embedding_dimension=embedding_dimension, 348 | device=device, 349 | dropout=dropout) 350 | elif module_type == "time": 351 | return TimeEmbedding(node_features=node_features, 352 | edge_features=edge_features, 353 | neighbor_finder=neighbor_finder, 354 | time_encoder=time_encoder, 355 | n_layers=n_layers, 356 | n_node_features=n_node_features, 357 | n_edge_features=n_edge_features, 358 | n_time_features=n_time_features, 359 | embedding_dimension=embedding_dimension, 360 | device=device, 361 | dropout=dropout, 362 | n_neighbors=n_neighbors) 363 | else: 364 | raise ValueError("Embedding Module {} not supported".format(module_type)) -------------------------------------------------------------------------------- /net/tgn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 26 22:58:18 2020 4 | 5 | @author: Ming Jin 6 | """ 7 | 8 | import logging 9 | import numpy as np 10 | import torch 11 | from collections import defaultdict 12 | 13 | from utils.utils import MergeLayer 14 | from modules.memory import Memory 15 | from modules.message_aggregator import get_message_aggregator 16 | from modules.message_function import get_message_function 17 | from modules.memory_updater import get_memory_updater 18 | from modules.embedding_module import get_embedding_module 19 | from modules.time_encoding import TimeEncode 20 | 21 | 22 | class TGN(torch.nn.Module): 23 | """ 24 | TGN model 25 | 26 | INIT INPUTS: 27 | neighbor_finder: NeighborFinder instance 28 | node_features: Nodes raw features of shape [n_nodes, node_feat_dim] 29 | edge_features: Edges raw features of shape [n_interactinon, edge_feat_dim] 30 | n_layers: 'L' in the paper 31 | n_heads: Number of attention heads 32 | dropout: For nn.MultiheadAttention() 33 | use_memory: Bool variable, whether to augment the model with a node memory 34 | memory_update_at_start: Bool variable, whether to update memory at the start of the batch 35 | message_dimension: Node message dimension for m_i(t), default 100 36 | memory_dimension: Node memory dimension for s_i(t), default 172 37 | embedding_module_type: How to calculate embedding, default 'graph_attention' 38 | message_function: How to calculate node message, default 'mlp' 39 | mean_time_shift_src: 40 | std_time_shift_src: 41 | mean_time_shift_dst: 42 | std_time_shift_dst: 43 | n_neighbors: How many temporal neighbos to be extracted 44 | aggregator_type: How to aggregate messages, default 'last' 45 | memory_updater_type: How to update node memory 46 | use_destination_embedding_in_message: 47 | use_source_embedding_in_message: 48 | """ 49 | 50 | def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2, 51 | n_heads=2, dropout=0.1, use_memory=True, memory_update_at_start=True, 52 | message_dimension=100, memory_dimension=172, embedding_module_type="graph_attention", 53 | message_function="mlp", mean_time_shift_src=0, std_time_shift_src=1, 54 | mean_time_shift_dst=0, std_time_shift_dst=1, n_neighbors=None, aggregator_type="last", 55 | memory_updater_type="gru", use_destination_embedding_in_message=False, 56 | use_source_embedding_in_message=False): 57 | 58 | super(TGN, self).__init__() 59 | 60 | self.n_layers = n_layers 61 | self.neighbor_finder = neighbor_finder 62 | self.device = device 63 | self.logger = logging.getLogger(__name__) 64 | 65 | self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device) # node features to tensor 66 | self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device) # edge features to tensor 67 | 68 | self.n_node_features = self.node_raw_features.shape[1] # node_feat_dim 69 | self.n_nodes = self.node_raw_features.shape[0] # n_nodes 70 | self.n_edge_features = self.edge_raw_features.shape[1] # edge_feat_dim 71 | self.embedding_dimension = self.n_node_features # emb_dim = node_feat_dim 72 | self.n_neighbors = n_neighbors 73 | self.embedding_module_type = embedding_module_type 74 | self.use_destination_embedding_in_message = use_destination_embedding_in_message 75 | self.use_source_embedding_in_message = use_source_embedding_in_message 76 | 77 | self.use_memory = use_memory 78 | self.time_encoder = TimeEncode(dimension=self.n_node_features) # encodes time to shape [node_feat_dim] 79 | self.memory = None 80 | 81 | self.mean_time_shift_src = mean_time_shift_src 82 | self.std_time_shift_src = std_time_shift_src 83 | self.mean_time_shift_dst = mean_time_shift_dst 84 | self.std_time_shift_dst = std_time_shift_dst 85 | 86 | if self.use_memory: 87 | 88 | self.memory_dimension = memory_dimension 89 | self.memory_update_at_start = memory_update_at_start 90 | # m_raw_i = (s_i || s_j || t || e) 91 | raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + self.time_encoder.dimension # raw message dim 92 | message_dimension = message_dimension if message_function != "identity" else raw_message_dimension # message dim 93 | 94 | self.memory = Memory(n_nodes=self.n_nodes, 95 | memory_dimension=self.memory_dimension, 96 | input_dimension=message_dimension, 97 | device=device) 98 | 99 | self.message_function = get_message_function(module_type=message_function, 100 | raw_message_dimension=raw_message_dimension, 101 | message_dimension=message_dimension) # message function 102 | 103 | self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type, device=device) # message aggregator 104 | 105 | # self.memory_updater = GRUMemoryUpdater(memory=self.memory, 106 | # message_dimension=message_dimension, 107 | # memory_dimension=self.memory_dimension, device=device) 108 | 109 | self.memory_updater = get_memory_updater(module_type=memory_updater_type, 110 | memory=self.memory, message_dimension=message_dimension, 111 | memory_dimension=self.memory_dimension, device=device) # memory updator 112 | 113 | self.embedding_module_type = embedding_module_type 114 | 115 | # self.embedding_module = get_embedding_module(module_type=embedding_module_type, 116 | # node_features=self.node_raw_features, 117 | # edge_features=self.edge_raw_features, 118 | # neighbor_finder=self.neighbor_finder, 119 | # time_encoder=self.time_encoder, 120 | # n_layers=self.n_layers, 121 | # n_node_features=self.n_node_features, 122 | # n_edge_features=self.n_edge_features, 123 | # n_time_features=self.n_node_features, 124 | # embedding_dimension=self.embedding_dimension, 125 | # device=self.device, 126 | # n_heads=n_heads, dropout=dropout, 127 | # use_memory=use_memory, 128 | # n_neighbors=self.n_neighbors) 129 | 130 | self.embedding_module = get_embedding_module(module_type=embedding_module_type, 131 | node_features=self.node_raw_features, 132 | edge_features=self.edge_raw_features, 133 | neighbor_finder=self.neighbor_finder, 134 | time_encoder=self.time_encoder, 135 | n_layers=self.n_layers, 136 | n_node_features=self.n_node_features, 137 | n_edge_features=self.n_edge_features, 138 | n_time_features=self.n_node_features, 139 | embedding_dimension=self.embedding_dimension, 140 | device=self.device, 141 | n_heads=n_heads, dropout=dropout, 142 | use_memory=use_memory, 143 | n_neighbors=self.n_neighbors) # embedding module 144 | 145 | # MLP to compute probability on an edge given two node embeddings 146 | self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features, self.n_node_features, 1) 147 | 148 | def set_neighbor_finder(self, neighbor_finder): 149 | self.neighbor_finder = neighbor_finder 150 | self.embedding_module.neighbor_finder = neighbor_finder 151 | 152 | def get_updated_memory(self, nodes, messages): 153 | """ 154 | Get (but not persist) updated nodes' memory by using messages (AGG-->MSG-->MEM, while in paper the order is MSG-->AGG-->MEM) 155 | 156 | INPUTS: 157 | nodes: A list of length n_nodes; Node ids 158 | message: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]}; Messages in previous batch 159 | 160 | OUTPUTS: 161 | updated_memory: A tensor of shape [unique_nodes, memory_dimension] 162 | updated_last_update: A tensor of shape [unique_nodes] 163 | """ 164 | # Aggregate messages for the same nodes 165 | unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages) 166 | 167 | if len(unique_nodes) > 0: 168 | unique_messages = self.message_function.compute_message(unique_messages) 169 | 170 | updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes, 171 | unique_messages, 172 | timestamps=unique_timestamps) 173 | 174 | return updated_memory, updated_last_update 175 | 176 | 177 | def update_memory(self, nodes, messages): 178 | """ 179 | Updated nodes' memory by using messages (AGG-->MSG-->MEM, while in paper the order is MSG-->AGG-->MEM) 180 | 181 | INPUTS: 182 | nodes: A list of length len(nodes); Node ids 183 | message: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]}; Messages in previous batch 184 | """ 185 | # Aggregate messages for the same nodes 186 | unique_nodes, unique_messages, unique_timestamps = self.message_aggregator.aggregate(nodes, messages) 187 | 188 | if len(unique_nodes) > 0: 189 | unique_messages = self.message_function.compute_message(unique_messages) 190 | 191 | # Update nodes' memory with the aggregated messages 192 | # Notice: update_memory() updates with no returns 193 | self.memory_updater.update_memory(unique_nodes, unique_messages, 194 | timestamps=unique_timestamps) 195 | 196 | 197 | # def get_raw_messages(self, source_nodes, destination_nodes, edge_times, edge_idxs): 198 | # """ 199 | # Get source_nodes' raw messages m_raw(t) = {[S(t-1), e(t)], t} 200 | 201 | # INPUTS: 202 | # source_nodes: Array of shape [batch_size]; Nodes' raw message to be calculated 203 | # destination_nodes: Array of shape [batch_size]; 204 | # edge_times: Array of shape [batch_size]; Timestamps of interactions (i.e. Current timestamps) for source_nodes 205 | # edge_idxs: Array of shape [batch_size]; Index of interactions (at edge_times) for source_nodes 206 | 207 | # OUTPUTS: 208 | # unique_sources: Array of shape [unique source nodes] 209 | # messages: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]} 210 | # where [message_x] is [S_i(t-1), S_j(t-1), e_ij(t), Phi(t-(t-1))], timestamp_x is the timestamp for each message_x 211 | # """ 212 | # edge_times = torch.from_numpy(edge_times).float().to(self.device) 213 | # edge_features = self.edge_raw_features[edge_idxs] # e_ij(t), or e(t) 214 | # source_memory = self.memory.get_memory(source_nodes) # S_i(t-1) 215 | # destination_memory = self.memory.get_memory(destination_nodes) # S_j(t-1) 216 | # source_time_delta = edge_times - self.memory.last_update[source_nodes] 217 | # source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(source_nodes), -1) # Phi(t-t^wave) 218 | 219 | # source_message = torch.cat([source_memory, destination_memory, edge_features, source_time_delta_encoding], dim=1) 220 | 221 | # messages = defaultdict(list) 222 | # unique_sources = np.unique(source_nodes) 223 | 224 | # for i in range(len(source_nodes)): 225 | # messages[source_nodes[i]].append((source_message[i], edge_times[i])) 226 | 227 | # return unique_sources, messages 228 | 229 | def get_raw_messages(self, source_nodes, source_node_embedding, 230 | destination_nodes, destination_node_embedding, 231 | edge_times, edge_idxs): 232 | """ 233 | Get source_nodes' raw messages m_raw_i(t) = {[S_i(t-1), S_j(t-1), e(t), Phi(t-(t_last)], t} 234 | 235 | INPUTS: 236 | source_nodes: Array of shape [batch_size]; Nodes' raw message to be calculated 237 | destination_nodes: Array of shape [batch_size]; 238 | edge_times: Array of shape [batch_size]; Timestamps of interactions (i.e. Current timestamps) for source_nodes 239 | edge_idxs: Array of shape [batch_size]; Index of interactions (at edge_times) for source_nodes 240 | source_node_embedding: z_i(t) with shape [batch_size, emb_dim=node_dim=mem_dim] 241 | destination_node_embedding: z_j(t) with shape [batch_size, emb_dim=node_dim=mem_dim] 242 | 243 | OUTPUTS: 244 | unique_sources: Array of shape [unique source nodes] 245 | messages: A dictionary {node_id:[([message_1], timestamp_1), ([message_2], timestamp_2), ...]} 246 | where [message_x] is [S_i(t-1), S_j(t-1), e_ij(t), Phi(t-(t-1))], timestamp_x is the timestamp for each message_x 247 | """ 248 | edge_times = torch.from_numpy(edge_times).float().to(self.device) 249 | edge_features = self.edge_raw_features[edge_idxs] # e_ij(t), or e(t) 250 | 251 | # s_i(t-1) or z_i(t) 252 | source_memory = self.memory.get_memory(source_nodes) if not self.use_source_embedding_in_message else source_node_embedding 253 | # s_j(t-1) or z_j(t) 254 | destination_memory = self.memory.get_memory(destination_nodes) if not self.use_destination_embedding_in_message else destination_node_embedding 255 | 256 | source_time_delta = edge_times - self.memory.last_update[source_nodes] 257 | source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(source_nodes), -1) # Phi(t-t^wave) 258 | 259 | unique_sources = np.unique(source_nodes) 260 | 261 | source_message = torch.cat([source_memory, destination_memory, edge_features, source_time_delta_encoding], dim=1) 262 | messages = defaultdict(list) 263 | for i in range(len(source_nodes)): 264 | messages[source_nodes[i]].append((source_message[i], edge_times[i])) 265 | 266 | return unique_sources, messages 267 | 268 | def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors=20): 269 | """ 270 | Compute temporal embeddings for sources, destinations, and negatively sampled destinations. 271 | Corresponding to algorithm 1 and 2 in the paper. 272 | 273 | INPUTS: 274 | source_nodes: Array of shape [batch_size]; Source node ids. 275 | destination_nodes: Array of shape [batch_size]; Destination node ids 276 | negative_nodes: Array of shape [batch_size]; Ids of negative sampled destination 277 | edge_times: Array of shape [batch_size]; Timestamps of interactions (i.e. Current timestamps) for those nodes (i.e. src, dest, neg) 278 | edge_idxs: Array of shape [batch_size]; Index of interactions 279 | n_neighbors: A number of temporal neighbor to consider in each layer (hop) 280 | 281 | OUTPUTS: Temporal embeddings for sources, destinations and negatives 282 | source_node_embedding: A tensor of shape [source_nodes, emb_dim] 283 | destination_node_embedding: A tensor of shape [destination_nodes, emb_dim] 284 | negative_node_embedding: A tensor of shape [negative_nodes, emb_dim] 285 | """ 286 | 287 | n_samples = len(source_nodes) 288 | nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes]) # all nodes 289 | positives = np.concatenate([source_nodes, destination_nodes]) # positive pairs 290 | timestamps = np.concatenate([edge_times, edge_times, edge_times]) # (current) timestamps for those nodes (i.e. V_2(t_1) and V_2(t_2)) 291 | 292 | memory = None 293 | time_diffs = None 294 | 295 | if self.use_memory: 296 | 297 | ### Line 5-7 in Algorithm 2: Update memory first with previous batch messages, and then calculate embeddings 298 | if self.memory_update_at_start: 299 | # update memory for ALL nodes with messages stored in previous batches 300 | memory, last_update = self.get_updated_memory(list(range(self.n_nodes)), self.memory.messages) 301 | ### Line 3.5 in Algorithm 1: Use previous batch memory and calculate embeddings 302 | else: 303 | memory = self.memory.get_memory(list(range(self.n_nodes))) 304 | last_update = self.memory.last_update 305 | 306 | # Compute differences between the time the memory of a node was last updated, 307 | # and the time for which we want to compute the embedding of a node 308 | source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[source_nodes].long() 309 | source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src 310 | destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[destination_nodes].long() 311 | destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst 312 | negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[negative_nodes].long() 313 | negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst 314 | 315 | # time_diffs, i.e. delta_t, is for TimeEmbedding method 316 | time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs], dim=0) 317 | 318 | # Compute the embeddings for [source_nodes, destination_nodes, negative_nodes] 319 | # If memory_update_at_start is True: Line 8 in algorithm 2; The procedure is same as Figure 2 (right) in the paper 320 | # If memory_update_at_start is False: Line 4 in algorithm 1; The procedure is same as Figure 2 (left) in the paper 321 | node_embedding = self.embedding_module.compute_embedding(memory=memory, 322 | source_nodes=nodes, 323 | timestamps=timestamps, 324 | n_layers=self.n_layers, 325 | n_neighbors=n_neighbors, 326 | time_diffs=time_diffs) 327 | 328 | source_node_embedding = node_embedding[:n_samples] 329 | destination_node_embedding = node_embedding[n_samples: 2 * n_samples] 330 | negative_node_embedding = node_embedding[2 * n_samples:] 331 | 332 | if self.use_memory: 333 | 334 | ### Line 12 in algorithm 2: If memory_update_at_start, we persist the update to memory (i.e. S(t-1)) here 335 | if self.memory_update_at_start: 336 | # Persist the updates to the memory only for sources and destinations 337 | self.update_memory(positives, self.memory.messages) 338 | # Remove messages for the positives, we have already updated the memory using positives old message 339 | self.memory.clear_messages(positives) 340 | 341 | ### Line 7 in algorithm 1 342 | ### Line 11 in algorithm 2 343 | # get raw message on source nodes 344 | unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes, 345 | source_node_embedding, 346 | destination_nodes, 347 | destination_node_embedding, 348 | edge_times, edge_idxs) 349 | # get raw message on destination nodes 350 | unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes, 351 | destination_node_embedding, 352 | source_nodes, 353 | source_node_embedding, 354 | edge_times, edge_idxs) 355 | 356 | ### Line 11 in Algorithm 2: If memory_update_at_start, we then store the new raw message 357 | if self.memory_update_at_start: 358 | self.memory.store_raw_messages(unique_sources, source_id_to_messages) 359 | self.memory.store_raw_messages(unique_destinations, destination_id_to_messages) 360 | ### Line 7-9 in Algorithm 1: If not memory_update_at_start, we update memory here with new raw message 361 | else: 362 | self.update_memory(unique_sources, source_id_to_messages) 363 | self.update_memory(unique_destinations, destination_id_to_messages) 364 | 365 | return source_node_embedding, destination_node_embedding, negative_node_embedding 366 | 367 | 368 | def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors=20): 369 | """ 370 | Line 5 in algorithm 1; Line 9 in algorithm 2 371 | 372 | Compute probabilities for edges between sources and destination and between sources and 373 | negatives by first computing temporal embeddings using the TGN encoder and then feeding them 374 | into the MLP decoder. 375 | 376 | INPUTS: 377 | source_nodes: Array of shape [batch_size]; Source node ids. 378 | destination_nodes: Array of shape [batch_size]; Destination node ids. 379 | negative_nodes: Array of shape [batch_size]; Negative node ids. 380 | edge_times: Array of shape [batch_size]; Timestamps of interactions (i.e. Current timestamps) for those nodes (i.e. src, dest, neg) 381 | edge_idxs: Array of shape [batch_size]; Index of interactions 382 | n_neighbors: A number of temporal neighbor to consider in each layer (i.e. Each hop) 383 | 384 | OUTPUTS: 385 | Probabilities for both the positive and negative edges 386 | """ 387 | n_samples = len(source_nodes) 388 | 389 | # get node embeddings for all nodes first 390 | source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings( 391 | source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors) 392 | 393 | # then calculate the P_pos and P_neg 394 | score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0), 395 | torch.cat([destination_node_embedding, negative_node_embedding])).squeeze(dim=0) 396 | 397 | pos_score = score[:n_samples] 398 | neg_score = score[n_samples:] 399 | 400 | return pos_score.sigmoid(), neg_score.sigmoid() --------------------------------------------------------------------------------