├── data
├── ml_UCI-Msg.npy
└── ml_UCI-Msg_node.npy
├── train_example.sh
├── .idea
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── .gitignore
├── misc.xml
├── modules.xml
├── DGNN.iml
├── remote-mappings.xml
└── deployment.xml
├── README.md
├── modules
├── merge.py
├── message_function.py
├── memory.py
├── message_aggregator.py
├── propagater.py
└── update.py
├── models
├── time_encoding.py
└── dgnn.py
├── LICENSE
├── utils
├── preprocess_data.py
├── preprocess_uci_msg.py
├── utils.py
└── data_processing.py
├── evaluation.py
├── train.py
└── trainUCI.py
/data/ml_UCI-Msg.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyd1502/DGNN/HEAD/data/ml_UCI-Msg.npy
--------------------------------------------------------------------------------
/data/ml_UCI-Msg_node.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wyd1502/DGNN/HEAD/data/ml_UCI-Msg_node.npy
--------------------------------------------------------------------------------
/train_example.sh:
--------------------------------------------------------------------------------
1 | python trainUCI.py --bs 100 --prefix ucitest2.3 --n_runs 15 --different_new_nodes --gpu 1 --lr 0.0000005 --patience 25 --threshold 25
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../../:\Users\41374\Downloads\DGNN\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/DGNN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/remote-mappings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DGNN
2 | Pytorch implementation of "Streaming Graph Neural Network"
3 | (https://arxiv.org/abs/1810.10627)
4 | (I'm not author just trying to reproduce the experiment results)
5 |
6 |
7 | The best result of my implementation on UCI dataset is :
8 | mrr: 0.0259, recall@20: 0.1276, recall@50:0.2078.
9 |
10 | The result in orgin paper is :
11 | mrr:0.0342 recall@20:0.1284 recall@50: 0.2547.
12 |
13 | Any changes to the existing implementation are welcome.
14 |
15 | The main difficulty is how to process several events in one batch(batch size > 1). I borrow the message aggregation idea from Tgn(http://arxiv.org/abs/2006.10637) to solve the problem. Besides, I have not used 4 propagation module as the paper does but 2 in my implementation.
16 |
17 |
--------------------------------------------------------------------------------
/modules/merge.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from collections import defaultdict
5 |
6 |
7 | class MemoryMerge(nn.Module):
8 | def __init__(self, memory_dimension, device='cpu'):
9 | super(MemoryMerge, self).__init__()
10 | self.device = device
11 | self.W1 = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
12 | self.W2 = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
13 | self.bias = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
14 | self.act = torch.nn.ReLU()
15 |
16 | def forward(self, memory_s, memory_g):
17 | return torch.matmul(memory_s, self.W1)+torch.matmul(memory_g, self.W2) + self.bias
18 |
--------------------------------------------------------------------------------
/models/time_encoding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class TimeEncode(torch.nn.Module):
6 | # Time Encoding proposed by TGAT
7 | def __init__(self, dimension):
8 | super(TimeEncode, self).__init__()
9 |
10 | self.dimension = dimension
11 | self.w = torch.nn.Linear(1, dimension)
12 |
13 | self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
14 | .float().reshape(dimension, -1))
15 | self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())
16 |
17 | def forward(self, t):
18 | # t has shape [batch_size, seq_len]
19 | # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
20 | t = t.unsqueeze(dim=2)
21 |
22 | # output has shape [batch_size, seq_len, dimension]
23 | output = torch.cos(self.w(t))
24 |
25 | return output
26 |
--------------------------------------------------------------------------------
/modules/message_function.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 |
5 | class MessageFunction(nn.Module):
6 | def __init__(self, memory_dimension, message_dimension, edge_dimension=0, device="cpu"):
7 | super(MessageFunction, self).__init__()
8 | self.device = device
9 | self.W1 = nn.Parameter(torch.zeros((memory_dimension, message_dimension,)).to(self.device))
10 | self.W2 = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device))
11 | self.W3 = nn.Parameter(torch.zeros((edge_dimension, message_dimension)).to(self.device))
12 | nn.init.xavier_uniform_(self.W1)
13 | nn.init.xavier_uniform_(self.W2)
14 | self.bias = nn.Parameter(torch.zeros(message_dimension).to(self.device))
15 | self.act = nn.ReLU()
16 |
17 | def compute_message(self, memory_s, memory_g, edge_fea=None):
18 | messages = self.act(torch.matmul(memory_s, self.W1)+torch.matmul(memory_g, self.W2,) +
19 | torch.matmul(edge_fea, self.W3)+self.bias)
20 | #messages = self.bn1(messages)
21 | return messages
22 |
23 |
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 wyd1502
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
29 |
30 |
31 |
--------------------------------------------------------------------------------
/utils/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import pandas as pd
4 | from pathlib import Path
5 | import argparse
6 |
7 |
8 | def preprocess(data_name):
9 | u_list, i_list, ts_list, label_list = [], [], [], []
10 | feat_l = []
11 | idx_list = []
12 |
13 | with open(data_name) as f:
14 | s = next(f)
15 | for idx, line in enumerate(f):
16 | e = line.strip().split(',')
17 | u = int(e[0])
18 | i = int(e[1])
19 |
20 | ts = float(e[2])
21 | label = float(e[3]) # int(e[3])
22 |
23 | feat = np.array([float(x) for x in e[4:]])
24 |
25 | u_list.append(u)
26 | i_list.append(i)
27 | ts_list.append(ts)
28 | label_list.append(label)
29 | idx_list.append(idx)
30 |
31 | feat_l.append(feat)
32 | return pd.DataFrame({'u': u_list,
33 | 'i': i_list,
34 | 'ts': ts_list,
35 | 'label': label_list,
36 | 'idx': idx_list}), np.array(feat_l)
37 |
38 |
39 | def reindex(df, bipartite=True):
40 | new_df = df.copy()
41 | if bipartite:
42 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
43 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))
44 |
45 | upper_u = df.u.max() + 1
46 | new_i = df.i + upper_u
47 |
48 | new_df.i = new_i
49 | new_df.u += 1
50 | new_df.i += 1
51 | new_df.idx += 1
52 | else:
53 | new_df.u += 1
54 | new_df.i += 1
55 | new_df.idx += 1
56 |
57 | return new_df
58 |
59 |
60 | def run(data_name, bipartite=True):
61 | Path("data/").mkdir(parents=True, exist_ok=True)
62 | PATH = './data/{}.csv'.format(data_name)
63 | OUT_DF = './data/ml_{}.csv'.format(data_name)
64 | OUT_FEAT = './data/ml_{}.npy'.format(data_name)
65 | OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)
66 |
67 | df, feat = preprocess(PATH)
68 | new_df = reindex(df, bipartite)
69 |
70 | empty = np.zeros(feat.shape[1])[np.newaxis, :]
71 | feat = np.vstack([empty, feat])
72 |
73 | max_idx = max(new_df.u.max(), new_df.i.max())
74 | rand_feat = np.zeros((max_idx + 1, 172))
75 |
76 | new_df.to_csv(OUT_DF)
77 | np.save(OUT_FEAT, feat)
78 | np.save(OUT_NODE_FEAT, rand_feat)
79 |
80 | parser = argparse.ArgumentParser('Interface for DGNN data preprocessing')
81 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
82 | default='wikipedia')
83 |
84 | args = parser.parse_args()
85 |
86 | run(args.data)
--------------------------------------------------------------------------------
/utils/preprocess_uci_msg.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import pandas as pd
4 | from pathlib import Path
5 | import argparse
6 |
7 |
8 | def preprocess(data_name):
9 | u_list, i_list, ts_list, label_list = [], [], [], []
10 | feat_l = []
11 | idx_list = []
12 |
13 | with open(data_name) as f:
14 | s = next(f)
15 | for idx, line in enumerate(f):
16 | e = line.strip().split(' ')
17 | u = int(e[0])
18 | i = int(e[1])
19 |
20 | ts = float(e[2])
21 | label = float(0) # int(e[3])
22 |
23 | feat = np.array([0, 0, 0, 0, 0, 0])
24 |
25 | u_list.append(u)
26 | i_list.append(i)
27 | ts_list.append(ts)
28 | label_list.append(label)
29 | idx_list.append(idx)
30 |
31 | feat_l.append(feat)
32 | return pd.DataFrame({'u': u_list,
33 | 'i': i_list,
34 | 'ts': ts_list,
35 | 'label': label_list,
36 | 'idx': idx_list}), np.array(feat_l)
37 |
38 |
39 | def reindex(df, bipartite=True):
40 | new_df = df.copy()
41 | new_df.ts = df.ts - df.ts.min()
42 | if bipartite:
43 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
44 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))
45 |
46 | upper_u = df.u.max() + 1
47 | new_i = df.i + upper_u
48 |
49 | new_df.i = new_i
50 | new_df.u += 1
51 | new_df.i += 1
52 | new_df.idx += 1
53 | else:
54 | new_df.u += 1
55 | new_df.i += 1
56 | new_df.idx += 1
57 |
58 | return new_df
59 |
60 |
61 | def run(data_name, bipartite=True):
62 | Path("data/").mkdir(parents=True, exist_ok=True)
63 | PATH = './data/{}.txt'.format(data_name)
64 | OUT_DF = './data/ml_{}.csv'.format(data_name)
65 | OUT_FEAT = './data/ml_{}.npy'.format(data_name)
66 | OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)
67 |
68 | df, feat = preprocess(PATH)
69 | new_df = reindex(df, bipartite)
70 |
71 | empty = np.zeros(feat.shape[1])[np.newaxis, :]
72 | feat = np.vstack([empty, feat])
73 |
74 | max_idx = max(new_df.u.max(), new_df.i.max())
75 | rand_feat = np.zeros((max_idx + 1, 100))
76 |
77 | new_df.to_csv(OUT_DF)
78 | np.save(OUT_FEAT, feat)
79 | np.save(OUT_NODE_FEAT, rand_feat)
80 |
81 | parser = argparse.ArgumentParser('Interface for DGNN data preprocessing')
82 | parser.add_argument('--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
83 | default='UCI-Msg')
84 |
85 | args = parser.parse_args()
86 |
87 | run(args.data, False)
--------------------------------------------------------------------------------
/modules/memory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from collections import defaultdict
5 | from copy import deepcopy
6 |
7 |
8 | class Memory(nn.Module):
9 |
10 | def __init__(self, n_nodes, memory_dimension, message_dimension=None,
11 | device="cpu"):
12 | super(Memory, self).__init__()
13 | self.n_nodes = n_nodes
14 | self.memory_dimension = memory_dimension
15 | self.message_dimension = message_dimension
16 | self.device = device
17 |
18 |
19 | self.__init_memory__()
20 |
21 | def __init_memory__(self, seed=0):
22 | """
23 | Initializes the memory to all zeros. It should be called at the start of each epoch.
24 | """
25 | # Treat memory as parameter so that it is saved and loaded together with the model
26 | torch.manual_seed(seed)
27 | self.cell = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
28 | requires_grad=False)
29 | self.hidden = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
30 | requires_grad=False)
31 | self.memory = [self.cell, self.hidden]
32 | self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
33 | requires_grad=False)
34 | nn.init.xavier_normal(self.hidden)
35 | nn.init.xavier_normal(self.cell)
36 | self.messages = defaultdict(list)
37 |
38 | def store_raw_messages(self, nodes, node_id_to_messages):
39 | for node in nodes:
40 | self.messages[node].extend(node_id_to_messages[node])
41 |
42 | def get_memory(self, node_idxs):
43 | return [self.memory[i][node_idxs, :] for i in range(2)]
44 |
45 | def set_cell(self, node_idxs, values):
46 | self.cell[node_idxs, :] = values
47 |
48 | def set_hidden(self, node_idxs, values):
49 | self.hidden[node_idxs, :] = values
50 |
51 | def get_last_update(self, node_idxs):
52 | return self.last_update[node_idxs]
53 |
54 | def backup_memory(self):
55 | messages_clone = {}
56 | for k, v in self.messages.items():
57 | messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]
58 |
59 | return [self.memory[i].data.clone() for i in range(2)], self.last_update.data.clone(), messages_clone
60 |
61 | def restore_memory(self, memory_backup):
62 | self.cell.data, self.hidden.data, self.last_update.data = \
63 | memory_backup[0][0].clone(), memory_backup[0][1].clone(), memory_backup[1].clone()
64 | self.messages = defaultdict(list)
65 | for k, v in memory_backup[2].items():
66 | self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]
67 |
68 | def detach_memory(self):
69 | self.hidden.detach_()
70 | self.cell.detach_()
71 |
72 | # Detach all stored messages
73 | for k, v in self.messages.items():
74 | new_node_messages = []
75 | for message in v:
76 | new_node_messages.append((message[0].detach(), message[1]))
77 |
78 | self.messages[k] = new_node_messages
79 |
80 | def clear_messages(self, nodes):
81 | for node in nodes:
82 | self.messages[node] = []
83 |
--------------------------------------------------------------------------------
/modules/message_aggregator.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | import torch
3 | import numpy as np
4 |
5 |
6 | class MessageAggregator(torch.nn.Module):
7 | """
8 | Abstract class for the message aggregator module, which given a batch of node ids and
9 | corresponding messages, aggregates messages with the same node id.
10 | """
11 | def __init__(self, device):
12 | super(MessageAggregator, self).__init__()
13 | self.device = device
14 |
15 | def aggregate(self, node_ids, messages=None):
16 | """
17 | Given a list of node ids, and a list of messages of the same length, aggregate different
18 | messages for the same id using one of the possible strategies.
19 | :param node_ids: A list of node ids of length batch_size
20 | :param messages: A tensor of shape [batch_size, message_length]
21 | :param timestamps A tensor of shape [batch_size]
22 | :return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages
23 | """
24 |
25 | def group_by_id(self, node_ids, messages, timestamps):
26 | node_id_to_messages = defaultdict(list)
27 |
28 | for i, node_id in enumerate(node_ids):
29 | node_id_to_messages[node_id].append((messages[i], timestamps[i]))
30 |
31 | return node_id_to_messages
32 |
33 |
34 | class LastMessageAggregator(MessageAggregator):
35 | def __init__(self, device):
36 | super(LastMessageAggregator, self).__init__(device)
37 |
38 | def aggregate(self, node_ids, messages=None):
39 | """Only keep the last message for each node"""
40 | unique_node_ids = np.unique(node_ids)
41 | unique_messages = []
42 | unique_timestamps = []
43 |
44 | to_update_node_ids = []
45 |
46 | for node_id in unique_node_ids:
47 | if len(messages[node_id]) > 0:
48 | to_update_node_ids.append(node_id)
49 | unique_messages.append(messages[node_id][-1][0])
50 | unique_timestamps.append(messages[node_id][-1][1])
51 |
52 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
53 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
54 |
55 | return to_update_node_ids, unique_messages, unique_timestamps
56 |
57 |
58 | class MeanMessageAggregator(MessageAggregator):
59 | def __init__(self, device):
60 | super(MeanMessageAggregator, self).__init__(device)
61 |
62 | def aggregate(self, node_ids, messages=None):
63 | """Only keep the last message for each node"""
64 | unique_node_ids = np.unique(node_ids)
65 | unique_messages = []
66 | unique_timestamps = []
67 |
68 | to_update_node_ids = []
69 | n_messages = 0
70 |
71 | for node_id in unique_node_ids:
72 | if len(messages[node_id]) > 0:
73 | n_messages += len(messages[node_id])
74 | to_update_node_ids.append(node_id)
75 | unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
76 | unique_timestamps.append(messages[node_id][-1][1])
77 |
78 | unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
79 | unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []
80 |
81 | return to_update_node_ids, unique_messages, unique_timestamps
82 |
83 |
84 | def get_message_aggregator(aggregator_type, device):
85 | if aggregator_type == "last":
86 | return LastMessageAggregator(device=device)
87 | elif aggregator_type == "mean":
88 | return MeanMessageAggregator(device=device)
89 | else:
90 | raise ValueError("Message aggregator {} not implemented".format(aggregator_type))
91 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | from sklearn.metrics import average_precision_score, roc_auc_score
6 |
7 | def choose_target(model,memory_s, memory_g, src_mem):
8 | u = model.memory_merge(memory_s[1], memory_g[1]) #[num_nodes,mem_d]
9 | u_norm = torch.norm(u, dim=1) #[num_nodes, 1]
10 | u_normalized = u/u_norm.view(-1, 1) #[num_nodes,mem_d]
11 | src_mem_norm = torch.norm(src_mem, dim=1) #[batch_size, 1]
12 | src_mem_normalized = src_mem / src_mem_norm.view(-1, 1) #[batch_size, mem_d]
13 | cos_similarity = torch.matmul(src_mem_normalized, u_normalized.t()) #[batch_size, num_nodes]
14 | cos_similarity, idx = torch.sort(cos_similarity, descending=True)
15 | return cos_similarity, idx
16 |
17 | def recall(des_node, idx, top_k):
18 | bs = idx.shape[0]
19 | idx = idx[:, :top_k] #[bs,top_k]
20 | recall = np.array([a in idx[i] for i, a in enumerate(des_node)])#[bs,1]
21 | recall = recall.sum() / recall.size
22 | return recall
23 |
24 | def MRR(des_node, idx):
25 | bs = idx.shape[0]
26 | mrr = np.array([float(np.where(idx[i].cpu() == a)[0] + 1) for i, a in enumerate(des_node)])#[bs,1]
27 | mrr = (1 / mrr).mean()
28 | return mrr
29 |
30 |
31 | def eval_edge_prediction(model, negative_edge_sampler, data, n_neighbors, batch_size=200):
32 | # Ensures the random sampler uses a seed for evaluation (i.e. we sample always the same
33 | # negatives for validation / test set)
34 | assert negative_edge_sampler.seed is not None
35 | negative_edge_sampler.reset_random_state()
36 |
37 | val_mrr, val_recall_20, val_recall_50 = [], [], []
38 | with torch.no_grad():
39 | model = model.eval()
40 | # While usually the test batch size is as big as it fits in memory, here we keep it the same
41 | # size as the training batch size, since it allows the memory to be updated more frequently,
42 | # and later test batches to access information from interactions in previous test batches
43 | # through the memory
44 | TEST_BATCH_SIZE = batch_size
45 | num_test_instance = len(data.sources)
46 | num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE)
47 |
48 | for k in range(num_test_batch):
49 | s_idx = k * TEST_BATCH_SIZE
50 | e_idx = min(num_test_instance, s_idx + TEST_BATCH_SIZE)
51 | sources_batch = data.sources[s_idx:e_idx]
52 | destinations_batch = data.destinations[s_idx:e_idx]
53 | timestamps_batch = data.timestamps[s_idx:e_idx]
54 | edge_idxs_batch = data.edge_idxs[s_idx: e_idx]
55 |
56 | size = len(sources_batch)
57 | _, negative_samples = negative_edge_sampler.sample(size)
58 |
59 | src_mem, des_mem = model(sources_batch, destinations_batch,
60 | negative_samples, timestamps_batch,
61 | edge_idxs_batch, test=True)
62 |
63 | src_cos_sim, src_idx = choose_target(model, model.memory_s.memory, model.memory_g.memory, src_mem)
64 | des_cos_sim, des_idx = choose_target(model, model.memory_s.memory, model.memory_g.memory, des_mem)
65 | recall_20 = (recall(destinations_batch, src_idx, 20) + recall(sources_batch, des_idx, 20)) / 2
66 | recall_50 = (recall(destinations_batch, src_idx, 50) + recall(sources_batch, des_idx, 50)) / 2
67 | mrr = (MRR(destinations_batch, src_idx) + MRR(sources_batch, des_idx)) / 2
68 | true_label = np.concatenate([np.ones(size), np.zeros(size)])
69 |
70 | val_mrr.append(mrr)
71 | val_recall_20.append(recall_20)
72 | val_recall_50.append(recall_50)
73 |
74 | return np.mean(val_mrr), np.mean(val_recall_20), np.mean(val_recall_50)
75 |
--------------------------------------------------------------------------------
/modules/propagater.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 |
6 | class Propagater(nn.Module):
7 | def __init__(self, memory, message_dimension, memory_dimension, mean_time_shift_src, neighbor_finder, n_neighbors,
8 | tau=2, device='cpu'):
9 | super(Propagater, self).__init__()
10 | self.memory = memory
11 | self.layer_norm = torch.nn.LayerNorm(memory_dimension)
12 | self.message_dimension = message_dimension
13 | self.device = device
14 | self.alpha = 2
15 | self.neighbor_finder = neighbor_finder
16 | self.n_neighbors = n_neighbors
17 | self.tau = tau * mean_time_shift_src
18 |
19 | self.tanh = nn.Tanh()
20 | self.W_s = nn.Parameter(torch.zeros((message_dimension, memory_dimension)).to(self.device))
21 | self.bias = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
22 |
23 |
24 | def compute_time_discount(self, edge_delta):
25 | time_intervals = self.alpha * edge_delta
26 | time_discount = torch.exp(-time_intervals)
27 | return time_discount
28 |
29 | def compute_attention_weight(self, memory, sources_idx, neighbors):
30 | batch_s = neighbors.shape[0]
31 | sources = memory[0][sources_idx].view(batch_s, 1, -1)
32 | softmax = nn.Softmax(dim=1)
33 | att = softmax(torch.matmul(neighbors, sources.transpose(1, 2)))
34 | return att
35 |
36 | def forward(self, memory, unique_node_ids, unique_messages, timestamps, inplace=True):
37 | if len(unique_messages) == 0:
38 | return memory
39 | timestamps = timestamps.cpu().numpy()
40 | neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(unique_node_ids, timestamps,
41 | n_neighbors=self.n_neighbors)
42 | neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)
43 | batch_s, _ = neighbors.shape
44 | edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)
45 | edge_deltas = timestamps[:, np.newaxis] - edge_times
46 | edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)
47 | mask = (torch.from_numpy(edge_deltas).float().to(self.device) > 0).long()
48 | mask_re = mask.view(batch_s, self.n_neighbors, -1)
49 | neighbors_cell = memory[0][neighbors_torch.flatten()].view(batch_s, self.n_neighbors, -1)
50 | edge_delta_re = edge_deltas_torch.view(batch_s, self.n_neighbors, -1)
51 | unique_messages_re = unique_messages.repeat((self.n_neighbors, 1)).view(batch_s, self.n_neighbors, -1)
52 | time_discounts = self.compute_time_discount(edge_delta_re)
53 | time_threshold = (edge_delta_re < self.tau).long()
54 | att = self.compute_attention_weight(memory, unique_node_ids, neighbors_cell) #(b_s,n_neighbors,1)
55 | unique_messages = torch.matmul(unique_messages_re, self.W_s) #(b_s,n_neighbors,memory_size)
56 | C_v = memory[0][neighbors_torch.flatten()] + (mask_re*time_threshold*time_discounts*att * \
57 | unique_messages).view(batch_s*self.n_neighbors, -1) #(b_s,n_neighbors,memory_size)
58 | h_v = self.tanh(C_v)
59 | if inplace:
60 | memory[0][neighbors_torch.flatten()] = C_v
61 | memory[1][neighbors_torch.flatten()] = h_v
62 | return memory
63 | else:
64 | memory_cell = memory[0].data
65 | memory_hidden = memory[1].data
66 | memory_cell[neighbors_torch.flatten()] = C_v
67 | memory_hidden[neighbors_torch.flatten()] = h_v
68 | return [memory_cell, memory_hidden]
69 |
70 |
71 |
--------------------------------------------------------------------------------
/modules/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from collections import defaultdict
5 |
6 |
7 | class MemoryUpdater(nn.Module):
8 | def __init__(self, memory, message_dimension, memory_dimension, mean_time_shift_src, device):
9 | super(MemoryUpdater, self).__init__()
10 | self.memory = memory
11 | self.layer_norm = torch.nn.LayerNorm(memory_dimension)
12 | self.message_dimension = message_dimension
13 | self.device = device
14 | self.alpha = 2
15 |
16 | self.sig = nn.Sigmoid()
17 | self.tanh = nn.Tanh()
18 | self.W_d = nn.Parameter(torch.zeros((memory_dimension,memory_dimension)).to(self.device))
19 | self.b_d = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
20 | self.W_f = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device))
21 | self.U_f = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
22 | self.b_f = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
23 | self.W_i = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device))
24 | self.U_i = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
25 | self.b_i = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
26 | self.W_o = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device))
27 | self.U_o = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
28 | self.b_o = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
29 | self.W_c = nn.Parameter(torch.zeros((memory_dimension, message_dimension)).to(self.device))
30 | self.U_c = nn.Parameter(torch.zeros((memory_dimension, memory_dimension)).to(self.device))
31 | self.b_c = nn.Parameter(torch.zeros(memory_dimension).to(self.device))
32 | def update_memory(self, unique_node_ids, unique_messages, timestamps, inplace=True):
33 | if len(unique_messages) == 0:
34 | return self.memory.memory, self.memory.last_update
35 | hidden = self.memory.memory[1][unique_node_ids] #hidden [bs,memroy_size]
36 | cell = self.memory.memory[0][unique_node_ids] #cell [bs,memroy_size]
37 | messages = unique_messages
38 | time_discounts = self.compute_time_discount(unique_node_ids, timestamps) #[bs,1]
39 | bs = hidden.shape[0]
40 | C_vi = self.tanh(torch.matmul(self.W_d, cell.t()).t()+self.b_d) # [bs,memory_size]
41 | C_v_discount = torch.mul(C_vi, time_discounts.view(bs, 1))
42 | C_v_t = cell - C_v_discount
43 | C_v_star = C_v_t + C_v_discount
44 | f_t = self.sig(torch.matmul(self.W_f, messages.t()).t() + torch.matmul(self.U_f, hidden.t()).t()+self.b_f)
45 | i_t = self.sig(torch.matmul(self.W_i, messages.t()).t() + torch.matmul(self.U_i, hidden.t()).t() + self.b_i)
46 | o_t = self.sig(torch.matmul(self.W_o, messages.t()).t() + torch.matmul(self.U_o, hidden.t()).t() + self.b_o)
47 | C_hat_t = self.tanh(torch.matmul(self.W_c, messages.t()).t() + torch.matmul(self.U_c, hidden.t()).t() + self.b_c)
48 | C_v_t = torch.mul(f_t, C_v_star) + torch.mul(i_t, C_hat_t)
49 | h_v_t = torch.mul(o_t, self.tanh(C_v_t))
50 | if inplace:
51 | self.memory.memory[0][unique_node_ids] = C_v_t
52 | self.memory.memory[1][unique_node_ids] = h_v_t
53 | self.memory.last_update[unique_node_ids] = timestamps
54 | return self.memory.memory, self.memory.last_update
55 | else:
56 | memory_cell = self.memory.memory[0].data.clone()
57 | memory_hidden = self.memory.memory[1].data.clone()
58 | last_update = self.memory.last_update.data.clone()
59 | memory_cell[unique_node_ids] = C_v_t
60 | memory_hidden[unique_node_ids] = h_v_t
61 | last_update[unique_node_ids] = timestamps
62 | return [memory_cell, memory_hidden], last_update
63 |
64 |
65 | def compute_time_discount(self,unique_node_ids, timestamps):
66 | time_intervals = timestamps - self.memory.last_update[unique_node_ids]
67 | time_intervals = self.alpha*time_intervals
68 | time_discount = torch.exp(-time_intervals)
69 | return time_discount
70 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class MergeLayer(torch.nn.Module):
6 | def __init__(self, dim1, dim2, dim3, dim4):
7 | super().__init__()
8 | self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
9 | self.fc2 = torch.nn.Linear(dim3, dim4)
10 | self.act = torch.nn.ReLU()
11 |
12 | torch.nn.init.xavier_normal_(self.fc1.weight)
13 | torch.nn.init.xavier_normal_(self.fc2.weight)
14 |
15 | def forward(self, x1, x2):
16 | x = torch.cat([x1, x2], dim=1)
17 | h = self.act(self.fc1(x))
18 | return self.fc2(h)
19 |
20 |
21 | class EarlyStopMonitor(object):
22 | def __init__(self, max_round=3, higher_better=True, tolerance=1e-10):
23 | self.max_round = max_round
24 | self.num_round = 0
25 |
26 | self.epoch_count = 0
27 | self.best_epoch = 0
28 |
29 | self.last_best = None
30 | self.higher_better = higher_better
31 | self.tolerance = tolerance
32 |
33 | def early_stop_check(self, curr_val):
34 | if not self.higher_better:
35 | curr_val *= -1
36 | if self.last_best is None:
37 | self.last_best = curr_val
38 | elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance:
39 | self.last_best = curr_val
40 | self.num_round = 0
41 | self.best_epoch = self.epoch_count
42 | else:
43 | self.num_round += 1
44 |
45 | self.epoch_count += 1
46 |
47 | return self.num_round >= self.max_round
48 |
49 |
50 | class RandEdgeSampler(object):
51 | def __init__(self, src_list, dst_list, seed=None):
52 | self.seed = None
53 | self.src_list = np.unique(src_list)
54 | self.dst_list = np.unique(dst_list)
55 |
56 | if seed is not None:
57 | self.seed = seed
58 | self.random_state = np.random.RandomState(self.seed)
59 |
60 | def sample(self, size):
61 | if self.seed is None:
62 | src_index = np.random.randint(0, len(self.src_list), size)
63 | dst_index = np.random.randint(0, len(self.dst_list), size)
64 | else:
65 |
66 | src_index = self.random_state.randint(0, len(self.src_list), size)
67 | dst_index = self.random_state.randint(0, len(self.dst_list), size)
68 | return self.src_list[src_index], self.dst_list[dst_index]
69 |
70 | def reset_random_state(self):
71 | self.random_state = np.random.RandomState(self.seed)
72 |
73 |
74 | def get_neighbor_finder(data, uniform):
75 | max_node_idx = max(data.sources.max(), data.destinations.max())
76 | adj_list = [[] for _ in range(max_node_idx + 1)]
77 | for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
78 | data.edge_idxs,
79 | data.timestamps):
80 | adj_list[source].append((destination, edge_idx, timestamp))
81 | adj_list[destination].append((source, edge_idx, timestamp))
82 |
83 | return NeighborFinder(adj_list, uniform=uniform)
84 |
85 |
86 | class NeighborFinder:
87 | def __init__(self, adj_list, uniform=False, seed=None):
88 | self.node_to_neighbors = []
89 | self.node_to_edge_idxs = []
90 | self.node_to_edge_timestamps = []
91 |
92 | for neighbors in adj_list:
93 | # Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
94 | # We sort the list based on timestamp
95 | sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
96 | self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
97 | self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
98 | self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))
99 |
100 | self.uniform = uniform
101 |
102 | if seed is not None:
103 | self.seed = seed
104 | self.random_state = np.random.RandomState(self.seed)
105 |
106 | def find_before(self, src_idx, cut_time):
107 | """
108 | Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.
109 |
110 | Returns 3 lists: neighbors, edge_idxs, timestamps
111 |
112 | """
113 | i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)
114 |
115 | return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]
116 |
117 | def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
118 | """
119 | Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.
120 |
121 | Params
122 | ------
123 | src_idx_l: List[int]
124 | cut_time_l: List[float],
125 | num_neighbors: int
126 | """
127 | assert (len(source_nodes) == len(timestamps))
128 |
129 | tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
130 | # NB! All interactions described in these matrices are sorted in each row by time
131 | neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
132 | np.int32) # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
133 | edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
134 | np.float32) # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
135 | edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
136 | np.int32) # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
137 |
138 | for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
139 | source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
140 | timestamp) # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time
141 | if len(source_neighbors) > 0 and n_neighbors > 0:
142 | if self.uniform: # if we are applying uniform sampling, shuffles the data above before sampling
143 | sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)
144 |
145 | neighbors[i, :] = source_neighbors[sampled_idx]
146 | edge_times[i, :] = source_edge_times[sampled_idx]
147 | edge_idxs[i, :] = source_edge_idxs[sampled_idx]
148 |
149 | # re-sort based on time
150 | pos = edge_times[i, :].argsort()
151 | neighbors[i, :] = neighbors[i, :][pos]
152 | edge_times[i, :] = edge_times[i, :][pos]
153 | edge_idxs[i, :] = edge_idxs[i, :][pos]
154 | else:
155 | # Take most recent interactions
156 | source_edge_times = source_edge_times[-n_neighbors:]
157 | source_neighbors = source_neighbors[-n_neighbors:]
158 | source_edge_idxs = source_edge_idxs[-n_neighbors:]
159 |
160 | assert (len(source_neighbors) <= n_neighbors)
161 | assert (len(source_edge_times) <= n_neighbors)
162 | assert (len(source_edge_idxs) <= n_neighbors)
163 |
164 | neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
165 | edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
166 | edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs
167 |
168 | return neighbors, edge_idxs, edge_times
--------------------------------------------------------------------------------
/utils/data_processing.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import pandas as pd
4 |
5 |
6 | class Data:
7 | def __init__(self, sources, destinations, timestamps, edge_idxs, labels):
8 | self.sources = sources
9 | self.destinations = destinations
10 | self.timestamps = timestamps
11 | self.edge_idxs = edge_idxs
12 | self.labels = labels
13 | self.n_interactions = len(sources)
14 | self.unique_nodes = set(sources) | set(destinations)
15 | self.n_unique_nodes = len(self.unique_nodes)
16 |
17 |
18 | def get_data(dataset_name, different_new_nodes_between_val_and_test=False):
19 | ### Load data and train val test split
20 | graph_df = pd.read_csv('./data/ml_{}.csv'.format(dataset_name))
21 | edge_features = np.load('./data/ml_{}.npy'.format(dataset_name))
22 | node_features = np.load('./data/ml_{}_node.npy'.format(dataset_name))
23 |
24 | val_time, test_time = list(np.quantile(graph_df.ts, [0.80, 0.9]))
25 |
26 | sources = graph_df.u.values
27 | destinations = graph_df.i.values
28 | edge_idxs = graph_df.idx.values
29 | labels = graph_df.label.values
30 | timestamps = graph_df.ts.values
31 |
32 | full_data = Data(sources, destinations, timestamps, edge_idxs, labels)
33 |
34 | random.seed(2020)
35 |
36 | node_set = set(sources) | set(destinations)
37 | n_total_unique_nodes = len(node_set)
38 |
39 | # Compute nodes which appear at test time
40 | test_node_set = set(sources[timestamps > val_time]).union(
41 | set(destinations[timestamps > val_time]))
42 | # Sample nodes which we keep as new nodes (to test inductiveness), so than we have to remove all
43 | # their edges from training
44 | new_test_node_set = set(random.sample(test_node_set, int(0.1 * n_total_unique_nodes)))
45 |
46 | # Mask saying for each source and destination whether they are new test nodes
47 | new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values
48 | new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values
49 |
50 | # Mask which is true for edges with both destination and source not being new test nodes (because
51 | # we want to remove all edges involving any new test node)
52 | observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask)
53 |
54 | # For train we keep edges happening before the validation time which do not involve any new node
55 | # used for inductiveness
56 | train_mask = np.logical_and(timestamps <= val_time, observed_edges_mask)
57 |
58 | train_data = Data(sources[train_mask], destinations[train_mask], timestamps[train_mask],
59 | edge_idxs[train_mask], labels[train_mask])
60 |
61 | # define the new nodes sets for testing inductiveness of the model
62 | train_node_set = set(train_data.sources).union(train_data.destinations)
63 | assert len(train_node_set & new_test_node_set) == 0
64 | new_node_set = node_set - train_node_set
65 |
66 | val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
67 | test_mask = timestamps > test_time
68 |
69 | if different_new_nodes_between_val_and_test:
70 | n_new_nodes = len(new_test_node_set) // 2
71 | val_new_node_set = set(list(new_test_node_set)[:n_new_nodes])
72 | test_new_node_set = set(list(new_test_node_set)[n_new_nodes:])
73 |
74 | edge_contains_new_val_node_mask = np.array(
75 | [(a in val_new_node_set or b in val_new_node_set) for a, b in zip(sources, destinations)])
76 | edge_contains_new_test_node_mask = np.array(
77 | [(a in test_new_node_set or b in test_new_node_set) for a, b in zip(sources, destinations)])
78 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_val_node_mask)
79 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_test_node_mask)
80 |
81 |
82 | else:
83 | edge_contains_new_node_mask = np.array(
84 | [(a in new_node_set or b in new_node_set) for a, b in zip(sources, destinations)])
85 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask)
86 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask)
87 |
88 | # validation and test with all edges
89 | val_data = Data(sources[val_mask], destinations[val_mask], timestamps[val_mask],
90 | edge_idxs[val_mask], labels[val_mask])
91 |
92 | test_data = Data(sources[test_mask], destinations[test_mask], timestamps[test_mask],
93 | edge_idxs[test_mask], labels[test_mask])
94 |
95 | # validation and test with edges that at least has one new node (not in training set)
96 | new_node_val_data = Data(sources[new_node_val_mask], destinations[new_node_val_mask],
97 | timestamps[new_node_val_mask],
98 | edge_idxs[new_node_val_mask], labels[new_node_val_mask])
99 |
100 | new_node_test_data = Data(sources[new_node_test_mask], destinations[new_node_test_mask],
101 | timestamps[new_node_test_mask], edge_idxs[new_node_test_mask],
102 | labels[new_node_test_mask])
103 |
104 | print("The dataset has {} interactions, involving {} different nodes".format(full_data.n_interactions,
105 | full_data.n_unique_nodes))
106 | print("The training dataset has {} interactions, involving {} different nodes".format(
107 | train_data.n_interactions, train_data.n_unique_nodes))
108 | print("The validation dataset has {} interactions, involving {} different nodes".format(
109 | val_data.n_interactions, val_data.n_unique_nodes))
110 | print("The test dataset has {} interactions, involving {} different nodes".format(
111 | test_data.n_interactions, test_data.n_unique_nodes))
112 | print("The new node validation dataset has {} interactions, involving {} different nodes".format(
113 | new_node_val_data.n_interactions, new_node_val_data.n_unique_nodes))
114 | print("The new node test dataset has {} interactions, involving {} different nodes".format(
115 | new_node_test_data.n_interactions, new_node_test_data.n_unique_nodes))
116 | print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(
117 | len(new_test_node_set)))
118 |
119 | return node_features, edge_features, full_data, train_data, val_data, test_data, \
120 | new_node_val_data, new_node_test_data
121 |
122 |
123 | def compute_time_statistics(sources, destinations, timestamps):
124 | last_timestamp_sources = dict()
125 | last_timestamp_dst = dict()
126 | all_timediffs_src = []
127 | all_timediffs_dst = []
128 | for k in range(len(sources)):
129 | source_id = sources[k]
130 | dest_id = destinations[k]
131 | c_timestamp = timestamps[k]
132 | if source_id not in last_timestamp_sources.keys():
133 | last_timestamp_sources[source_id] = 0
134 | if dest_id not in last_timestamp_dst.keys():
135 | last_timestamp_dst[dest_id] = 0
136 | all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id])
137 | all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id])
138 | last_timestamp_sources[source_id] = c_timestamp
139 | last_timestamp_dst[dest_id] = c_timestamp
140 | assert len(all_timediffs_src) == len(sources)
141 | assert len(all_timediffs_dst) == len(sources)
142 | mean_time_shift_src = np.mean(all_timediffs_src)
143 | std_time_shift_src = np.std(all_timediffs_src)
144 | mean_time_shift_dst = np.mean(all_timediffs_dst)
145 | std_time_shift_dst = np.std(all_timediffs_dst)
146 |
147 | return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst
148 |
--------------------------------------------------------------------------------
/models/dgnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from collections import defaultdict
4 | import logging
5 | import numpy as np
6 |
7 | from modules.memory import Memory
8 | from modules.message_aggregator import get_message_aggregator
9 | from modules.merge import MemoryMerge
10 | from modules.message_function import MessageFunction
11 | from modules.update import MemoryUpdater
12 | from modules.propagater import Propagater
13 |
14 |
15 | class DGNN(nn.Module):
16 | def __init__(self, neighbor_finder, node_features, edge_features, device,
17 | dropout=0.1,
18 | memory_update_at_start=True, message_dimension=100,
19 | memory_dimension=200, n_neighbors=None, aggregator_type="last",
20 | mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
21 | std_time_shift_dst=1, threshold=2):
22 | super(DGNN, self).__init__()
23 | self.neighbor_finder = neighbor_finder
24 | self.device = device
25 | self.logger = logging.getLogger(__name__)
26 |
27 | self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
28 | self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)
29 |
30 | self.n_node_features = self.node_raw_features.shape[1]
31 | self.n_nodes = self.node_raw_features.shape[0]
32 | self.n_edge_features = self.edge_raw_features.shape[1]
33 | self.embedding_dimension = self.n_node_features
34 | self.n_neighbors = n_neighbors
35 | self.memory_s = None
36 | self.memory_g = None
37 |
38 | self.threshold = threshold
39 | self.mean_time_shift_src = mean_time_shift_src
40 | self.std_time_shift_src = std_time_shift_src
41 | self.mean_time_shift_dst = mean_time_shift_dst
42 | self.std_time_shift_dst = std_time_shift_dst
43 | self.memory_dimension = memory_dimension
44 | self.memory_update_at_start = memory_update_at_start
45 | self.message_dimension = message_dimension
46 | self.memory_merge = MemoryMerge(self.memory_dimension, self.device)
47 | self.memory_s = Memory(n_nodes=self.n_nodes,
48 | memory_dimension=self.memory_dimension,
49 | message_dimension=message_dimension,
50 | device=device)
51 | self.memory_g = Memory(n_nodes=self.n_nodes,
52 | memory_dimension=self.memory_dimension,
53 | message_dimension=message_dimension,
54 | device=device)
55 | self.message_dim = message_dimension
56 | self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
57 | device=device)
58 | self.message_function = MessageFunction(memory_dimension=memory_dimension,
59 | message_dimension=message_dimension, edge_dimension=self.n_edge_features,
60 | device=self.device)
61 | self.memory_updater_s = MemoryUpdater(memory=self.memory_s, message_dimension=message_dimension,
62 | memory_dimension=self.memory_dimension,
63 | mean_time_shift_src=self.mean_time_shift_src/2,
64 | device=self.device)
65 | self.memory_updater_g = MemoryUpdater(memory=self.memory_g, message_dimension=message_dimension,
66 | memory_dimension=self.memory_dimension,
67 | mean_time_shift_src=self.mean_time_shift_dst / 2,
68 | device=self.device)
69 | self.propagater_s = Propagater(memory=self.memory_s, message_dimension=message_dimension,
70 | memory_dimension=self.memory_dimension,
71 | mean_time_shift_src=self.mean_time_shift_src / 2,
72 | neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold,
73 | device=self.device)
74 | self.propagater_g = Propagater(memory=self.memory_g, message_dimension=message_dimension,
75 | memory_dimension=self.memory_dimension,
76 | mean_time_shift_src=self.mean_time_shift_dst / 2,
77 | neighbor_finder=self.neighbor_finder, n_neighbors=self.n_neighbors, tau=self.threshold,
78 | device=self.device)
79 | self.W_s = nn.Parameter(torch.zeros((memory_dimension, memory_dimension // 2)).to(self.device))
80 | #nn.xavier_
81 | self.W_g = nn.Parameter(torch.zeros((memory_dimension, memory_dimension // 2)).to(self.device))
82 |
83 | def update_memory(self, source_nodes, destination_nodes, messages_s, messages_g):
84 | # Aggregate messages for the same nodes
85 |
86 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes,
87 | messages_s)
88 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate(destination_nodes,
89 | messages_g)
90 |
91 | # Update the memory with the aggregated messages
92 | self.memory_updater_s.update_memory(unique_src_nodes, unique_src_messages,
93 | timestamps=unique_src_timestamps)
94 | self.memory_updater_g.update_memory(unique_des_nodes, unique_des_messages,
95 | timestamps=unique_des_timestamps)
96 |
97 | def propagate(self, source_nodes, destination_nodes, messages_s, messages_g):
98 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes,
99 | messages_s)
100 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate(
101 | destination_nodes,
102 | messages_g)
103 |
104 | self.propagater_s(self.memory_s.memory, unique_src_nodes, unique_src_messages,
105 | timestamps=unique_src_timestamps)
106 | self.propagater_g(self.memory_g.memory, unique_des_nodes, unique_des_messages,
107 | timestamps=unique_des_timestamps)
108 |
109 | def compute_loss(self, memory_s, memory_g, source_nodes, destination_nodes):
110 | source_mem = self.memory_merge(memory_s[1][source_nodes], memory_g[1][source_nodes])
111 | destination_mem = self.memory_merge(memory_s[1][destination_nodes], memory_g[1][destination_nodes])
112 | source_emb = torch.matmul(source_mem, self.W_s)
113 | destination_emb = torch.matmul(destination_mem, self.W_g)
114 | score = torch.sum(source_emb*destination_emb, dim=1)
115 | return score.sigmoid()
116 |
117 | def forward(self, source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, test=False):
118 | n_samples = len(source_nodes)
119 | nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes])
120 | positives = np.concatenate([source_nodes, destination_nodes])
121 | timestamps = np.concatenate([edge_times, edge_times])
122 | memory = None
123 | time_diffs = None
124 | memory_s, last_update_s, memory_g, last_update_g = \
125 | self.get_updated_memory(list(range(self.n_nodes)), list(range(self.n_nodes)),
126 | self.memory_s.messages, self.memory_g.messages)
127 |
128 | pos_score = self.compute_loss(memory_s, memory_g, source_nodes, destination_nodes)
129 | neg_score = self.compute_loss(memory_s, memory_g, source_nodes, negative_nodes)
130 | self.update_memory(source_nodes, destination_nodes,
131 | self.memory_s.messages, self.memory_g.messages)
132 | self.propagate(source_nodes, destination_nodes,
133 | self.memory_s.messages, self.memory_g.messages)
134 | self.memory_s.clear_messages(positives)
135 | self.memory_g.clear_messages(positives)
136 | unique_sources, source_id_to_messages = self.get_messages(source_nodes,
137 | destination_nodes,
138 | edge_times, edge_idxs)
139 | unique_destinations, destination_id_to_messages = self.get_messages(destination_nodes,
140 | source_nodes,
141 | edge_times,
142 | edge_idxs)
143 | self.memory_s.store_raw_messages(unique_sources, source_id_to_messages)
144 | self.memory_g.store_raw_messages(unique_destinations, destination_id_to_messages)
145 | if not test:
146 | return pos_score, neg_score
147 | else:
148 | source_mem = self.memory_merge(memory_s[1][source_nodes], memory_g[1][source_nodes])
149 | destination_mem = self.memory_merge(memory_s[1][destination_nodes], memory_g[1][destination_nodes])
150 | return source_mem, destination_mem
151 |
152 | def get_messages(self, source_nodes, destination_nodes, edge_times, edge_idxs):
153 | edge_times = torch.from_numpy(edge_times).float().to(self.device)
154 | edge_features = self.edge_raw_features[edge_idxs]
155 | source_memory = self.memory_merge(self.memory_s.memory[1][source_nodes],
156 | self.memory_g.memory[1][source_nodes])
157 | destination_memory = self.memory_merge(self.memory_s.memory[1][destination_nodes],
158 | self.memory_g.memory[1][destination_nodes])
159 |
160 | source_message = self.message_function.compute_message(source_memory, destination_memory, edge_features)
161 | messages = defaultdict(list)
162 | unique_sources = np.unique(source_nodes)
163 |
164 | for i in range(len(source_nodes)):
165 | messages[source_nodes[i]].append((source_message[i], edge_times[i]))
166 |
167 | return unique_sources, messages
168 |
169 | def get_updated_memory(self, source_nodes, destination_nodes, message_s, message_g):
170 | unique_src_nodes, unique_src_messages, unique_src_timestamps = self.message_aggregator.aggregate(source_nodes,
171 | message_s)
172 | unique_des_nodes, unique_des_messages, unique_des_timestamps = self.message_aggregator.aggregate(destination_nodes,
173 | message_g)
174 | updated_src_memory, updated_src_last_update = self.memory_updater_s.update_memory(unique_src_nodes,
175 | unique_src_messages,
176 | timestamps=unique_src_timestamps,
177 | inplace=False)
178 | updated_des_memory, updated_des_last_update = self.memory_updater_g.update_memory(unique_des_nodes,
179 | unique_des_messages,
180 | timestamps=unique_des_timestamps,
181 | inplace=False)
182 | updated_src_memory = self.propagater_s(updated_src_memory, unique_src_nodes, unique_src_messages,
183 | timestamps=unique_src_timestamps, inplace=False)
184 | updated_des_memory = self.propagater_g(updated_des_memory, unique_des_nodes, unique_des_messages,
185 | timestamps=unique_des_timestamps, inplace=False)
186 |
187 | return updated_src_memory, updated_src_last_update, updated_des_memory, updated_des_last_update
188 |
189 | def set_neighbor_finder(self, neighbor_finder):
190 | self.neighbor_finder = neighbor_finder
191 | self.propagater_s.neighbor_finder = neighbor_finder
192 | self.propagater_g.neighbor_finder = neighbor_finder
193 |
194 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import time
4 | import sys
5 | import argparse
6 | import torch
7 | import numpy as np
8 | import pickle
9 | from pathlib import Path
10 |
11 | from evaluation import eval_edge_prediction
12 | from models.dgnn import DGNN
13 | from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder
14 | from utils.data_processing import get_data, compute_time_statistics
15 |
16 |
17 |
18 | ### Argument and global variables
19 | parser = argparse.ArgumentParser('DGGN training')
20 | parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
21 | default='reddit')
22 | parser.add_argument('--bs', type=int, default=200, help='Batch_size')
23 | parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
24 | parser.add_argument('--n_degree', type=int, default=20, help='Number of neighbors to sample')
25 | parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs')
26 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate')
27 | parser.add_argument('--patience', type=int, default=8, help='Patience for early stopping')
28 | parser.add_argument('--n_runs', type=int, default=10, help='Number of runs')
29 | parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
30 | parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
31 | parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding')
32 | parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to '
33 | 'backprop')
34 |
35 | parser.add_argument('--aggregator', type=str, default="last", help='Type of message '
36 | 'aggregator')
37 |
38 | parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages')
39 |
40 | parser.add_argument('--memory_dim', type=int, default=172, help='Dimensions of the memory for '
41 | 'each user')
42 | parser.add_argument('--different_new_nodes', action='store_true',
43 | help='Whether to use disjoint set of new nodes for train and val')
44 | parser.add_argument('--uniform', action='store_true',
45 | help='take uniform sampling from temporal neighbors')
46 | parser.add_argument('--seed', type=int, default=0, help='random seed')
47 |
48 | try:
49 | args = parser.parse_args()
50 | except:
51 | parser.print_help()
52 | sys.exit(0)
53 | torch.manual_seed(args.seed)
54 | np.random.seed(args.seed)
55 | BATCH_SIZE = args.bs
56 | NUM_NEIGHBORS = args.n_degree
57 | NUM_NEG = 1
58 | NUM_EPOCH = args.n_epoch
59 | DROP_OUT = args.drop_out
60 | GPU = args.gpu
61 | SEQ_LEN = NUM_NEIGHBORS
62 | DATA = args.data
63 | LEARNING_RATE = args.lr
64 | NODE_DIM = args.node_dim
65 | MESSAGE_DIM = args.message_dim
66 | MEMORY_DIM = args.memory_dim
67 |
68 | Path("./saved_models/").mkdir(parents=True, exist_ok=True)
69 | Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
70 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth'
71 | get_checkpoint_path = lambda \
72 | epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth'
73 |
74 | ### set up logger
75 | logging.basicConfig(level=logging.INFO)
76 | logger = logging.getLogger()
77 | logger.setLevel(logging.DEBUG)
78 | Path("log/").mkdir(parents=True, exist_ok=True)
79 | timenow = time.strftime("%Y-%m-%d_%H:%M:%S",time.localtime(time.time()))
80 | fh = logging.FileHandler('log/{}.log'.format(str(args.prefix)+timenow))
81 | fh.setLevel(logging.DEBUG)
82 | ch = logging.StreamHandler()
83 | ch.setLevel(logging.WARN)
84 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
85 | fh.setFormatter(formatter)
86 | ch.setFormatter(formatter)
87 | logger.addHandler(fh)
88 | logger.addHandler(ch)
89 | logger.info(args)
90 |
91 | ### Extract data for training, validation and testing
92 | node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \
93 | new_node_test_data = get_data(DATA,
94 | different_new_nodes_between_val_and_test=args.different_new_nodes)
95 |
96 | # Initialize training neighbor finder to retrieve temporal graph
97 | train_ngh_finder = get_neighbor_finder(train_data, args.uniform)
98 |
99 | # Initialize validation and test neighbor finder to retrieve temporal graph
100 | full_ngh_finder = get_neighbor_finder(full_data, args.uniform)
101 |
102 | # Initialize negative samplers. Set seeds for validation and testing so negatives are the same
103 | # across different runs
104 | # NB: in the inductive setting, negatives are sampled only amongst other new nodes
105 | train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
106 | val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
107 | nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
108 | seed=1)
109 | test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
110 | nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
111 | new_node_test_data.destinations,
112 | seed=3)
113 |
114 | # Set device
115 | device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu'
116 | device = torch.device(device_string)
117 |
118 | # Compute time statistics
119 | mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
120 | compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)
121 |
122 | for i in range(args.n_runs):
123 | results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix)
124 | Path("results/").mkdir(parents=True, exist_ok=True)
125 |
126 | # Initialize Model
127 | dgnn = DGNN(neighbor_finder=train_ngh_finder, node_features=node_features,
128 | edge_features=edge_features, device=device, dropout=DROP_OUT,
129 | message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM,
130 | aggregator_type=args.aggregator, n_neighbors=NUM_NEIGHBORS,
131 | mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
132 | mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst)
133 | for a in dgnn.parameters():
134 | if a.ndim > 1:
135 | torch.nn.init.xavier_uniform_(a)
136 |
137 | dgnn.memory_s.__init_memory__()
138 | dgnn.memory_g.__init_memory__()
139 |
140 | criterion = torch.nn.BCELoss()
141 | optimizer = torch.optim.Adam(dgnn.parameters(), lr=LEARNING_RATE)
142 | dgnn = dgnn.to(device)
143 |
144 | num_instance = len(train_data.sources)
145 | num_batch = math.ceil(num_instance / BATCH_SIZE)
146 |
147 | logger.info('num of training instances: {}'.format(num_instance))
148 | logger.info('num of batches per epoch: {}'.format(num_batch))
149 | idx_list = np.arange(num_instance)
150 |
151 | new_nodes_val_aps = []
152 | val_aps = []
153 | epoch_times = []
154 | total_epoch_times = []
155 | train_losses = []
156 |
157 | early_stopper = EarlyStopMonitor(max_round=args.patience)
158 | for epoch in range(NUM_EPOCH):
159 | start_epoch = time.time()
160 | ### Training
161 |
162 | # Reinitialize memory of the model at the start of each epoch
163 | dgnn.memory_s.__init_memory__()
164 | dgnn.memory_g.__init_memory__()
165 |
166 | # Train using only training graph
167 | dgnn.set_neighbor_finder(train_ngh_finder)
168 | m_loss = []
169 |
170 | logger.info('start {} epoch'.format(epoch))
171 | for k in range(0, num_batch, args.backprop_every):
172 | loss = 0
173 | optimizer.zero_grad()
174 |
175 | # Custom loop to allow to perform backpropagation only every a certain number of batches
176 | for j in range(args.backprop_every):
177 | batch_idx = k + j
178 |
179 | if batch_idx >= num_batch:
180 | continue
181 |
182 | start_idx = batch_idx * BATCH_SIZE
183 | end_idx = min(num_instance, start_idx + BATCH_SIZE)
184 | sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
185 | train_data.destinations[start_idx:end_idx]
186 | edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
187 | timestamps_batch = train_data.timestamps[start_idx:end_idx]
188 |
189 | size = len(sources_batch)
190 | _, negatives_batch = train_rand_sampler.sample(size)
191 |
192 | with torch.no_grad():
193 | pos_label = torch.ones(size, dtype=torch.float, device=device)
194 | neg_label = torch.zeros(size, dtype=torch.float, device=device)
195 |
196 | dgnn = dgnn.train()
197 | pos_prob, neg_prob = dgnn(sources_batch, destinations_batch, negatives_batch,
198 | timestamps_batch, edge_idxs_batch)
199 |
200 | loss += criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label)
201 |
202 | loss /= args.backprop_every
203 |
204 | loss.backward()
205 | optimizer.step()
206 | m_loss.append(loss.item())
207 |
208 | # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to
209 | # the start of time
210 | dgnn.memory_s.detach_memory()
211 | dgnn.memory_g.detach_memory()
212 |
213 | epoch_time = time.time() - start_epoch
214 | epoch_times.append(epoch_time)
215 |
216 | ### Validation
217 | # Validation uses the full graph
218 | dgnn.set_neighbor_finder(full_ngh_finder)
219 |
220 | # Backup memory at the end of training, so later we can restore it and use it for the
221 | # validation on unseen nodes
222 | train_memory_backup_s = dgnn.memory_s.backup_memory()
223 | train_memory_backup_g = dgnn.memory_g.backup_memory()
224 |
225 | val_ap, val_auc = eval_edge_prediction(model=dgnn,
226 | negative_edge_sampler=val_rand_sampler,
227 | data=val_data,
228 | n_neighbors=NUM_NEIGHBORS)
229 |
230 | val_memory_backup_s = dgnn.memory_s.backup_memory()
231 | val_memory_backup_g = dgnn.memory_g.backup_memory()
232 | # Restore memory we had at the end of training to be used when validating on new nodes.
233 | # Also backup memory after validation so it can be used for testing (since test edges are
234 | # strictly later in time than validation edges)
235 | dgnn.memory_s.restore_memory(train_memory_backup_s)
236 | dgnn.memory_g.restore_memory(train_memory_backup_g)
237 |
238 | # Validate on unseen nodes
239 | nn_val_ap, nn_val_auc = eval_edge_prediction(model=dgnn,negative_edge_sampler=val_rand_sampler,
240 | data=new_node_val_data, n_neighbors=NUM_NEIGHBORS)
241 |
242 | # Restore memory we had at the end of validation
243 | dgnn.memory_s.restore_memory(val_memory_backup_s)
244 | dgnn.memory_g.restore_memory(val_memory_backup_g)
245 |
246 | new_nodes_val_aps.append(nn_val_ap)
247 | val_aps.append(val_ap)
248 | train_losses.append(np.mean(m_loss))
249 |
250 | # Save temporary results to disk
251 | pickle.dump({
252 | "val_aps": val_aps,
253 | "new_nodes_val_aps": new_nodes_val_aps,
254 | "train_losses": train_losses,
255 | "epoch_times": epoch_times,
256 | "total_epoch_times": total_epoch_times
257 | }, open(results_path, "wb"))
258 |
259 | total_epoch_time = time.time() - start_epoch
260 | total_epoch_times.append(total_epoch_time)
261 |
262 | logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
263 | logger.info('Epoch mean loss: {}'.format(np.mean(m_loss)))
264 | logger.info(
265 | 'val auc: {}, new node val auc: {}'.format(val_auc, nn_val_auc))
266 | logger.info(
267 | 'val ap: {}, new node val ap: {}'.format(val_ap, nn_val_ap))
268 |
269 | # Early stopping
270 | if early_stopper.early_stop_check(val_ap):
271 | logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
272 | logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
273 | best_model_path = get_checkpoint_path(early_stopper.best_epoch)
274 | dgnn.load_state_dict(torch.load(best_model_path))
275 | logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
276 | dgnn.eval()
277 | break
278 | else:
279 | torch.save(dgnn.state_dict(), get_checkpoint_path(epoch))
280 |
281 | # Training has finished, we have loaded the best model, and we want to backup its current
282 | # memory (which has seen validation edges) so that it can also be used when testing on unseen
283 | # nodes
284 | val_memory_backup_s = dgnn.memory_s.backup_memory()
285 | val_memory_backup_g = dgnn.memory_g.backup_memory()
286 |
287 | ### Test
288 | dgnn.propagater_g.neighbor_finder = full_ngh_finder
289 | dgnn.propagater_s.neighbor_finder = full_ngh_finder
290 | test_ap, test_auc = eval_edge_prediction(model=dgnn, negative_edge_sampler=test_rand_sampler,
291 | data=test_data, n_neighbors=NUM_NEIGHBORS)
292 |
293 | dgnn.memory_s.restore_memory(val_memory_backup_s)
294 | dgnn.memory_g.restore_memory(val_memory_backup_g)
295 |
296 | # Test on unseen nodes
297 | nn_test_ap, nn_test_auc = eval_edge_prediction(model=dgnn, negative_edge_sampler=nn_test_rand_sampler,
298 | data=new_node_test_data, n_neighbors=NUM_NEIGHBORS)
299 |
300 | logger.info(
301 | 'Test statistics: Old nodes -- auc: {}, ap: {}'.format(test_auc, test_ap))
302 | logger.info(
303 | 'Test statistics: New nodes -- auc: {}, ap: {}'.format(nn_test_auc, nn_test_ap))
304 | # Save results for this run
305 | pickle.dump({
306 | "val_aps": val_aps,
307 | "new_nodes_val_aps": new_nodes_val_aps,
308 | "test_ap": test_ap,
309 | "new_node_test_ap": nn_test_ap,
310 | "epoch_times": epoch_times,
311 | "train_losses": train_losses,
312 | "total_epoch_times": total_epoch_times
313 | }, open(results_path, "wb"))
314 |
315 | logger.info('Saving DGNN model')
316 | # Restore memory at the end of validation (save a model which is ready for testing)
317 | dgnn.memory_s.restore_memory(val_memory_backup_s)
318 | dgnn.memory_g.restore_memory(val_memory_backup_g)
319 | torch.save(dgnn.state_dict(), MODEL_SAVE_PATH)
320 | logger.info('DGNN model saved')
321 |
--------------------------------------------------------------------------------
/trainUCI.py:
--------------------------------------------------------------------------------
1 | import math
2 | import logging
3 | import time
4 | import sys
5 | import argparse
6 | import torch
7 | import numpy as np
8 | import pickle
9 | from pathlib import Path
10 |
11 | from evaluation import eval_edge_prediction
12 | from models.dgnn import DGNN
13 | from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder
14 | from utils.data_processing import get_data, compute_time_statistics
15 |
16 |
17 |
18 | ### Argument and global variables
19 | parser = argparse.ArgumentParser('DGGN training')
20 | parser.add_argument('-d', '--data', type=str, help='Dataset name (eg. wikipedia or reddit)',
21 | default='UCI-Msg')
22 | parser.add_argument('--bs', type=int, default=200, help='Batch_size')
23 | parser.add_argument('--prefix', type=str, default='', help='Prefix to name the checkpoints')
24 | parser.add_argument('--n_degree', type=int, default=20, help='Number of neighbors to sample')
25 | parser.add_argument('--n_epoch', type=int, default=50, help='Number of epochs')
26 | parser.add_argument('--lr', type=float, default=0.00001, help='Learning rate')
27 | parser.add_argument('--patience', type=int, default=8, help='Patience for early stopping')
28 | parser.add_argument('--n_runs', type=int, default=10, help='Number of runs')
29 | parser.add_argument('--drop_out', type=float, default=0.1, help='Dropout probability')
30 | parser.add_argument('--gpu', type=int, default=0, help='Idx for the gpu to use')
31 | parser.add_argument('--node_dim', type=int, default=100, help='Dimensions of the node embedding')
32 | parser.add_argument('--backprop_every', type=int, default=1, help='Every how many batches to '
33 | 'backprop')
34 |
35 | parser.add_argument('--aggregator', type=str, default="last", help='Type of message '
36 | 'aggregator')
37 |
38 | parser.add_argument('--message_dim', type=int, default=100, help='Dimensions of the messages')
39 |
40 | parser.add_argument('--memory_dim', type=int, default=100, help='Dimensions of the memory for '
41 | 'each user')
42 | parser.add_argument('--different_new_nodes', action='store_true',
43 | help='Whether to use disjoint set of new nodes for train and val')
44 | parser.add_argument('--uniform', action='store_true',
45 | help='take uniform sampling from temporal neighbors')
46 | parser.add_argument('--seed', type=int, default=0, help='random seed')
47 | parser.add_argument('--threshold', type=int, default=2, help='time threshold')
48 |
49 | try:
50 | args = parser.parse_args()
51 | except:
52 | parser.print_help()
53 | sys.exit(0)
54 | torch.manual_seed(args.seed)
55 | np.random.seed(args.seed)
56 | BATCH_SIZE = args.bs
57 | NUM_NEIGHBORS = args.n_degree
58 | NUM_NEG = 1
59 | NUM_EPOCH = args.n_epoch
60 | DROP_OUT = args.drop_out
61 | GPU = args.gpu
62 | SEQ_LEN = NUM_NEIGHBORS
63 | DATA = args.data
64 | LEARNING_RATE = args.lr
65 | NODE_DIM = args.node_dim
66 | MESSAGE_DIM = args.message_dim
67 | MEMORY_DIM = args.memory_dim
68 |
69 | Path("./saved_models/").mkdir(parents=True, exist_ok=True)
70 | Path("./saved_checkpoints/").mkdir(parents=True, exist_ok=True)
71 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.data}.pth'
72 | get_checkpoint_path = lambda \
73 | epoch: f'./saved_checkpoints/{args.prefix}-{args.data}-{epoch}.pth'
74 |
75 | ### set up logger
76 | logging.basicConfig(level=logging.INFO)
77 | logger = logging.getLogger()
78 | logger.setLevel(logging.DEBUG)
79 | Path("log/").mkdir(parents=True, exist_ok=True)
80 | timenow = time.strftime("%Y-%m-%d_%H:%M:%S",time.localtime(time.time()))
81 | fh = logging.FileHandler('log/{}.log'.format(str(args.prefix)+timenow))
82 | fh.setLevel(logging.DEBUG)
83 | ch = logging.StreamHandler()
84 | ch.setLevel(logging.WARN)
85 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
86 | fh.setFormatter(formatter)
87 | ch.setFormatter(formatter)
88 | logger.addHandler(fh)
89 | logger.addHandler(ch)
90 | logger.info(args)
91 |
92 | ### Extract data for training, validation and testing
93 | node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \
94 | new_node_test_data = get_data(DATA,
95 | different_new_nodes_between_val_and_test=args.different_new_nodes)
96 |
97 | # Initialize training neighbor finder to retrieve temporal graph
98 | train_ngh_finder = get_neighbor_finder(train_data, args.uniform)
99 |
100 | # Initialize validation and test neighbor finder to retrieve temporal graph
101 | full_ngh_finder = get_neighbor_finder(full_data, args.uniform)
102 |
103 | # Initialize negative samplers. Set seeds for validation and testing so negatives are the same
104 | # across different runs
105 | # NB: in the inductive setting, negatives are sampled only amongst other new nodes
106 | train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
107 | val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
108 | nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
109 | seed=1)
110 | test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
111 | nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
112 | new_node_test_data.destinations,
113 | seed=3)
114 |
115 | # Set device
116 | device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu'
117 | device = torch.device(device_string)
118 |
119 | # Compute time statistics
120 | mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
121 | compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)
122 |
123 | for i in range(args.n_runs):
124 | results_path = "results/{}_{}.pkl".format(args.prefix, i) if i > 0 else "results/{}.pkl".format(args.prefix)
125 | Path("results/").mkdir(parents=True, exist_ok=True)
126 |
127 | # Initialize Model
128 | dgnn = DGNN(neighbor_finder=train_ngh_finder, node_features=node_features,
129 | edge_features=edge_features, device=device, dropout=DROP_OUT,
130 | message_dimension=MESSAGE_DIM, memory_dimension=MEMORY_DIM,
131 | aggregator_type=args.aggregator, n_neighbors=NUM_NEIGHBORS,
132 | mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
133 | mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst, threshold=args.threshold)
134 | for a in dgnn.parameters():
135 | if a.ndim > 1:
136 | torch.nn.init.xavier_uniform_(a)
137 |
138 | dgnn.memory_s.__init_memory__(args.seed)
139 | dgnn.memory_g.__init_memory__(args.seed)
140 |
141 | criterion = torch.nn.BCELoss()
142 | optimizer = torch.optim.Adam(dgnn.parameters(), lr=LEARNING_RATE)
143 | dgnn = dgnn.to(device)
144 |
145 | num_instance = len(train_data.sources)
146 | num_batch = math.ceil(num_instance / BATCH_SIZE)
147 |
148 | logger.info('num of training instances: {}'.format(num_instance))
149 | logger.info('num of batches per epoch: {}'.format(num_batch))
150 | idx_list = np.arange(num_instance)
151 |
152 | new_nodes_val_mrrs = []
153 | val_mrrs = []
154 | epoch_times = []
155 | total_epoch_times = []
156 | train_losses = []
157 |
158 | early_stopper = EarlyStopMonitor(max_round=args.patience)
159 | for epoch in range(NUM_EPOCH):
160 | start_epoch = time.time()
161 | ### Training
162 |
163 | # Reinitialize memory of the model at the start of each epoch
164 | dgnn.memory_s.__init_memory__(args.seed)
165 | dgnn.memory_g.__init_memory__(args.seed)
166 |
167 | # Train using only training graph
168 | dgnn.set_neighbor_finder(train_ngh_finder)
169 | m_loss = []
170 |
171 | logger.info('start {} epoch'.format(epoch))
172 | for k in range(0, num_batch, args.backprop_every):
173 | loss = 0
174 | optimizer.zero_grad()
175 |
176 | # Custom loop to allow to perform backpropagation only every a certain number of batches
177 | for j in range(args.backprop_every):
178 | batch_idx = k + j
179 |
180 | if batch_idx >= num_batch:
181 | continue
182 |
183 | start_idx = batch_idx * BATCH_SIZE
184 | end_idx = min(num_instance, start_idx + BATCH_SIZE)
185 | sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
186 | train_data.destinations[start_idx:end_idx]
187 | edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
188 | timestamps_batch = train_data.timestamps[start_idx:end_idx]
189 |
190 | size = len(sources_batch)
191 | _, negatives_batch = train_rand_sampler.sample(size)
192 |
193 | with torch.no_grad():
194 | pos_label = torch.ones(size, dtype=torch.float, device=device)
195 | neg_label = torch.zeros(size, dtype=torch.float, device=device)
196 |
197 | dgnn = dgnn.train()
198 | pos_prob, neg_prob = dgnn(sources_batch, destinations_batch, negatives_batch,
199 | timestamps_batch, edge_idxs_batch)
200 |
201 | loss += criterion(pos_prob, pos_label) + criterion(neg_prob, neg_label)
202 |
203 | loss /= args.backprop_every
204 |
205 | loss.backward()
206 | optimizer.step()
207 | m_loss.append(loss.item())
208 |
209 | # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to
210 | # the start of time
211 | dgnn.memory_s.detach_memory()
212 | dgnn.memory_g.detach_memory()
213 |
214 | epoch_time = time.time() - start_epoch
215 | epoch_times.append(epoch_time)
216 |
217 | ### Validation
218 | # Validation uses the full graph
219 | dgnn.set_neighbor_finder(full_ngh_finder)
220 |
221 | # Backup memory at the end of training, so later we can restore it and use it for the
222 | # validation on unseen nodes
223 | train_memory_backup_s = dgnn.memory_s.backup_memory()
224 | train_memory_backup_g = dgnn.memory_g.backup_memory()
225 |
226 | val_mrr, val_recall_20, val_recall_50= eval_edge_prediction(model=dgnn, negative_edge_sampler=val_rand_sampler, data=val_data,
227 | n_neighbors=NUM_NEIGHBORS)
228 |
229 | val_memory_backup_s = dgnn.memory_s.backup_memory()
230 | val_memory_backup_g = dgnn.memory_g.backup_memory()
231 | # Restore memory we had at the end of training to be used when validating on new nodes.
232 | # Also backup memory after validation so it can be used for testing (since test edges are
233 | # strictly later in time than validation edges)
234 | dgnn.memory_s.restore_memory(train_memory_backup_s)
235 | dgnn.memory_g.restore_memory(train_memory_backup_g)
236 |
237 | # Validate on unseen nodes
238 | nn_val_mrr, nn_val_recall_20, nn_val_recall_50 = eval_edge_prediction(model=dgnn,negative_edge_sampler=val_rand_sampler,
239 | data=new_node_val_data, n_neighbors=NUM_NEIGHBORS)
240 |
241 | # Restore memory we had at the end of validation
242 | dgnn.memory_s.restore_memory(val_memory_backup_s)
243 | dgnn.memory_g.restore_memory(val_memory_backup_g)
244 |
245 | new_nodes_val_mrrs.append(nn_val_mrr)
246 | val_mrrs.append(val_mrr)
247 | train_losses.append(np.mean(m_loss))
248 |
249 | # Save temporary results to disk
250 | pickle.dump({
251 | "val_mrrs": val_mrrs,
252 | "new_nodes_val_aps": new_nodes_val_mrrs,
253 | "train_losses": train_losses,
254 | "epoch_times": epoch_times,
255 | "total_epoch_times": total_epoch_times
256 | }, open(results_path, "wb"))
257 |
258 | total_epoch_time = time.time() - start_epoch
259 | total_epoch_times.append(total_epoch_time)
260 |
261 | logger.info('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
262 | logger.info('Epoch mean loss: {}'.format(np.mean(m_loss)))
263 | logger.info(
264 | 'val mrr: {}, new node val mrr: {}'.format(val_mrr, nn_val_mrr))
265 | logger.info(
266 | 'val recall 20: {}, new node val recall 20: {}'.format(val_recall_20, nn_val_recall_20))
267 | logger.info(
268 | 'val recall 50: {}, new node val recall 50: {}'.format(val_recall_50, nn_val_recall_50))
269 |
270 | # Early stopping
271 | if early_stopper.early_stop_check(val_mrr):
272 | logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
273 | logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
274 | best_model_path = get_checkpoint_path(early_stopper.best_epoch)
275 | dgnn.load_state_dict(torch.load(best_model_path))
276 | logger.info(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
277 | dgnn.eval()
278 | break
279 | else:
280 | torch.save(dgnn.state_dict(), get_checkpoint_path(epoch))
281 |
282 | # Training has finished, we have loaded the best model, and we want to backup its current
283 | # memory (which has seen validation edges) so that it can also be used when testing on unseen
284 | # nodes
285 | val_memory_backup_s = dgnn.memory_s.backup_memory()
286 | val_memory_backup_g = dgnn.memory_g.backup_memory()
287 |
288 | ### Test
289 | dgnn.propagater_g.neighbor_finder = full_ngh_finder
290 | dgnn.propagater_s.neighbor_finder = full_ngh_finder
291 | test_mrr, test_recall_20, test_recall_50 = eval_edge_prediction(model=dgnn, negative_edge_sampler=test_rand_sampler,
292 | data=test_data, n_neighbors=NUM_NEIGHBORS)
293 |
294 | dgnn.memory_s.restore_memory(val_memory_backup_s)
295 | dgnn.memory_g.restore_memory(val_memory_backup_g)
296 |
297 | # Test on unseen nodes
298 | nn_test_mrr, nn_test_recall_20, nn_test_recall_50 = eval_edge_prediction(model=dgnn, negative_edge_sampler=nn_test_rand_sampler,
299 | data=new_node_test_data, n_neighbors=NUM_NEIGHBORS)
300 |
301 | logger.info(
302 | 'Test statistics: Old nodes -- mrr: {}, recall_20: {}, recall_50:{}'.format(test_mrr, test_recall_20,
303 | test_recall_50))
304 | logger.info(
305 | 'Test statistics: New nodes -- mrr: {}, recall_20: {}, recall_50:{}'.format(nn_test_mrr, nn_test_recall_20,
306 | nn_test_recall_50))
307 | # Save results for this run
308 | pickle.dump({
309 | "val_aps": val_mrrs,
310 | "new_nodes_val_aps": new_nodes_val_mrrs,
311 | "test_ap": test_mrr,
312 | "new_node_test_ap": nn_test_mrr,
313 | "epoch_times": epoch_times,
314 | "train_losses": train_losses,
315 | "total_epoch_times": total_epoch_times
316 | }, open(results_path, "wb"))
317 |
318 | logger.info('Saving DGNN model')
319 | # Restore memory at the end of validation (save a model which is ready for testing)
320 | dgnn.memory_s.restore_memory(val_memory_backup_s)
321 | dgnn.memory_g.restore_memory(val_memory_backup_g)
322 | torch.save(dgnn.state_dict(), MODEL_SAVE_PATH)
323 | logger.info('DGNN model saved')
324 |
--------------------------------------------------------------------------------