├── TGAT+TGSL ├── data │ └── .keep ├── config.py ├── utils.py ├── process.py ├── view_learner.py ├── neighbor_finder.py ├── MTL.py ├── TGAT.py └── train.py ├── GraphMixer+TGSL ├── data │ └── .keep ├── utils │ ├── metrics.py │ ├── EarlyStopping.py │ ├── DataLoader.py │ ├── load_configs.py │ └── utils.py ├── process.py ├── view_learner.py ├── models │ ├── modules.py │ └── GraphMixer.py └── MTL.py ├── figures └── TGSL.png ├── LICENSE └── README.md /TGAT+TGSL/data/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/data/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/TGSL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ViktorAxelsen/TGSL/HEAD/figures/TGSL.png -------------------------------------------------------------------------------- /TGAT+TGSL/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Training Configuration 3 | ''' 4 | class Config: 5 | BATCH_SIZE = 200 6 | EPOCH = 50 7 | SSL_EPOCH = 10 8 | LR = 1e-4 9 | SSL_LR = 1e-4 10 | 11 | N_DEGREE = 20 12 | N_HEAD = 2 13 | N_LAYER = 2 14 | DROPOUT = 0.1 15 | NODE_DIM = 100 16 | TIME_DIM = 100 17 | MAX_ROUND = 8 18 | 19 | 20 | ''' 21 | Wiki Dataset Configuration 22 | ''' 23 | class WikiConfig(Config): 24 | pass 25 | 26 | 27 | ''' 28 | Reddit Dataset Configuration 29 | ''' 30 | class RedditConfig(Config): 31 | pass 32 | 33 | 34 | ''' 35 | Escorts Dataset Configuration 36 | ''' 37 | class ESConfig(Config): 38 | pass 39 | 40 | 41 | if __name__ == '__main__': 42 | config = Config() 43 | 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ViktorAxelsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import average_precision_score, roc_auc_score 3 | 4 | 5 | def get_link_prediction_metrics(predicts: torch.Tensor, labels: torch.Tensor): 6 | """ 7 | get metrics for the link prediction task 8 | :param predicts: Tensor, shape (num_samples, ) 9 | :param labels: Tensor, shape (num_samples, ) 10 | :return: 11 | dictionary of metrics {'metric_name_1': metric_1, ...} 12 | """ 13 | predicts = predicts.cpu().detach().numpy() 14 | pred_label = predicts.squeeze() > 0.5 15 | labels = labels.cpu().numpy() 16 | 17 | acc = (pred_label == labels).mean() 18 | average_precision = average_precision_score(y_true=labels, y_score=predicts) 19 | roc_auc = roc_auc_score(y_true=labels, y_score=predicts) 20 | 21 | return {'acc': acc, 'roc_auc': roc_auc, 'average_precision': average_precision} 22 | 23 | 24 | def get_node_classification_metrics(predicts: torch.Tensor, labels: torch.Tensor): 25 | """ 26 | get metrics for the node classification task 27 | :param predicts: Tensor, shape (num_samples, ) 28 | :param labels: Tensor, shape (num_samples, ) 29 | :return: 30 | dictionary of metrics {'metric_name_1': metric_1, ...} 31 | """ 32 | predicts = predicts.cpu().detach().numpy() 33 | labels = labels.cpu().numpy() 34 | 35 | roc_auc = roc_auc_score(y_true=labels, y_score=predicts) 36 | 37 | return {'roc_auc': roc_auc} 38 | -------------------------------------------------------------------------------- /TGAT+TGSL/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import numpy as np 4 | import random 5 | import dgl 6 | 7 | 8 | def get_device(index=3): 9 | return torch.device("cuda:" + str(index) if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | def show_time(): 13 | time_stamp = '\033[1;31;40m[' + str(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + ']\033[0m' 14 | 15 | return time_stamp 16 | 17 | 18 | def set_seed(seed): 19 | random.seed(seed) 20 | np.random.seed(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | dgl.seed(seed) 24 | dgl.random.seed(seed) 25 | 26 | 27 | class EarlyStopMonitor(object): 28 | def __init__(self, max_round=3, higher_better=True, tolerance=1e-3): 29 | self.max_round = max_round 30 | self.num_round = 0 31 | 32 | self.epoch_count = 0 33 | self.best_epoch = 0 34 | 35 | self.last_best = None 36 | self.higher_better = higher_better 37 | self.tolerance = tolerance 38 | 39 | def early_stop_check(self, curr_val): 40 | if not self.higher_better: 41 | curr_val *= -1 42 | if self.last_best is None: 43 | self.last_best = curr_val 44 | elif (curr_val - self.last_best) / np.abs(self.last_best) > self.tolerance: 45 | self.last_best = curr_val 46 | self.num_round = 0 47 | self.best_epoch = self.epoch_count 48 | else: 49 | self.num_round += 1 50 | 51 | self.epoch_count += 1 52 | 53 | return self.num_round >= self.max_round 54 | 55 | 56 | class RandEdgeSampler(object): 57 | def __init__(self, src_list, dst_list): 58 | self.src_list = np.unique(src_list) 59 | self.dst_list = np.unique(dst_list) 60 | 61 | def sample(self, size): 62 | src_index = np.random.randint(0, len(self.src_list), size) 63 | dst_index = np.random.randint(0, len(self.dst_list), size) 64 | return self.src_list[src_index], self.dst_list[dst_index] 65 | 66 | 67 | if __name__ == '__main__': 68 | pass -------------------------------------------------------------------------------- /TGAT+TGSL/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def preprocess(path, data_name): 8 | u_list, i_list, ts_list, label_list = [], [], [], [] 9 | feat_l = [] 10 | idx_list = [] 11 | 12 | with open(path) as f: 13 | s = next(f) 14 | print(s) 15 | if data_name == 'escorts': 16 | for _ in range(5): 17 | s = next(f) 18 | print(s) 19 | previous_time = -1 20 | for idx, line in enumerate(f): 21 | if data_name == 'escorts': 22 | e = line.strip().split() 23 | u = int(e[0]) 24 | i = int(e[1]) 25 | ts = float(e[3]) 26 | assert ts >= previous_time 27 | previous_time = ts 28 | label = int(e[2]) 29 | feat = np.zeros(172) 30 | else: 31 | e = line.strip().split(',') 32 | u = int(e[0]) 33 | i = int(e[1]) 34 | ts = float(e[2]) 35 | label = int(e[3]) 36 | feat = np.array([float(x) for x in e[4:]]) 37 | 38 | u_list.append(u) 39 | i_list.append(i) 40 | ts_list.append(ts) 41 | label_list.append(label) 42 | idx_list.append(idx) 43 | 44 | feat_l.append(feat) 45 | return pd.DataFrame({'u': u_list, 46 | 'i': i_list, 47 | 'ts': ts_list, 48 | 'label': label_list, 49 | 'idx': idx_list}), np.array(feat_l) 50 | 51 | 52 | def reindex(df): 53 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 54 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 55 | 56 | upper_u = df.u.max() + 1 57 | new_i = df.i + upper_u 58 | 59 | new_df = df.copy() 60 | print(new_df.u.max()) 61 | print(new_df.i.max()) 62 | 63 | new_df.i = new_i 64 | new_df.u += 1 65 | new_df.i += 1 66 | new_df.idx += 1 67 | 68 | print(new_df.u.max()) 69 | print(new_df.i.max()) 70 | 71 | return new_df 72 | 73 | 74 | def run(data_name): 75 | if data_name == 'escorts': 76 | PATH = './data/{}.edges'.format(data_name) 77 | else: 78 | PATH = './data/{}.csv'.format(data_name) 79 | OUT_DF = './processed/ml_{}.csv'.format(data_name) 80 | OUT_FEAT = './processed/ml_{}.npy'.format(data_name) 81 | OUT_NODE_FEAT = './processed/ml_{}_node.npy'.format(data_name) 82 | 83 | df, feat = preprocess(PATH, data_name) 84 | new_df = reindex(df) 85 | 86 | print(feat.shape) 87 | empty = np.zeros(feat.shape[1])[np.newaxis, :] 88 | feat = np.vstack([empty, feat]) 89 | 90 | max_idx = max(new_df.u.max(), new_df.i.max()) 91 | rand_feat = np.zeros((max_idx + 1, feat.shape[1])) 92 | 93 | print(feat.shape) 94 | new_df.to_csv(OUT_DF) 95 | np.save(OUT_FEAT, feat) 96 | np.save(OUT_NODE_FEAT, rand_feat) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--dataset', type=str, help='dataset', choices=['wikipedia', 'reddit', 'escorts'], default='wikipedia') 102 | args = parser.parse_args() 103 | os.makedirs(f"./processed/", exist_ok=True) 104 | os.makedirs(f"./log/", exist_ok=True) 105 | run(args.dataset) -------------------------------------------------------------------------------- /GraphMixer+TGSL/process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def preprocess(path, data_name): 8 | u_list, i_list, ts_list, label_list = [], [], [], [] 9 | feat_l = [] 10 | idx_list = [] 11 | 12 | with open(path) as f: 13 | s = next(f) 14 | print(s) 15 | if data_name == 'escorts': 16 | for _ in range(5): 17 | s = next(f) 18 | print(s) 19 | previous_time = -1 20 | for idx, line in enumerate(f): 21 | if data_name == 'escorts': 22 | e = line.strip().split() 23 | u = int(e[0]) 24 | i = int(e[1]) 25 | ts = float(e[3]) 26 | assert ts >= previous_time 27 | previous_time = ts 28 | label = int(e[2]) 29 | feat = np.zeros(172) 30 | else: 31 | e = line.strip().split(',') 32 | u = int(e[0]) 33 | i = int(e[1]) 34 | ts = float(e[2]) 35 | label = int(e[3]) 36 | feat = np.array([float(x) for x in e[4:]]) 37 | 38 | u_list.append(u) 39 | i_list.append(i) 40 | ts_list.append(ts) 41 | label_list.append(label) 42 | idx_list.append(idx) 43 | 44 | feat_l.append(feat) 45 | return pd.DataFrame({'u': u_list, 46 | 'i': i_list, 47 | 'ts': ts_list, 48 | 'label': label_list, 49 | 'idx': idx_list}), np.array(feat_l) 50 | 51 | 52 | def reindex(df): 53 | assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) 54 | assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) 55 | 56 | upper_u = df.u.max() + 1 57 | new_i = df.i + upper_u 58 | 59 | new_df = df.copy() 60 | print(new_df.u.max()) 61 | print(new_df.i.max()) 62 | 63 | new_df.i = new_i 64 | new_df.u += 1 65 | new_df.i += 1 66 | new_df.idx += 1 67 | 68 | print(new_df.u.max()) 69 | print(new_df.i.max()) 70 | 71 | return new_df 72 | 73 | 74 | def run(data_name): 75 | if data_name == 'escorts': 76 | PATH = './data/{}.edges'.format(data_name) 77 | else: 78 | PATH = './data/{}.csv'.format(data_name) 79 | OUT_DF = './processed_data/{}/ml_{}.csv'.format(data_name, data_name) 80 | OUT_FEAT = './processed_data/{}/ml_{}.npy'.format(data_name, data_name) 81 | OUT_NODE_FEAT = './processed_data/{}/ml_{}_node.npy'.format(data_name, data_name) 82 | 83 | df, feat = preprocess(PATH, data_name) 84 | new_df = reindex(df) 85 | 86 | print(feat.shape) 87 | empty = np.zeros(feat.shape[1])[np.newaxis, :] 88 | feat = np.vstack([empty, feat]) 89 | 90 | max_idx = max(new_df.u.max(), new_df.i.max()) 91 | rand_feat = np.zeros((max_idx + 1, feat.shape[1])) 92 | 93 | print(feat.shape) 94 | new_df.to_csv(OUT_DF) 95 | np.save(OUT_FEAT, feat) 96 | np.save(OUT_NODE_FEAT, rand_feat) 97 | 98 | 99 | if __name__ == '__main__': 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--dataset', type=str, help='dataset', choices=['wikipedia', 'reddit', 'escorts'], default='wikipedia') 102 | args = parser.parse_args() 103 | os.makedirs("./processed_data/{}/".format(args.dataset), exist_ok=True) 104 | run(args.dataset) -------------------------------------------------------------------------------- /GraphMixer+TGSL/utils/EarlyStopping.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import logging 5 | import numpy as np 6 | 7 | 8 | class EarlyStopping(object): 9 | 10 | def __init__(self, patience: int, save_model_folder: str, save_model_name: str, logger: logging.Logger, model_name: str = None, prefix=None, tolerance=1e-3): 11 | """ 12 | Early stop strategy. 13 | :param patience: int, max patience 14 | :param save_model_folder: str, save model folder 15 | :param save_model_name: str, save model name 16 | :param logger: Logger 17 | :param model_name: str, model name 18 | """ 19 | self.patience = patience 20 | self.tolerance = tolerance 21 | self.counter = 0 22 | self.best_metrics = {} 23 | self.early_stop = False 24 | self.logger = logger 25 | self.save_model_path = os.path.join(save_model_folder, f"{save_model_name}-{prefix}.pkl") 26 | self.model_name = model_name 27 | 28 | def step(self, metrics: list, model: nn.Module): 29 | """ 30 | execute the early stop strategy for each evaluation process 31 | :param metrics: list, list of metrics, each element is a tuple (str, float, boolean) -> (metric_name, metric_value, whether higher means better) 32 | :param model: nn.Module 33 | :return: 34 | """ 35 | metrics_compare_results = [] 36 | for metric_tuple in metrics: 37 | metric_name, metric_value, higher_better = metric_tuple[0], metric_tuple[1], metric_tuple[2] 38 | 39 | if higher_better: 40 | if self.best_metrics.get(metric_name) is None or (metric_value - self.best_metrics.get(metric_name)) / np.abs(self.best_metrics.get(metric_name)) >= self.tolerance: 41 | metrics_compare_results.append(True) 42 | else: 43 | metrics_compare_results.append(False) 44 | else: 45 | if self.best_metrics.get(metric_name) is None or metric_value <= self.best_metrics.get(metric_name): 46 | metrics_compare_results.append(True) 47 | else: 48 | metrics_compare_results.append(False) 49 | # all the computed metrics are better than the best metrics 50 | if torch.all(torch.tensor(metrics_compare_results)): 51 | for metric_tuple in metrics: 52 | metric_name, metric_value = metric_tuple[0], metric_tuple[1] 53 | self.best_metrics[metric_name] = metric_value 54 | self.save_checkpoint(model) 55 | self.counter = 0 56 | # metrics are not better at the epoch 57 | else: 58 | self.counter += 1 59 | if self.counter >= self.patience: 60 | self.early_stop = True 61 | 62 | return self.early_stop 63 | 64 | def save_checkpoint(self, model: nn.Module): 65 | """ 66 | saves model at self.save_model_path 67 | :param model: nn.Module 68 | :return: 69 | """ 70 | self.logger.info(f"save model {self.save_model_path}") 71 | torch.save(model.state_dict(), self.save_model_path) 72 | 73 | def load_checkpoint(self, model: nn.Module, map_location: str = None): 74 | """ 75 | load model at self.save_model_path 76 | :param model: nn.Module 77 | :param map_location: str, how to remap the storage locations 78 | :return: 79 | """ 80 | self.logger.info(f"load model {self.save_model_path}") 81 | model.load_state_dict(torch.load(self.save_model_path, map_location=map_location)) 82 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Time-aware Graph Structure Learning via Sequence Prediction on Temporal Graphs 2 | 3 | 4 | Official implementation of the CIKM'23 Full/Long paper: Time-aware Graph Structure Learning via Sequence Prediction on Temporal Graphs. [[arXiv](https://arxiv.org/abs/2306.07699)] 5 | 6 | 7 | ![Method](./figures/TGSL.png) 8 | 9 | 10 | ## Environment Setup 11 | 12 | ```bash 13 | # python==3.8 14 | pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html 15 | pip install dgl==1.0.0+cu113 -f https://data.dgl.ai/wheels/cu113/repo.html 16 | pip install scikit-learn 17 | ``` 18 | 19 | 20 | ## TGAT + TGSL 21 | 22 | For TGAT, we use the official implementation codes: [TGAT](https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs) 23 | 24 | ### Datasets & Pre-processing 25 | 26 | Download the Wikipedia and Reddit datasets from [here](http://snap.stanford.edu/jodie/) and Escorts from [here](https://networkrepository.com/escorts.php). Save the downloaded files in the `data/` folder. 27 | 28 | 29 | Run the following commands to pre-process datasets. 30 | 31 | ```bash 32 | # wikipedia 33 | python process.py --dataset wikipedia 34 | # reddit 35 | python process.py --dataset reddit 36 | # escorts 37 | python process.py --dataset escorts 38 | ``` 39 | 40 | 41 | ### Training & Evaluation 42 | 43 | Run the following commands to start training and evaluation. 44 | 45 | ```bash 46 | # wikipedia 47 | python train.py --dataset wikipedia --uniform --agg_method attn --attn_mode prod --cuda 0 --prefix tgat_tgsl_wiki --seed 2023 --tau 0.1 --ratio 0.8 --can_type 1st --can_nn 10 --rnn_layer 3 --coe 0.2 --K 512 --gtau 1.0 48 | # reddit 49 | python train.py --dataset reddit --uniform --agg_method attn --attn_mode prod --cuda 0 --prefix tgat_tgsl_reddit --seed 2023 --tau 0.1 --ratio 0.4 --can_type 1st --can_nn 20 --rnn_layer 1 --coe 0.2 --K 512 --gtau 1.0 50 | # escorts 51 | python train.py --dataset escorts --uniform --agg_method attn --attn_mode prod --cuda 0 --prefix tgat_tgsl_escorts --seed 2023 --tau 0.1 --ratio 0.064 --can_type 3rd --can_nn 5 --rnn_layer 1 --coe 0.7 --K 512 --gtau 1.0 52 | ``` 53 | 54 | 55 | ## GraphMixer + TGSL 56 | 57 | For GraphMixer, we use [DyGLib](https://github.com/yule-BUAA/DyGLib) to implement TGSL under both transductive and inductive settings. 58 | 59 | GraphMixer official implementation: [GraphMixer](https://github.com/CongWeilin/GraphMixer) 60 | 61 | ### Datasets & Pre-processing 62 | 63 | The same as TGAT + TGSL. 64 | 65 | 66 | ### Training & Evaluation 67 | 68 | Run the following commands to start training and evaluation. 69 | 70 | 71 | ```bash 72 | # wikipedia 73 | python train_link_prediction.py --dataset_name wikipedia --prefix graphmixer_tgsl_wiki --log_name graphmixer_tgsl_wiki --model_name GraphMixer --load_best_configs --num_runs 1 --gpu 0 --tau 0.1 --ratio 0.008 --can_nn 10 --can_type 3rd --rnn_layer 1 --coe 0.2 --K 512 --gtau 1.0 74 | # reddit 75 | python train_link_prediction.py --dataset_name reddit --prefix graphmixer_tgsl_reddit --log_name graphmixer_tgsl_reddit --model_name GraphMixer --load_best_configs --num_runs 1 --gpu 0 --tau 0.1 --ratio 0.4 --can_nn 20 --can_type 1st --rnn_layer 2 --coe 0.2 --K 512 --gtau 1.0 76 | # escorts 77 | python train_link_prediction.py --dataset_name escorts --prefix graphmixer_tgsl_escorts --log_name graphmixer_tgsl_escorts --model_name GraphMixer --load_best_configs --num_runs 1 --gpu 0 --tau 0.1 --ratio 0.002 --can_nn 20 --can_type 3rd --rnn_layer 1 --coe 0.2 --K 512 --gtau 1.0 78 | ``` 79 | 80 | 81 | ## Acknowledgments 82 | 83 | Thanks to the publicly released codes of [TGAT](https://github.com/StatsDLMathsRecomSys/Inductive-representation-learning-on-temporal-graphs) and [DyGLib](https://github.com/yule-BUAA/DyGLib), we implement TGSL based on them. 84 | 85 | 86 | 87 | ## Citation 88 | 89 | ```bibtex 90 | @article{TGSL, 91 | title={Time-aware Graph Structure Learning via Sequence Prediction on Temporal Graphs}, 92 | author={Haozhen Zhang and Xueting Han* and Xi Xiao and Jing Bai}, 93 | journal={arXiv preprint arXiv:2306.07699}, 94 | year={2023} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /TGAT+TGSL/view_learner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class GraphMixerTE(torch.nn.Module): 7 | def __init__(self, time_dim): 8 | super(GraphMixerTE, self).__init__() 9 | self.w = nn.Parameter(torch.tensor([int(np.sqrt(time_dim)) ** (-(i - 1) / int(np.sqrt(time_dim))) for i in range(1, time_dim + 1)]), requires_grad=False) 10 | self.transformation = nn.Sequential( 11 | nn.Linear(time_dim, time_dim), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(time_dim, time_dim) 14 | ) 15 | 16 | def forward(self, ts): 17 | map_ts = ts * self.w.view(1, -1) 18 | harmonic = torch.sin(map_ts) 19 | harmonic = self.transformation(harmonic) 20 | 21 | return harmonic 22 | 23 | 24 | class TimeMapping(torch.nn.Module): 25 | def __init__(self, time_dim=172): 26 | super(TimeMapping, self).__init__() 27 | self.w = nn.Parameter(torch.tensor([int(np.sqrt(time_dim)) ** (-(i - 1) / int(np.sqrt(time_dim))) for i in range(1, time_dim + 1)]), requires_grad=False) 28 | self.transformation = nn.Linear(time_dim, time_dim, bias=False) 29 | 30 | def forward(self, ts): 31 | map_ts = ts * self.w.view(1, -1) 32 | harmonic = torch.sin(map_ts) 33 | harmonic = self.transformation(harmonic) 34 | 35 | return harmonic + 1 36 | 37 | 38 | def gcn_reduce(nodes): 39 | selfh = nodes.data['h'] 40 | msgs = torch.mean(nodes.mailbox['h'], dim=1) 41 | msgs = torch.cat((msgs, selfh), dim=1) 42 | 43 | return {'h': msgs} 44 | 45 | 46 | def gcn_msg(edges): 47 | h = torch.cat((edges.src['h'], edges.data['h'], edges.data['ts_enc']), dim=1) 48 | 49 | return {'h': h} 50 | 51 | 52 | class NodeApplyModule(nn.Module): 53 | def __init__(self, in_feats, out_feats): 54 | super(NodeApplyModule, self).__init__() 55 | self.fc = nn.Linear(in_feats, out_feats, bias=True) 56 | 57 | def forward(self, node): 58 | h = self.fc(node.data['h']) 59 | 60 | return {'h': h} 61 | 62 | 63 | class EdgeApplyModule(nn.Module): 64 | def __init__(self, in_feats, out_feats): 65 | super(EdgeApplyModule, self).__init__() 66 | self.fc = nn.Linear(in_feats * 4, out_feats, bias=True) 67 | 68 | def forward(self, edge): 69 | h = torch.cat((edge.src['h'], edge.data['h'], edge.data['ts_enc'], edge.dst['h']), dim=1) 70 | h = self.fc(h) 71 | 72 | return {'h': h} 73 | 74 | 75 | class TimeAwareGCN(nn.Module): 76 | def __init__(self, in_feats, out_feats): 77 | super(TimeAwareGCN, self).__init__() 78 | self.apply_mod = NodeApplyModule(in_feats * 4, out_feats) 79 | self.apply_mod_e = EdgeApplyModule(in_feats, out_feats) 80 | 81 | def forward(self, g, features, efeatures): 82 | g.ndata['h'] = features 83 | g.edata['h'] = efeatures 84 | 85 | g.update_all(gcn_msg, gcn_reduce) 86 | g.apply_nodes(func=self.apply_mod) 87 | g.apply_edges(func=self.apply_mod_e) 88 | 89 | return g.ndata.pop('h'), g.edata.pop('h') 90 | 91 | 92 | class ETGNN(nn.Module): 93 | def __init__(self, in_dim, hidden_dim, train_src_l, train_dst_l, mlp_dim=64, time_dim=172): 94 | super(ETGNN, self).__init__() 95 | self.train_src_l = train_src_l 96 | self.train_dst_l = train_dst_l 97 | self.gcn1 = TimeAwareGCN(in_dim, hidden_dim) 98 | self.gcn2 = TimeAwareGCN(hidden_dim, hidden_dim) 99 | self.act = nn.ReLU(inplace=True) 100 | self.mlp = nn.Sequential( 101 | nn.Linear(hidden_dim * 4, hidden_dim * 2), 102 | nn.ReLU(inplace=True), 103 | nn.Linear(hidden_dim * 2, mlp_dim), 104 | nn.ReLU(inplace=True), 105 | nn.Linear(mlp_dim, 1) 106 | ) 107 | self.time_encoder = GraphMixerTE(time_dim=time_dim) 108 | self.init_emb() 109 | 110 | def init_emb(self): 111 | for m in self.modules(): 112 | if isinstance(m, nn.Linear): 113 | torch.nn.init.xavier_uniform_(m.weight.data) 114 | if m.bias is not None: 115 | m.bias.data.fill_(0.0) 116 | 117 | def forward(self, g): 118 | g.edata['ts_enc'] = self.time_encoder(g.edata['ts']) 119 | res, eres = self.gcn1(g, g.ndata['feat'], g.edata['edge_feat']) 120 | res = self.act(res) 121 | eres = self.act(eres) 122 | res, eres = self.gcn2(g, res, eres) 123 | eres = eres[:len(self.train_src_l), :] 124 | 125 | return eres 126 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/view_learner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class GraphMixerTE(torch.nn.Module): 7 | def __init__(self, time_dim): 8 | super(GraphMixerTE, self).__init__() 9 | self.w = nn.Parameter(torch.tensor([int(np.sqrt(time_dim)) ** (-(i - 1) / int(np.sqrt(time_dim))) for i in range(1, time_dim + 1)]), requires_grad=False) 10 | self.transformation = nn.Sequential( 11 | nn.Linear(time_dim, time_dim), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(time_dim, time_dim) 14 | ) 15 | 16 | def forward(self, ts): 17 | map_ts = ts * self.w.view(1, -1) 18 | harmonic = torch.sin(map_ts) 19 | harmonic = self.transformation(harmonic) 20 | 21 | return harmonic 22 | 23 | 24 | class TimeMapping(torch.nn.Module): 25 | def __init__(self, time_dim=172): 26 | super(TimeMapping, self).__init__() 27 | self.w = nn.Parameter(torch.tensor([int(np.sqrt(time_dim)) ** (-(i - 1) / int(np.sqrt(time_dim))) for i in range(1, time_dim + 1)]), requires_grad=False) 28 | self.transformation = nn.Sequential( 29 | nn.Linear(time_dim, time_dim, bias=False) 30 | ) 31 | 32 | def forward(self, ts): 33 | map_ts = ts * self.w.view(1, -1) 34 | harmonic = torch.sin(map_ts) 35 | harmonic = self.transformation(harmonic) 36 | 37 | return harmonic + 1 38 | 39 | 40 | def gcn_reduce(nodes): 41 | selfh = nodes.data['h'] 42 | msgs = torch.mean(nodes.mailbox['h'], dim=1) 43 | msgs = torch.cat((msgs, selfh), dim=1) 44 | 45 | return {'h': msgs} 46 | 47 | 48 | def gcn_msg(edges): 49 | h = torch.cat((edges.src['h'], edges.data['h'], edges.data['ts_enc']), dim=1) 50 | 51 | return {'h': h} 52 | 53 | 54 | class NodeApplyModule(nn.Module): 55 | def __init__(self, in_feats, out_feats): 56 | super(NodeApplyModule, self).__init__() 57 | self.fc = nn.Linear(in_feats, out_feats, bias=True) 58 | 59 | def forward(self, node): 60 | h = self.fc(node.data['h']) 61 | 62 | return {'h': h} 63 | 64 | 65 | class EdgeApplyModule(nn.Module): 66 | def __init__(self, in_feats, out_feats): 67 | super(EdgeApplyModule, self).__init__() 68 | self.fc = nn.Linear(in_feats * 4, out_feats, bias=True) 69 | 70 | def forward(self, edge): 71 | h = torch.cat((edge.src['h'], edge.data['h'], edge.data['ts_enc'], edge.dst['h']), dim=1) 72 | h = self.fc(h) 73 | 74 | return {'h': h} 75 | 76 | 77 | class TimeAwareGCN(nn.Module): 78 | def __init__(self, in_feats, out_feats): 79 | super(TimeAwareGCN, self).__init__() 80 | self.apply_mod = NodeApplyModule(in_feats * 4, out_feats) 81 | self.apply_mod_e = EdgeApplyModule(in_feats, out_feats) 82 | 83 | def forward(self, g, features, efeatures): 84 | g.ndata['h'] = features 85 | g.edata['h'] = efeatures 86 | 87 | g.update_all(gcn_msg, gcn_reduce) 88 | g.apply_nodes(func=self.apply_mod) 89 | g.apply_edges(func=self.apply_mod_e) 90 | 91 | return g.ndata.pop('h'), g.edata.pop('h') 92 | 93 | 94 | class ETGNN(nn.Module): 95 | def __init__(self, in_dim, hidden_dim, train_src_l, train_dst_l, mlp_dim=64, time_dim=172): 96 | super(ETGNN, self).__init__() 97 | self.train_src_l = train_src_l 98 | self.train_dst_l = train_dst_l 99 | 100 | self.gcn1 = TimeAwareGCN(in_dim, hidden_dim) 101 | self.gcn2 = TimeAwareGCN(hidden_dim, hidden_dim) 102 | self.act = nn.ReLU(inplace=True) 103 | 104 | self.mlp = nn.Sequential( 105 | nn.Linear(hidden_dim * 4, hidden_dim * 2), 106 | nn.ReLU(inplace=True), 107 | nn.Linear(hidden_dim * 2, mlp_dim), 108 | nn.ReLU(inplace=True), 109 | nn.Linear(mlp_dim, 1) 110 | ) 111 | self.time_encoder = GraphMixerTE(time_dim=time_dim) 112 | self.init_emb() 113 | 114 | def init_emb(self): 115 | for m in self.modules(): 116 | if isinstance(m, nn.Linear): 117 | torch.nn.init.xavier_uniform_(m.weight.data) 118 | if m.bias is not None: 119 | m.bias.data.fill_(0.0) 120 | 121 | def forward(self, g): 122 | g.edata['ts_enc'] = self.time_encoder(g.edata['ts']) 123 | res, eres = self.gcn1(g, g.ndata['feat'], g.edata['edge_feat']) 124 | res = self.act(res) 125 | eres = self.act(eres) 126 | res, eres = self.gcn2(g, res, eres) 127 | eres = eres[:len(self.train_src_l), :] 128 | 129 | return eres 130 | -------------------------------------------------------------------------------- /TGAT+TGSL/neighbor_finder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class NeighborFinder: 5 | def __init__(self, adj_list, uniform=False): 6 | """ 7 | Params 8 | ------ 9 | node_idx_l: List[int] 10 | node_ts_l: List[int] 11 | off_set_l: List[int], such that node_idx_l[off_set_l[i]:off_set_l[i + 1]] = adjacent_list[i] 12 | """ 13 | 14 | node_idx_l, node_ts_l, edge_idx_l, off_set_l = self.init_off_set(adj_list) 15 | self.node_idx_l = node_idx_l 16 | self.node_ts_l = node_ts_l 17 | self.edge_idx_l = edge_idx_l 18 | 19 | self.off_set_l = off_set_l 20 | 21 | self.uniform = uniform 22 | 23 | def init_off_set(self, adj_list): 24 | """ 25 | Params 26 | ------ 27 | adj_list: List[List[int]] 28 | 29 | """ 30 | n_idx_l = [] 31 | n_ts_l = [] 32 | e_idx_l = [] 33 | off_set_l = [0] 34 | for i in range(len(adj_list)): 35 | curr = adj_list[i] 36 | # curr = sorted(curr, key=lambda x: x[1]) 37 | curr = sorted(curr, key=lambda x: x[2]) 38 | n_idx_l.extend([x[0] for x in curr]) 39 | e_idx_l.extend([x[1] for x in curr]) 40 | n_ts_l.extend([x[2] for x in curr]) 41 | 42 | 43 | off_set_l.append(len(n_idx_l)) 44 | n_idx_l = np.array(n_idx_l) 45 | n_ts_l = np.array(n_ts_l) 46 | e_idx_l = np.array(e_idx_l) 47 | off_set_l = np.array(off_set_l) 48 | 49 | assert(len(n_idx_l) == len(n_ts_l)) 50 | assert(off_set_l[-1] == len(n_ts_l)) 51 | 52 | return n_idx_l, n_ts_l, e_idx_l, off_set_l 53 | 54 | def find_before(self, src_idx, cut_time): 55 | """ 56 | 57 | Params 58 | ------ 59 | src_idx: int 60 | cut_time: float 61 | """ 62 | node_idx_l = self.node_idx_l 63 | node_ts_l = self.node_ts_l 64 | edge_idx_l = self.edge_idx_l 65 | off_set_l = self.off_set_l 66 | 67 | neighbors_idx = node_idx_l[off_set_l[src_idx]:off_set_l[src_idx + 1]] 68 | neighbors_ts = node_ts_l[off_set_l[src_idx]:off_set_l[src_idx + 1]] 69 | neighbors_e_idx = edge_idx_l[off_set_l[src_idx]:off_set_l[src_idx + 1]] 70 | 71 | if len(neighbors_idx) == 0 or len(neighbors_ts) == 0: 72 | return neighbors_idx, neighbors_ts, neighbors_e_idx 73 | 74 | left = 0 75 | right = len(neighbors_idx) - 1 76 | 77 | while left + 1 < right: 78 | mid = (left + right) // 2 79 | curr_t = neighbors_ts[mid] 80 | if curr_t < cut_time: 81 | left = mid 82 | else: 83 | right = mid 84 | 85 | if neighbors_ts[right] < cut_time: 86 | return neighbors_idx[:right], neighbors_e_idx[:right], neighbors_ts[:right] 87 | else: 88 | return neighbors_idx[:left], neighbors_e_idx[:left], neighbors_ts[:left] 89 | 90 | def get_temporal_neighbor(self, src_idx_l, cut_time_l, num_neighbors=20): 91 | """ 92 | Params 93 | ------ 94 | src_idx_l: List[int] 95 | cut_time_l: List[float], 96 | num_neighbors: int 97 | """ 98 | assert(len(src_idx_l) == len(cut_time_l)) 99 | 100 | out_ngh_node_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.int32) 101 | out_ngh_t_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.float32) 102 | out_ngh_eidx_batch = np.zeros((len(src_idx_l), num_neighbors)).astype(np.int32) 103 | 104 | for i, (src_idx, cut_time) in enumerate(zip(src_idx_l, cut_time_l)): 105 | ngh_idx, ngh_eidx, ngh_ts = self.find_before(src_idx, cut_time) 106 | 107 | if len(ngh_idx) > 0: 108 | if self.uniform: 109 | sampled_idx = np.random.randint(0, len(ngh_idx), num_neighbors) 110 | 111 | out_ngh_node_batch[i, :] = ngh_idx[sampled_idx] 112 | out_ngh_t_batch[i, :] = ngh_ts[sampled_idx] 113 | out_ngh_eidx_batch[i, :] = ngh_eidx[sampled_idx] 114 | 115 | # resort based on time 116 | pos = out_ngh_t_batch[i, :].argsort() 117 | out_ngh_node_batch[i, :] = out_ngh_node_batch[i, :][pos] 118 | out_ngh_t_batch[i, :] = out_ngh_t_batch[i, :][pos] 119 | out_ngh_eidx_batch[i, :] = out_ngh_eidx_batch[i, :][pos] 120 | else: 121 | ngh_ts = ngh_ts[:num_neighbors] 122 | ngh_idx = ngh_idx[:num_neighbors] 123 | ngh_eidx = ngh_eidx[:num_neighbors] 124 | 125 | assert(len(ngh_idx) <= num_neighbors) 126 | assert(len(ngh_ts) <= num_neighbors) 127 | assert(len(ngh_eidx) <= num_neighbors) 128 | 129 | out_ngh_node_batch[i, num_neighbors - len(ngh_idx):] = ngh_idx 130 | out_ngh_t_batch[i, num_neighbors - len(ngh_ts):] = ngh_ts 131 | out_ngh_eidx_batch[i, num_neighbors - len(ngh_eidx):] = ngh_eidx 132 | 133 | return out_ngh_node_batch, out_ngh_eidx_batch, out_ngh_t_batch 134 | 135 | def find_k_hop(self, k, src_idx_l, cut_time_l, num_neighbors=20): 136 | """Sampling the k-hop sub graph 137 | """ 138 | x, y, z = self.get_temporal_neighbor(src_idx_l, cut_time_l, num_neighbors) 139 | node_records = [x] 140 | eidx_records = [y] 141 | t_records = [z] 142 | for _ in range(k -1): 143 | ngn_node_est, ngh_t_est = node_records[-1], t_records[-1] # [N, *([num_neighbors] * (k - 1))] 144 | orig_shape = ngn_node_est.shape 145 | ngn_node_est = ngn_node_est.flatten() 146 | ngn_t_est = ngh_t_est.flatten() 147 | out_ngh_node_batch, out_ngh_eidx_batch, out_ngh_t_batch = self.get_temporal_neighbor(ngn_node_est, ngn_t_est, num_neighbors) 148 | out_ngh_node_batch = out_ngh_node_batch.reshape(*orig_shape, num_neighbors) # [N, *([num_neighbors] * k)] 149 | out_ngh_eidx_batch = out_ngh_eidx_batch.reshape(*orig_shape, num_neighbors) 150 | out_ngh_t_batch = out_ngh_t_batch.reshape(*orig_shape, num_neighbors) 151 | 152 | node_records.append(out_ngh_node_batch) 153 | eidx_records.append(out_ngh_eidx_batch) 154 | t_records.append(out_ngh_t_batch) 155 | return node_records, eidx_records, t_records 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /TGAT+TGSL/MTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pickle 5 | 6 | from neighbor_finder import NeighborFinder 7 | 8 | 9 | class MTL(nn.Module): 10 | def __init__(self, base_encoder_k, encoder, view_learner, edge_rnn, sample_time_encoder, len_full_edge, 11 | train_e_idx_l, train_node_set, train_ts_l, e_feat, device, dim=172, K=600, m=0.999, tau=0.1, gtau=1.0, 12 | ratio=0.9, can_nn=20, rnn_nn=20, can_type='3rd'): 13 | super(MTL, self).__init__() 14 | self.K = K 15 | self.m = m 16 | self.tau = tau 17 | self.gtau = gtau 18 | self.ratio = ratio 19 | self.can_nn = can_nn 20 | self.rnn_nn = rnn_nn 21 | self.can_type = can_type 22 | 23 | self.encoder_k = base_encoder_k 24 | self.encoder = encoder 25 | self.view_learner = view_learner 26 | self.edge_rnn = edge_rnn 27 | self.sample_time_encoder = sample_time_encoder 28 | 29 | self.e_feat = e_feat 30 | self.len_full_edge = len_full_edge 31 | self.train_e_idx_l = train_e_idx_l 32 | self.train_node_set = np.array(list(train_node_set)) 33 | self.train_ts_l = train_ts_l 34 | self.max_train_ts_l = max(train_ts_l) 35 | self.device = device 36 | 37 | for param_q, param_k in zip( 38 | self.encoder.parameters(), self.encoder_k.parameters() 39 | ): 40 | param_k.data.copy_(param_q.data) # initialize 41 | param_k.requires_grad = False # not update by gradient 42 | 43 | self.register_buffer("queue", torch.randn(dim, K)) 44 | self.queue = nn.functional.normalize(self.queue, dim=0) 45 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 46 | 47 | @torch.no_grad() 48 | def _momentum_update_key_encoder(self): 49 | for param_q, param_k in zip( 50 | self.encoder.parameters(), self.encoder_k.parameters() 51 | ): 52 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 53 | 54 | @torch.no_grad() 55 | def _dequeue_and_enqueue(self, keys): 56 | batch_size = keys.shape[0] 57 | ptr = int(self.queue_ptr) 58 | if ptr + batch_size > self.K: 59 | keys_prev = keys[:self.K - ptr] 60 | keys_aft = keys[self.K - ptr:] 61 | self.queue[:, ptr: self.K] = keys_prev.T 62 | self.queue[:, : len(keys_aft)] = keys_aft.T 63 | ptr = len(keys_aft) 64 | self.queue_ptr[0] = ptr 65 | else: 66 | self.queue[:, ptr : ptr + batch_size] = keys.T 67 | ptr = (ptr + batch_size) % self.K # move pointer 68 | self.queue_ptr[0] = ptr 69 | 70 | def forward(self, src_l_cut, dst_l_cut, dst_l_fake, ts_l_cut, NUM_NEIGHBORS, g, adj_list, adj_list_pickle): 71 | src_embed = self.encoder.tem_conv(src_l_cut, ts_l_cut, 2, NUM_NEIGHBORS) 72 | target_embed = self.encoder.tem_conv(dst_l_cut, ts_l_cut, 2, NUM_NEIGHBORS) 73 | background_embed = self.encoder.tem_conv(dst_l_fake, ts_l_cut, 2, NUM_NEIGHBORS) 74 | pos_score = self.encoder.affinity_score(src_embed, target_embed).squeeze(dim=-1) 75 | neg_score = self.encoder.affinity_score(src_embed, background_embed).squeeze(dim=-1) 76 | 77 | train_node_cut = np.array(list(set(np.append(src_l_cut, dst_l_cut)))) 78 | max_ts = np.array([self.max_train_ts_l + 1] * len(train_node_cut)) 79 | 80 | neighbor_finder = NeighborFinder(adj_list, uniform=True) 81 | neighbor_node_idx, neighbor_edge_idx, neighbor_ts = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=self.rnn_nn) 82 | 83 | train_edge_feat = self.view_learner(g) 84 | full_edge_feat = torch.zeros(self.len_full_edge + 1, train_edge_feat.shape[1], device=self.device) 85 | full_edge_feat[self.train_e_idx_l - 1] = train_edge_feat 86 | 87 | neighbor_edge_idx = neighbor_edge_idx.reshape(-1) - 1 88 | neighbor_edge_feat = full_edge_feat[neighbor_edge_idx] 89 | neighbor_edge_feat = neighbor_edge_feat.reshape(neighbor_node_idx.shape[0], neighbor_node_idx.shape[1], -1) 90 | 91 | neighbor_edge_feat = neighbor_edge_feat.transpose(0, 1) 92 | _, (h_n, _) = self.edge_rnn(neighbor_edge_feat) 93 | context_vec = h_n[-1] 94 | 95 | neighbor_finder.uniform = True 96 | if self.can_type == '1st': 97 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=self.can_nn) 98 | 99 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 100 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 101 | elif self.can_type == '3rd': 102 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.find_k_hop(3, train_node_cut, max_ts, num_neighbors=self.can_nn) 103 | candidate_node_idx = candidate_node_idx[-1].reshape(train_node_cut.shape[0], -1) 104 | candidate_edge_idx = candidate_edge_idx[-1].reshape(train_node_cut.shape[0], -1) 105 | candidate_ts = candidate_ts[-1].reshape(train_node_cut.shape[0], -1) 106 | 107 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 108 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 109 | elif self.can_type == 'random': 110 | candidate_node_idx = np.random.choice(self.train_node_set, size=train_node_cut.shape[0] * self.can_nn, replace=True).reshape(train_node_cut.shape[0], -1) # [bs, 20] 111 | candidate_edge_idx = np.array([0] * (train_node_cut.shape[0] * self.can_nn)).reshape(train_node_cut.shape[0], -1) 112 | candidate_ts = np.random.rand(train_node_cut.shape[0], self.can_nn) * self.max_train_ts_l 113 | 114 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 115 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 116 | elif self.can_type == 'mix': 117 | candidate_node_idx_1st, candidate_edge_idx_1st, candidate_ts_1st = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=self.can_nn) 118 | candidate_node_idx_3rd, candidate_edge_idx_3rd, candidate_ts_3rd = neighbor_finder.find_k_hop(3, train_node_cut, max_ts, num_neighbors=self.can_nn) 119 | candidate_node_idx_3rd = candidate_node_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 120 | candidate_edge_idx_3rd = candidate_edge_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 121 | candidate_ts_3rd = candidate_ts_3rd[-1].reshape(train_node_cut.shape[0], -1) 122 | 123 | candidate_node_idx = np.concatenate((candidate_node_idx_1st, candidate_node_idx_3rd), axis=-1) 124 | candidate_edge_idx = np.concatenate((candidate_edge_idx_1st, candidate_edge_idx_3rd), axis=-1) 125 | candidate_ts = np.concatenate((candidate_ts_1st, candidate_ts_3rd), axis=-1) 126 | 127 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 128 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 129 | else: 130 | pass 131 | 132 | candidate_edge_idx = candidate_edge_idx.reshape(-1) - 1 133 | candidate_edge_feat = full_edge_feat[candidate_edge_idx] # [bs * 20, 172] 134 | candidate_edge_feat = candidate_edge_feat.reshape(candidate_node_idx.shape[0], candidate_node_idx.shape[1], -1) 135 | 136 | ts_aug = np.random.rand(candidate_ts.shape[0], candidate_ts.shape[1]) * self.max_train_ts_l 137 | delta_ts_sample = ts_aug - candidate_ts 138 | delta_ts_sample_context = ts_aug - np.ones_like(candidate_ts) * self.max_train_ts_l 139 | delta_ts_sample_embedding = self.sample_time_encoder(torch.tensor(delta_ts_sample.reshape(-1, 1), dtype=torch.float32).to(self.device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 140 | delta_ts_sample_context_embedding = self.sample_time_encoder(torch.tensor(delta_ts_sample_context.reshape(-1, 1), dtype=torch.float32).to(self.device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 141 | 142 | context_vec = context_vec.unsqueeze(1).expand_as(candidate_edge_feat) 143 | context_vec = context_vec * delta_ts_sample_context_embedding 144 | candidate_edge_feat = candidate_edge_feat * delta_ts_sample_embedding 145 | aug_edge_logits = torch.sum(context_vec * candidate_edge_feat, dim=-1) # [bs, 20, 1] 146 | 147 | # Gumble-Top-K 148 | bias = 0.0 + 0.0001 # If bias is 0, we run into problems 149 | eps = (bias - (1 - bias)) * torch.rand(aug_edge_logits.size()) + (1 - bias) 150 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 151 | gate_inputs = gate_inputs.to(aug_edge_logits.device) 152 | gate_inputs = (gate_inputs + aug_edge_logits) / self.gtau 153 | z = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 154 | __, sorted_idx = z.sort(dim=-1, descending=True) 155 | k = int(self.ratio * z.size(1)) 156 | keep = sorted_idx[:, :k] # [bs, k] 157 | 158 | aug_edge_logits = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 159 | aug_edge_weight = torch.gather(aug_edge_logits, dim=1, index=keep) # [bs, k] 160 | ts_aug = torch.gather(torch.tensor(ts_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 161 | src_node_idx_aug = torch.gather(torch.tensor(src_node_idx_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 162 | dst_node_idx_aug = torch.gather(torch.tensor(dst_node_idx_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 163 | candidate_edge_feat = torch.gather(candidate_edge_feat, dim=1, index=keep.unsqueeze(2).repeat(1, 1, candidate_edge_feat.shape[2])) # [bs, k, 172] 164 | 165 | aug_edge_weight = aug_edge_weight.reshape(-1) 166 | ts_aug = ts_aug.reshape(-1) 167 | src_node_idx_aug = src_node_idx_aug.reshape(-1) 168 | dst_node_idx_aug = dst_node_idx_aug.reshape(-1) 169 | 170 | temp_eid = self.len_full_edge 171 | new_eid_list = [] 172 | adj_list_aug = pickle.loads(adj_list_pickle) 173 | for src, dst, ts in zip(src_node_idx_aug, dst_node_idx_aug, ts_aug): 174 | adj_list_aug[src].append((dst, temp_eid, ts)) 175 | adj_list_aug[dst].append((src, temp_eid, ts)) 176 | new_eid_list.append(temp_eid) 177 | temp_eid += 1 178 | 179 | train_ngh_finder_aug = NeighborFinder(adj_list_aug, uniform=self.encoder.ngh_finder.uniform) 180 | 181 | new_eid_list = np.array(new_eid_list) 182 | full_aug_edge_weight = torch.ones(temp_eid, device=self.device) 183 | full_aug_edge_weight[new_eid_list - 1] = aug_edge_weight 184 | 185 | candidate_edge_feat = candidate_edge_feat.reshape(-1, candidate_edge_feat.shape[2]).detach().cpu().numpy() # [bs * k, 172] 186 | e_feat_aug = np.concatenate((self.e_feat, candidate_edge_feat), axis=0) 187 | e_feat_th_aug = torch.nn.Parameter(torch.from_numpy(e_feat_aug.astype(np.float32))) 188 | edge_raw_embed_aug = torch.nn.Embedding.from_pretrained(e_feat_th_aug, padding_idx=0, freeze=True).to(self.device) 189 | 190 | ngh_finder_ori = self.encoder.ngh_finder 191 | self.encoder.ngh_finder = train_ngh_finder_aug 192 | edge_raw_embed_ori = self.encoder.edge_raw_embed 193 | self.encoder.edge_raw_embed = edge_raw_embed_aug 194 | 195 | max_ts = np.array([self.max_train_ts_l + 1] * len(train_node_cut)) 196 | 197 | src_embed_ed = self.encoder.tem_conv(src_l_cut, ts_l_cut, 2, NUM_NEIGHBORS, full_aug_edge_weight) 198 | target_embed_ed = self.encoder.tem_conv(dst_l_cut, ts_l_cut, 2, NUM_NEIGHBORS, full_aug_edge_weight) 199 | background_embed_ed = self.encoder.tem_conv(dst_l_fake, ts_l_cut, 2, NUM_NEIGHBORS, full_aug_edge_weight) 200 | pos_score_ed = self.encoder.affinity_score(src_embed_ed, target_embed_ed).squeeze(dim=-1) 201 | neg_score_ed = self.encoder.affinity_score(src_embed_ed, background_embed_ed).squeeze(dim=-1) 202 | 203 | q = self.encoder.tem_conv(train_node_cut, max_ts, 2, 20, full_aug_edge_weight) # queries: NxC 204 | q = nn.functional.normalize(q, dim=1) 205 | 206 | # Recover 207 | self.encoder.ngh_finder = ngh_finder_ori 208 | self.encoder.edge_raw_embed = edge_raw_embed_ori 209 | 210 | with torch.no_grad(): # no gradient to keys 211 | self._momentum_update_key_encoder() # update the key encoder 212 | k = self.encoder_k.tem_conv(train_node_cut, max_ts, 2, 20) # keys: NxC 213 | k = nn.functional.normalize(k, dim=1) 214 | 215 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 216 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 217 | logits = torch.cat([l_pos, l_neg], dim=1) 218 | logits /= self.tau 219 | labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() 220 | self._dequeue_and_enqueue(k) 221 | 222 | return pos_score.sigmoid(), neg_score.sigmoid(), pos_score_ed.sigmoid(), neg_score_ed.sigmoid(), logits, labels 223 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/utils/DataLoader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset, DataLoader 2 | import numpy as np 3 | import random 4 | import pandas as pd 5 | 6 | 7 | class CustomizedDataset(Dataset): 8 | def __init__(self, indices_list: list): 9 | """ 10 | Customized dataset. 11 | :param indices_list: list, list of indices 12 | """ 13 | super(CustomizedDataset, self).__init__() 14 | 15 | self.indices_list = indices_list 16 | 17 | def __getitem__(self, idx: int): 18 | """ 19 | get item at the index in self.indices_list 20 | :param idx: int, the index 21 | :return: 22 | """ 23 | return self.indices_list[idx] 24 | 25 | def __len__(self): 26 | return len(self.indices_list) 27 | 28 | 29 | def get_idx_data_loader(indices_list: list, batch_size: int, shuffle: bool): 30 | """ 31 | get data loader that iterates over indices 32 | :param indices_list: list, list of indices 33 | :param batch_size: int, batch size 34 | :param shuffle: boolean, whether to shuffle the data 35 | :return: data_loader, DataLoader 36 | """ 37 | dataset = CustomizedDataset(indices_list=indices_list) 38 | 39 | data_loader = DataLoader(dataset=dataset, 40 | batch_size=batch_size, 41 | shuffle=shuffle, 42 | drop_last=False) 43 | return data_loader 44 | 45 | 46 | class Data: 47 | 48 | def __init__(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, node_interact_times: np.ndarray, edge_ids: np.ndarray, labels: np.ndarray): 49 | """ 50 | Data object to store the nodes interaction information. 51 | :param src_node_ids: ndarray 52 | :param dst_node_ids: ndarray 53 | :param node_interact_times: ndarray 54 | :param edge_ids: ndarray 55 | :param labels: ndarray 56 | """ 57 | self.src_node_ids = src_node_ids 58 | self.dst_node_ids = dst_node_ids 59 | self.node_interact_times = node_interact_times 60 | self.edge_ids = edge_ids 61 | self.labels = labels 62 | self.num_interactions = len(src_node_ids) 63 | self.unique_node_ids = set(src_node_ids) | set(dst_node_ids) 64 | self.num_unique_nodes = len(self.unique_node_ids) 65 | 66 | 67 | def get_link_prediction_data(dataset_name: str, val_ratio: float, test_ratio: float): 68 | """ 69 | generate data for link prediction task (inductive & transductive settings) 70 | :param dataset_name: str, dataset name 71 | :param val_ratio: float, validation data ratio 72 | :param test_ratio: float, test data ratio 73 | :return: node_raw_features, edge_raw_features, (np.ndarray), 74 | full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data, (Data object) 75 | """ 76 | # Load data and train val test split 77 | graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(dataset_name, dataset_name)) 78 | edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(dataset_name, dataset_name)) 79 | node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(dataset_name, dataset_name)) 80 | 81 | NODE_FEAT_DIM = EDGE_FEAT_DIM = 172 82 | assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {dataset_name} is bigger than {NODE_FEAT_DIM}!' 83 | assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {dataset_name} is bigger than {EDGE_FEAT_DIM}!' 84 | # padding the features of edges and nodes to the same dimension (172 for all the datasets) 85 | if node_raw_features.shape[1] < NODE_FEAT_DIM: 86 | node_zero_padding = np.zeros((node_raw_features.shape[0], 172 - node_raw_features.shape[1])) 87 | node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1) 88 | if edge_raw_features.shape[1] < EDGE_FEAT_DIM: 89 | edge_zero_padding = np.zeros((edge_raw_features.shape[0], 172 - edge_raw_features.shape[1])) 90 | edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1) 91 | 92 | assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], "Unaligned feature dimensions after feature padding!" 93 | 94 | # get the timestamp of validate and test set 95 | val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)])) 96 | 97 | src_node_ids = graph_df.u.values.astype(int) 98 | dst_node_ids = graph_df.i.values.astype(int) 99 | node_interact_times = graph_df.ts.values.astype(np.float64) 100 | edge_ids = graph_df.idx.values.astype(int) 101 | labels = graph_df.label.values 102 | 103 | full_data = Data(src_node_ids=src_node_ids, dst_node_ids=dst_node_ids, node_interact_times=node_interact_times, edge_ids=edge_ids, labels=labels) 104 | 105 | # the setting of seed follows previous works 106 | random.seed(2020) 107 | 108 | # union to get node set 109 | node_set = set(src_node_ids) | set(dst_node_ids) 110 | num_total_unique_node_ids = len(node_set) 111 | 112 | # compute nodes which appear at test time 113 | test_node_set = set(src_node_ids[node_interact_times > val_time]).union(set(dst_node_ids[node_interact_times > val_time])) 114 | # sample nodes which we keep as new nodes (to test inductiveness), so then we have to remove all their edges from training 115 | new_test_node_set = set(random.sample(test_node_set, int(0.1 * num_total_unique_node_ids))) 116 | 117 | # mask for each source and destination to denote whether they are new test nodes 118 | new_test_source_mask = graph_df.u.map(lambda x: x in new_test_node_set).values 119 | new_test_destination_mask = graph_df.i.map(lambda x: x in new_test_node_set).values 120 | 121 | # mask, which is true for edges with both destination and source not being new test nodes (because we want to remove all edges involving any new test node) 122 | observed_edges_mask = np.logical_and(~new_test_source_mask, ~new_test_destination_mask) 123 | 124 | # for train data, we keep edges happening before the validation time which do not involve any new node, used for inductiveness 125 | train_mask = np.logical_and(node_interact_times <= val_time, observed_edges_mask) 126 | 127 | train_data = Data(src_node_ids=src_node_ids[train_mask], dst_node_ids=dst_node_ids[train_mask], 128 | node_interact_times=node_interact_times[train_mask], 129 | edge_ids=edge_ids[train_mask], labels=labels[train_mask]) 130 | 131 | # define the new nodes sets for testing inductiveness of the model 132 | train_node_set = set(train_data.src_node_ids).union(train_data.dst_node_ids) 133 | assert len(train_node_set & new_test_node_set) == 0 134 | # new nodes that are not in the training set 135 | new_node_set = node_set - train_node_set 136 | 137 | val_mask = np.logical_and(node_interact_times <= test_time, node_interact_times > val_time) 138 | test_mask = node_interact_times > test_time 139 | 140 | # new edges with new nodes in the val and test set (for inductive evaluation) 141 | edge_contains_new_node_mask = np.array([(src_node_id in new_node_set or dst_node_id in new_node_set) 142 | for src_node_id, dst_node_id in zip(src_node_ids, dst_node_ids)]) 143 | new_node_val_mask = np.logical_and(val_mask, edge_contains_new_node_mask) 144 | new_node_test_mask = np.logical_and(test_mask, edge_contains_new_node_mask) 145 | 146 | # validation and test data 147 | val_data = Data(src_node_ids=src_node_ids[val_mask], dst_node_ids=dst_node_ids[val_mask], 148 | node_interact_times=node_interact_times[val_mask], edge_ids=edge_ids[val_mask], labels=labels[val_mask]) 149 | 150 | test_data = Data(src_node_ids=src_node_ids[test_mask], dst_node_ids=dst_node_ids[test_mask], 151 | node_interact_times=node_interact_times[test_mask], edge_ids=edge_ids[test_mask], labels=labels[test_mask]) 152 | 153 | # validation and test with edges that at least has one new node (not in training set) 154 | new_node_val_data = Data(src_node_ids=src_node_ids[new_node_val_mask], dst_node_ids=dst_node_ids[new_node_val_mask], 155 | node_interact_times=node_interact_times[new_node_val_mask], 156 | edge_ids=edge_ids[new_node_val_mask], labels=labels[new_node_val_mask]) 157 | 158 | new_node_test_data = Data(src_node_ids=src_node_ids[new_node_test_mask], dst_node_ids=dst_node_ids[new_node_test_mask], 159 | node_interact_times=node_interact_times[new_node_test_mask], 160 | edge_ids=edge_ids[new_node_test_mask], labels=labels[new_node_test_mask]) 161 | 162 | print("The dataset has {} interactions, involving {} different nodes".format(full_data.num_interactions, full_data.num_unique_nodes)) 163 | print("The training dataset has {} interactions, involving {} different nodes".format( 164 | train_data.num_interactions, train_data.num_unique_nodes)) 165 | print("The validation dataset has {} interactions, involving {} different nodes".format( 166 | val_data.num_interactions, val_data.num_unique_nodes)) 167 | print("The test dataset has {} interactions, involving {} different nodes".format( 168 | test_data.num_interactions, test_data.num_unique_nodes)) 169 | print("The new node validation dataset has {} interactions, involving {} different nodes".format( 170 | new_node_val_data.num_interactions, new_node_val_data.num_unique_nodes)) 171 | print("The new node test dataset has {} interactions, involving {} different nodes".format( 172 | new_node_test_data.num_interactions, new_node_test_data.num_unique_nodes)) 173 | print("{} nodes were used for the inductive testing, i.e. are never seen during training".format(len(new_test_node_set))) 174 | 175 | return node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data 176 | 177 | 178 | def get_node_classification_data(dataset_name: str, val_ratio: float, test_ratio: float): 179 | """ 180 | generate data for node classification task 181 | :param dataset_name: str, dataset name 182 | :param val_ratio: float, validation data ratio 183 | :param test_ratio: float, test data ratio 184 | :return: 185 | """ 186 | # Load data and train val test split 187 | graph_df = pd.read_csv('./processed_data/{}/ml_{}.csv'.format(dataset_name, dataset_name)) 188 | edge_raw_features = np.load('./processed_data/{}/ml_{}.npy'.format(dataset_name, dataset_name)) 189 | node_raw_features = np.load('./processed_data/{}/ml_{}_node.npy'.format(dataset_name, dataset_name)) 190 | 191 | NODE_FEAT_DIM = EDGE_FEAT_DIM = 172 192 | assert NODE_FEAT_DIM >= node_raw_features.shape[1], f'Node feature dimension in dataset {dataset_name} is bigger than {NODE_FEAT_DIM}!' 193 | assert EDGE_FEAT_DIM >= edge_raw_features.shape[1], f'Edge feature dimension in dataset {dataset_name} is bigger than {EDGE_FEAT_DIM}!' 194 | # padding the features of edges and nodes to the same dimension (172 for all the datasets) 195 | if node_raw_features.shape[1] < NODE_FEAT_DIM: 196 | node_zero_padding = np.zeros((node_raw_features.shape[0], 172 - node_raw_features.shape[1])) 197 | node_raw_features = np.concatenate([node_raw_features, node_zero_padding], axis=1) 198 | if edge_raw_features.shape[1] < EDGE_FEAT_DIM: 199 | edge_zero_padding = np.zeros((edge_raw_features.shape[0], 172 - edge_raw_features.shape[1])) 200 | edge_raw_features = np.concatenate([edge_raw_features, edge_zero_padding], axis=1) 201 | 202 | assert NODE_FEAT_DIM == node_raw_features.shape[1] and EDGE_FEAT_DIM == edge_raw_features.shape[1], "Unaligned feature dimensions after feature padding!" 203 | 204 | # get the timestamp of validate and test set 205 | val_time, test_time = list(np.quantile(graph_df.ts, [(1 - val_ratio - test_ratio), (1 - test_ratio)])) 206 | 207 | src_node_ids = graph_df.u.values.astype(np.long) 208 | dst_node_ids = graph_df.i.values.astype(np.long) 209 | node_interact_times = graph_df.ts.values.astype(np.float64) 210 | edge_ids = graph_df.idx.values.astype(np.long) 211 | labels = graph_df.label.values 212 | 213 | # The setting of seed follows previous works 214 | random.seed(2020) 215 | 216 | train_mask = node_interact_times <= val_time 217 | val_mask = np.logical_and(node_interact_times <= test_time, node_interact_times > val_time) 218 | test_mask = node_interact_times > test_time 219 | 220 | full_data = Data(src_node_ids=src_node_ids, dst_node_ids=dst_node_ids, node_interact_times=node_interact_times, edge_ids=edge_ids, labels=labels) 221 | train_data = Data(src_node_ids=src_node_ids[train_mask], dst_node_ids=dst_node_ids[train_mask], 222 | node_interact_times=node_interact_times[train_mask], 223 | edge_ids=edge_ids[train_mask], labels=labels[train_mask]) 224 | val_data = Data(src_node_ids=src_node_ids[val_mask], dst_node_ids=dst_node_ids[val_mask], 225 | node_interact_times=node_interact_times[val_mask], edge_ids=edge_ids[val_mask], labels=labels[val_mask]) 226 | test_data = Data(src_node_ids=src_node_ids[test_mask], dst_node_ids=dst_node_ids[test_mask], 227 | node_interact_times=node_interact_times[test_mask], edge_ids=edge_ids[test_mask], labels=labels[test_mask]) 228 | 229 | return node_raw_features, edge_raw_features, full_data, train_data, val_data, test_data 230 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/models/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TimeEncoder(nn.Module): 8 | 9 | def __init__(self, time_dim: int, parameter_requires_grad: bool = True): 10 | """ 11 | Time encoder. 12 | :param time_dim: int, dimension of time encodings 13 | :param parameter_requires_grad: boolean, whether the parameter in TimeEncoder needs gradient 14 | """ 15 | super(TimeEncoder, self).__init__() 16 | 17 | self.time_dim = time_dim 18 | # trainable parameters for time encoding 19 | self.w = nn.Linear(1, time_dim) 20 | self.w.weight = nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, time_dim, dtype=np.float32))).reshape(time_dim, -1)) 21 | self.w.bias = nn.Parameter(torch.zeros(time_dim)) 22 | 23 | if not parameter_requires_grad: 24 | self.w.weight.requires_grad = False 25 | self.w.bias.requires_grad = False 26 | 27 | def forward(self, timestamps: torch.Tensor): 28 | """ 29 | compute time encodings of time in timestamps 30 | :param timestamps: Tensor, shape (batch_size, seq_len) 31 | :return: 32 | """ 33 | # Tensor, shape (batch_size, seq_len, 1) 34 | timestamps = timestamps.unsqueeze(dim=2) 35 | 36 | # Tensor, shape (batch_size, seq_len, time_dim) 37 | output = torch.sin(self.w(timestamps)) 38 | 39 | return output 40 | 41 | 42 | class MergeLayer(nn.Module): 43 | 44 | def __init__(self, input_dim1: int, input_dim2: int, hidden_dim: int, output_dim: int): 45 | """ 46 | Merge Layer to merge two inputs via: input_dim1 + input_dim2 -> hidden_dim -> output_dim. 47 | :param input_dim1: int, dimension of first input 48 | :param input_dim2: int, dimension of the second input 49 | :param hidden_dim: int, hidden dimension 50 | :param output_dim: int, dimension of the output 51 | """ 52 | super().__init__() 53 | self.fc1 = nn.Linear(input_dim1 + input_dim2, hidden_dim) 54 | self.fc2 = nn.Linear(hidden_dim, output_dim) 55 | self.act = nn.ReLU() 56 | 57 | def forward(self, input_1: torch.Tensor, input_2: torch.Tensor): 58 | """ 59 | merge and project the inputs 60 | :param input_1: Tensor, shape (*, input_dim1) 61 | :param input_2: Tensor, shape (*, input_dim2) 62 | :return: 63 | """ 64 | # Tensor, shape (*, input_dim1 + input_dim2) 65 | x = torch.cat([input_1, input_2], dim=1) 66 | # Tensor, shape (*, output_dim) 67 | h = self.fc2(self.act(self.fc1(x))) 68 | return h 69 | 70 | 71 | class MLPClassifier(nn.Module): 72 | def __init__(self, input_dim: int, dropout: float = 0.1): 73 | """ 74 | Multi-Layer Perceptron Classifier. 75 | :param input_dim: int, dimension of input 76 | :param dropout: float, dropout rate 77 | """ 78 | super().__init__() 79 | self.fc1 = nn.Linear(input_dim, 80) 80 | self.fc2 = nn.Linear(80, 10) 81 | self.fc3 = nn.Linear(10, 1) 82 | self.act = nn.ReLU() 83 | self.dropout = nn.Dropout(dropout) 84 | 85 | def forward(self, x: torch.Tensor): 86 | """ 87 | multi-layer perceptron classifier forward process 88 | :param x: Tensor, shape (*, input_dim) 89 | :return: 90 | """ 91 | # Tensor, shape (*, 80) 92 | x = self.dropout(self.act(self.fc1(x))) 93 | # Tensor, shape (*, 10) 94 | x = self.dropout(self.act(self.fc2(x))) 95 | # Tensor, shape (*, 1) 96 | return self.fc3(x) 97 | 98 | 99 | class MultiHeadAttention(nn.Module): 100 | 101 | def __init__(self, node_feat_dim: int, edge_feat_dim: int, time_feat_dim: int, 102 | num_heads: int = 2, dropout: float = 0.1): 103 | """ 104 | Multi-head Attention module. 105 | :param node_feat_dim: int, dimension of node features 106 | :param edge_feat_dim: int, dimension of edge features 107 | :param time_feat_dim: int, dimension of time features (time encodings) 108 | :param num_heads: int, number of attention heads 109 | :param dropout: float, dropout rate 110 | """ 111 | super(MultiHeadAttention, self).__init__() 112 | 113 | self.node_feat_dim = node_feat_dim 114 | self.edge_feat_dim = edge_feat_dim 115 | self.time_feat_dim = time_feat_dim 116 | self.num_heads = num_heads 117 | 118 | self.query_dim = node_feat_dim + time_feat_dim 119 | self.key_dim = node_feat_dim + edge_feat_dim + time_feat_dim 120 | 121 | assert self.query_dim % num_heads == 0, "The sum of node_feat_dim and time_feat_dim should be divided by num_heads!" 122 | 123 | self.head_dim = self.query_dim // num_heads 124 | 125 | self.query_projection = nn.Linear(self.query_dim, num_heads * self.head_dim, bias=False) 126 | self.key_projection = nn.Linear(self.key_dim, num_heads * self.head_dim, bias=False) 127 | self.value_projection = nn.Linear(self.key_dim, num_heads * self.head_dim, bias=False) 128 | 129 | self.scaling_factor = self.head_dim ** -0.5 130 | 131 | self.layer_norm = nn.LayerNorm(self.query_dim) 132 | 133 | self.residual_fc = nn.Linear(num_heads * self.head_dim, self.query_dim) 134 | 135 | self.dropout = nn.Dropout(dropout) 136 | 137 | def forward(self, node_features: torch.Tensor, node_time_features: torch.Tensor, neighbor_node_features: torch.Tensor, 138 | neighbor_node_time_features: torch.Tensor, neighbor_node_edge_features: torch.Tensor, neighbor_masks: np.ndarray): 139 | """ 140 | temporal attention forward process 141 | :param node_features: Tensor, shape (batch_size, node_feat_dim) 142 | :param node_time_features: Tensor, shape (batch_size, 1, time_feat_dim) 143 | :param neighbor_node_features: Tensor, shape (batch_size, num_neighbors, node_feat_dim) 144 | :param neighbor_node_time_features: Tensor, shape (batch_size, num_neighbors, time_feat_dim) 145 | :param neighbor_node_edge_features: Tensor, shape (batch_size, num_neighbors, edge_feat_dim) 146 | :param neighbor_masks: ndarray, shape (batch_size, num_neighbors), used to create mask of neighbors for nodes in the batch 147 | :return: 148 | """ 149 | # Tensor, shape (batch_size, 1, node_feat_dim) 150 | node_features = torch.unsqueeze(node_features, dim=1) 151 | 152 | # Tensor, shape (batch_size, 1, node_feat_dim + time_feat_dim) 153 | query = residual = torch.cat([node_features, node_time_features], dim=2) 154 | # shape (batch_size, 1, num_heads, self.head_dim) 155 | query = self.query_projection(query).reshape(query.shape[0], query.shape[1], self.num_heads, self.head_dim) 156 | 157 | # Tensor, shape (batch_size, num_neighbors, node_feat_dim + edge_feat_dim + time_feat_dim) 158 | key = value = torch.cat([neighbor_node_features, neighbor_node_edge_features, neighbor_node_time_features], dim=2) 159 | # Tensor, shape (batch_size, num_neighbors, num_heads, self.head_dim) 160 | key = self.key_projection(key).reshape(key.shape[0], key.shape[1], self.num_heads, self.head_dim) 161 | # Tensor, shape (batch_size, num_neighbors, num_heads, self.head_dim) 162 | value = self.value_projection(value).reshape(value.shape[0], value.shape[1], self.num_heads, self.head_dim) 163 | 164 | # Tensor, shape (batch_size, num_heads, 1, self.head_dim) 165 | query = query.permute(0, 2, 1, 3) 166 | # Tensor, shape (batch_size, num_heads, num_neighbors, self.head_dim) 167 | key = key.permute(0, 2, 1, 3) 168 | # Tensor, shape (batch_size, num_heads, num_neighbors, self.head_dim) 169 | value = value.permute(0, 2, 1, 3) 170 | 171 | # Tensor, shape (batch_size, num_heads, 1, num_neighbors) 172 | attention = torch.einsum('bhld,bhnd->bhln', query, key) 173 | attention = attention * self.scaling_factor 174 | 175 | # Tensor, shape (batch_size, 1, num_neighbors) 176 | attention_mask = torch.from_numpy(neighbor_masks).to(node_features.device).unsqueeze(dim=1) 177 | attention_mask = attention_mask == 0 178 | # Tensor, shape (batch_size, self.num_heads, 1, num_neighbors) 179 | attention_mask = torch.stack([attention_mask for _ in range(self.num_heads)], dim=1) 180 | 181 | # Tensor, shape (batch_size, self.num_heads, 1, num_neighbors) 182 | # note that if a node has no valid neighbor (whose neighbor_masks are all zero), directly set the masks to -np.inf will make the 183 | # attention scores after softmax be nan. Therefore, we choose a very large negative number (-1e10 following TGAT) instead of -np.inf to tackle this case 184 | attention = attention.masked_fill(attention_mask, -1e10) 185 | 186 | # Tensor, shape (batch_size, num_heads, 1, num_neighbors) 187 | attention_scores = self.dropout(torch.softmax(attention, dim=-1)) 188 | 189 | # Tensor, shape (batch_size, num_heads, 1, self.head_dim) 190 | attention_output = torch.einsum('bhln,bhnd->bhld', attention_scores, value) 191 | 192 | # Tensor, shape (batch_size, 1, num_heads * self.head_dim), where num_heads * self.head_dim is equal to node_feat_dim + time_feat_dim 193 | attention_output = attention_output.permute(0, 2, 1, 3).flatten(start_dim=2) 194 | 195 | # Tensor, shape (batch_size, 1, node_feat_dim + time_feat_dim) 196 | output = self.dropout(self.residual_fc(attention_output)) 197 | 198 | # Tensor, shape (batch_size, 1, node_feat_dim + time_feat_dim) 199 | output = self.layer_norm(output + residual) 200 | 201 | # Tensor, shape (batch_size, node_feat_dim + time_feat_dim) 202 | output = output.squeeze(dim=1) 203 | # Tensor, shape (batch_size, num_heads, num_neighbors) 204 | attention_scores = attention_scores.squeeze(dim=2) 205 | 206 | return output, attention_scores 207 | 208 | 209 | class TransformerEncoder(nn.Module): 210 | 211 | def __init__(self, attention_dim: int, num_heads: int, dropout: float = 0.1): 212 | """ 213 | Transformer encoder. 214 | :param attention_dim: int, dimension of the attention vector 215 | :param num_heads: int, number of attention heads 216 | :param dropout: float, dropout rate 217 | """ 218 | super(TransformerEncoder, self).__init__() 219 | # use the MultiheadAttention implemented by PyTorch 220 | self.multi_head_attention = nn.MultiheadAttention(embed_dim=attention_dim, num_heads=num_heads, dropout=dropout) 221 | 222 | self.dropout = nn.Dropout(dropout) 223 | 224 | self.linear_layers = nn.ModuleList([ 225 | nn.Linear(in_features=attention_dim, out_features=4 * attention_dim), 226 | nn.Linear(in_features=4 * attention_dim, out_features=attention_dim) 227 | ]) 228 | self.norm_layers = nn.ModuleList([ 229 | nn.LayerNorm(attention_dim), 230 | nn.LayerNorm(attention_dim) 231 | ]) 232 | 233 | def forward(self, inputs_query: torch.Tensor, inputs_key: torch.Tensor = None, inputs_value: torch.Tensor = None, 234 | neighbor_masks: np.ndarray = None): 235 | """ 236 | encode the inputs by Transformer encoder 237 | :param inputs_query: Tensor, shape (batch_size, target_seq_length, self.attention_dim) 238 | :param inputs_key: Tensor, shape (batch_size, source_seq_length, self.attention_dim) 239 | :param inputs_value: Tensor, shape (batch_size, source_seq_length, self.attention_dim) 240 | :param neighbor_masks: ndarray, shape (batch_size, source_seq_length), used to create mask of neighbors for nodes in the batch 241 | :return: 242 | """ 243 | if inputs_key is None or inputs_value is None: 244 | assert inputs_key is None and inputs_value is None 245 | inputs_key = inputs_value = inputs_query 246 | # note that the MultiheadAttention module accept input data with shape (seq_length, batch_size, input_dim), so we need to transpose the input 247 | # transposed_inputs_query, Tensor, shape (target_seq_length, batch_size, self.attention_dim) 248 | # transposed_inputs_key, Tensor, shape (source_seq_length, batch_size, self.attention_dim) 249 | # transposed_inputs_value, Tensor, shape (source_seq_length, batch_size, self.attention_dim) 250 | transposed_inputs_query, transposed_inputs_key, transposed_inputs_value = inputs_query.transpose(0, 1), inputs_key.transpose(0, 1), inputs_value.transpose(0, 1) 251 | 252 | if neighbor_masks is not None: 253 | # Tensor, shape (batch_size, source_seq_length) 254 | neighbor_masks = torch.from_numpy(neighbor_masks).to(inputs_query.device) == 0 255 | 256 | # Tensor, shape (batch_size, target_seq_length, self.attention_dim) 257 | hidden_states = self.multi_head_attention(query=transposed_inputs_query, key=transposed_inputs_key, 258 | value=transposed_inputs_value, key_padding_mask=neighbor_masks)[0].transpose(0, 1) 259 | # Tensor, shape (batch_size, target_seq_length, self.attention_dim) 260 | outputs = self.norm_layers[0](inputs_query + self.dropout(hidden_states)) 261 | # Tensor, shape (batch_size, target_seq_length, self.attention_dim) 262 | hidden_states = self.linear_layers[1](self.dropout(F.relu(self.linear_layers[0](outputs)))) 263 | # Tensor, shape (batch_size, target_seq_length, self.attention_dim) 264 | outputs = self.norm_layers[1](outputs + self.dropout(hidden_states)) 265 | 266 | return outputs 267 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/utils/load_configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import torch 4 | 5 | 6 | def get_link_prediction_args(is_evaluation: bool = False): 7 | """ 8 | get the args for the link prediction task 9 | :param is_evaluation: boolean, whether in the evaluation process 10 | :return: 11 | """ 12 | # arguments 13 | parser = argparse.ArgumentParser('Interface for the link prediction task') 14 | parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='wikipedia', 15 | choices=['escorts', 'wikipedia', 'reddit', 'mooc', 'lastfm', 'enron', 'SocialEvo', 'uci', 'Flights', 'CanParl', 'USLegis', 'UNtrade', 'UNvote', 'Contacts']) 16 | parser.add_argument('--batch_size', type=int, default=200, help='batch size') 17 | parser.add_argument('--model_name', type=str, default='DyGFormer', help='name of the model, note that EdgeBank is only applicable for evaluation', 18 | choices=['JODIE', 'DyRep', 'TGAT', 'TGN', 'CAWN', 'EdgeBank', 'TCL', 'GraphMixer', 'DyGFormer']) 19 | parser.add_argument('--gpu', type=int, default=0, help='number of gpu to use') 20 | parser.add_argument('--num_neighbors', type=int, default=20, help='number of neighbors to sample for each node') 21 | parser.add_argument('--sample_neighbor_strategy', default='recent', choices=['uniform', 'recent', 'time_interval_aware'], help='how to sample historical neighbors') 22 | parser.add_argument('--time_scaling_factor', default=1e-6, type=float, help='the hyperparameter that controls the sampling preference with time interval, ' 23 | 'a large time_scaling_factor tends to sample more on recent links, 0.0 corresponds to uniform sampling, ' 24 | 'it works when sample_neighbor_strategy == time_interval_aware') 25 | parser.add_argument('--num_walk_heads', type=int, default=8, help='number of heads used for the attention in walk encoder') 26 | parser.add_argument('--num_heads', type=int, default=2, help='number of heads used in attention layer') 27 | parser.add_argument('--num_layers', type=int, default=2, help='number of model layers') 28 | parser.add_argument('--walk_length', type=int, default=1, help='length of each random walk') 29 | parser.add_argument('--time_gap', type=int, default=2000, help='time gap for neighbors to compute node features') 30 | parser.add_argument('--time_feat_dim', type=int, default=100, help='dimension of the time embedding') 31 | parser.add_argument('--position_feat_dim', type=int, default=172, help='dimension of the position embedding') 32 | parser.add_argument('--edge_bank_memory_mode', type=str, default='unlimited_memory', help='how memory of EdgeBank works', 33 | choices=['unlimited_memory', 'time_window_memory', 'repeat_threshold_memory']) 34 | parser.add_argument('--time_window_mode', type=str, default='fixed_proportion', help='how to select the time window size for time window memory', 35 | choices=['fixed_proportion', 'repeat_interval']) 36 | parser.add_argument('--patch_size', type=int, default=1, help='patch size') 37 | parser.add_argument('--channel_embedding_dim', type=int, default=50, help='dimension of each channel embedding') 38 | parser.add_argument('--max_input_sequence_length', type=int, default=32, help='maximal length of the input sequence of each node') 39 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate') 40 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 41 | parser.add_argument('--num_epochs', type=int, default=50, help='number of epochs') 42 | parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adam', 'RMSprop'], help='name of optimizer') 43 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay') 44 | parser.add_argument('--patience', type=int, default=5, help='patience for early stopping') 45 | parser.add_argument('--val_ratio', type=float, default=0.15, help='ratio of validation set') 46 | parser.add_argument('--test_ratio', type=float, default=0.15, help='ratio of test set') 47 | parser.add_argument('--num_runs', type=int, default=1, help='number of runs') 48 | parser.add_argument('--test_interval_epochs', type=int, default=100, help='how many epochs to perform testing once') 49 | parser.add_argument('--negative_sample_strategy', type=str, default='random', choices=['random', 'historical', 'inductive'], 50 | help='strategy for the negative edge sampling') 51 | parser.add_argument('--load_best_configs', action='store_true', default=False, help='whether to load the best configurations') 52 | 53 | parser.add_argument('--tau', type=float, default=0.1) 54 | parser.add_argument('--gtau', type=float, default=1.0) 55 | parser.add_argument('--K', type=int, default=512) 56 | parser.add_argument('--coe', type=float, default=0.2) 57 | parser.add_argument('--ratio', type=float, default=0.02) 58 | parser.add_argument('--infer_bs', type=int, default=200) 59 | parser.add_argument('--can_nn', type=int, default=20) 60 | parser.add_argument('--rnn_nn', type=int, default=20) 61 | parser.add_argument('--rnn_layer', type=int, default=1) 62 | parser.add_argument('--can_type', type=str, choices=['1st', '3rd', 'random', 'mix'], default='3rd') 63 | parser.add_argument('--log_name', type=str, default='t') 64 | parser.add_argument('--prefix', type=str, default='t') 65 | 66 | try: 67 | args = parser.parse_args() 68 | args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' 69 | except: 70 | parser.print_help() 71 | sys.exit() 72 | 73 | if args.model_name == 'EdgeBank': 74 | assert is_evaluation, 'EdgeBank is only applicable for evaluation!' 75 | 76 | if args.load_best_configs: 77 | load_link_prediction_best_configs(args=args) 78 | 79 | return args 80 | 81 | 82 | def load_link_prediction_best_configs(args: argparse.Namespace): 83 | """ 84 | load the best configurations for the link prediction task 85 | :param args: argparse.Namespace 86 | :return: 87 | """ 88 | # model specific settings 89 | args.num_layers = 2 90 | if args.dataset_name in ['wikipedia']: 91 | args.num_neighbors = 30 92 | elif args.dataset_name in ['reddit', 'lastfm']: 93 | args.num_neighbors = 10 94 | else: 95 | args.num_neighbors = 20 96 | if args.dataset_name in ['wikipedia', 'reddit', 'enron']: 97 | args.dropout = 0.5 98 | elif args.dataset_name in ['mooc', 'uci', 'USLegis']: 99 | args.dropout = 0.4 100 | elif args.dataset_name in ['lastfm', 'UNvote']: 101 | args.dropout = 0.0 102 | elif args.dataset_name in ['SocialEvo']: 103 | args.dropout = 0.3 104 | elif args.dataset_name in ['Flights', 'CanParl']: 105 | args.dropout = 0.2 106 | else: 107 | args.dropout = 0.1 108 | if args.dataset_name in ['CanParl', 'UNtrade', 'UNvote']: 109 | args.sample_neighbor_strategy = 'uniform' 110 | else: 111 | args.sample_neighbor_strategy = 'recent' 112 | 113 | 114 | def get_node_classification_args(): 115 | """ 116 | get the args for the node classification task 117 | :return: 118 | """ 119 | # arguments 120 | parser = argparse.ArgumentParser('Interface for the node classification task') 121 | parser.add_argument('--dataset_name', type=str, help='dataset to be used', default='wikipedia', choices=['wikipedia', 'reddit']) 122 | parser.add_argument('--batch_size', type=int, default=200, help='batch size') 123 | parser.add_argument('--model_name', type=str, default='DyGFormer', help='name of the model', 124 | choices=['JODIE', 'DyRep', 'TGAT', 'TGN', 'CAWN', 'TCL', 'GraphMixer', 'DyGFormer']) 125 | parser.add_argument('--gpu', type=int, default=0, help='number of gpu to use') 126 | parser.add_argument('--num_neighbors', type=int, default=20, help='number of neighbors to sample for each node') 127 | parser.add_argument('--sample_neighbor_strategy', default='recent', 128 | choices=['uniform', 'recent', 'time_interval_aware'], help='how to sample historical neighbors') 129 | parser.add_argument('--time_scaling_factor', default=1e-6, type=float, help='the hyperparameter that controls the sampling preference with time interval, ' 130 | 'a large time_scaling_factor tends to sample more on recent links, 0.0 corresponds to uniform sampling, ' 131 | 'it works when sample_neighbor_strategy == time_interval_aware') 132 | parser.add_argument('--num_walk_heads', type=int, default=8, help='number of heads used for the attention in walk encoder') 133 | parser.add_argument('--num_heads', type=int, default=2, help='number of heads used in attention layer') 134 | parser.add_argument('--num_layers', type=int, default=2, help='number of model layers') 135 | parser.add_argument('--walk_length', type=int, default=1, help='length of each random walk') 136 | parser.add_argument('--time_gap', type=int, default=2000, help='time gap for neighbors to compute node features') 137 | parser.add_argument('--time_feat_dim', type=int, default=100, help='dimension of the time embedding') 138 | parser.add_argument('--position_feat_dim', type=int, default=172, help='dimension of the position embedding') 139 | parser.add_argument('--edge_bank_memory_mode', type=str, default='unlimited_memory', help='how memory of EdgeBank works', 140 | choices=['unlimited_memory', 'time_window_memory', 'repeat_threshold_memory']) 141 | parser.add_argument('--time_window_mode', type=str, default='fixed_proportion', help='how to select the time window size for time window memory', 142 | choices=['fixed_proportion', 'repeat_interval']) 143 | parser.add_argument('--patch_size', type=int, default=1, help='patch size') 144 | parser.add_argument('--channel_embedding_dim', type=int, default=50, help='dimension of each channel embedding') 145 | parser.add_argument('--max_input_sequence_length', type=int, default=32, help='maximal length of the input sequence of each node') 146 | parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate') 147 | parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate') 148 | parser.add_argument('--num_epochs', type=int, default=100, help='number of epochs') 149 | parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adam', 'RMSprop'], help='name of optimizer') 150 | parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay') 151 | parser.add_argument('--patience', type=int, default=20, help='patience for early stopping') 152 | parser.add_argument('--val_ratio', type=float, default=0.15, help='ratio of validation set') 153 | parser.add_argument('--test_ratio', type=float, default=0.15, help='ratio of test set') 154 | parser.add_argument('--num_runs', type=int, default=5, help='number of runs') 155 | parser.add_argument('--test_interval_epochs', type=int, default=10, help='how many epochs to perform testing once') 156 | parser.add_argument('--load_best_configs', action='store_true', default=False, help='whether to load the best configurations') 157 | 158 | try: 159 | args = parser.parse_args() 160 | args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu' 161 | except: 162 | parser.print_help() 163 | sys.exit() 164 | 165 | assert args.dataset_name in ['wikipedia', 'reddit'], f'Wrong value for dataset_name {args.dataset_name}!' 166 | if args.load_best_configs: 167 | load_node_classification_best_configs(args=args) 168 | 169 | return args 170 | 171 | 172 | def load_node_classification_best_configs(args: argparse.Namespace): 173 | """ 174 | load the best configurations for the node classification task 175 | :param args: argparse.Namespace 176 | :return: 177 | """ 178 | # model specific settings 179 | if args.model_name == 'TGAT': 180 | args.num_neighbors = 20 181 | args.num_layers = 2 182 | args.dropout = 0.1 183 | if args.dataset_name in ['reddit']: 184 | args.sample_neighbor_strategy = 'uniform' 185 | else: 186 | args.sample_neighbor_strategy = 'recent' 187 | elif args.model_name in ['JODIE', 'DyRep', 'TGN']: 188 | args.num_neighbors = 10 189 | args.num_layers = 1 190 | args.dropout = 0.1 191 | args.sample_neighbor_strategy = 'recent' 192 | elif args.model_name == 'CAWN': 193 | args.time_scaling_factor = 1e-6 194 | args.num_neighbors = 32 195 | args.dropout = 0.1 196 | args.sample_neighbor_strategy = 'time_interval_aware' 197 | elif args.model_name == 'TCL': 198 | args.num_neighbors = 20 199 | args.num_layers = 2 200 | args.dropout = 0.1 201 | if args.dataset_name in ['reddit']: 202 | args.sample_neighbor_strategy = 'uniform' 203 | else: 204 | args.sample_neighbor_strategy = 'recent' 205 | elif args.model_name == 'GraphMixer': 206 | args.num_layers = 2 207 | if args.dataset_name in ['reddit']: 208 | args.num_neighbors = 10 209 | else: 210 | args.num_neighbors = 30 211 | args.dropout = 0.5 212 | args.sample_neighbor_strategy = 'recent' 213 | elif args.model_name == 'DyGFormer': 214 | args.num_layers = 2 215 | if args.dataset_name in ['reddit']: 216 | args.max_input_sequence_length = 64 217 | args.patch_size = 2 218 | else: 219 | args.max_input_sequence_length = 32 220 | args.patch_size = 1 221 | assert args.max_input_sequence_length % args.patch_size == 0 222 | if args.dataset_name in ['reddit']: 223 | args.dropout = 0.2 224 | else: 225 | args.dropout = 0.1 226 | else: 227 | raise ValueError(f"Wrong value for model_name {args.model_name}!") 228 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/MTL.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import pickle 5 | 6 | from utils.utils import NeighborSampler 7 | 8 | 9 | class MTL(nn.Module): 10 | def __init__(self, base_encoder_k, encoder, view_learner, edge_rnn, sample_time_encoder, len_full_edge, 11 | train_e_idx_l, train_node_set, train_ts_l, e_feat, device, dim=172, K=600, m=0.999, tau=0.1, gtau=1.0, 12 | ratio=0.9, can_nn=20, rnn_nn=20, can_type='3rd'): 13 | super(MTL, self).__init__() 14 | self.K = K 15 | self.m = m 16 | self.tau = tau 17 | self.gtau = gtau 18 | self.ratio = ratio 19 | self.can_nn = can_nn 20 | self.rnn_nn = rnn_nn 21 | self.can_type = can_type 22 | 23 | self.encoder_k = base_encoder_k 24 | self.encoder = encoder 25 | self.view_learner = view_learner 26 | self.edge_rnn = edge_rnn 27 | self.sample_time_encoder = sample_time_encoder 28 | 29 | self.e_feat = e_feat 30 | self.len_full_edge = len_full_edge 31 | self.train_e_idx_l = train_e_idx_l 32 | self.train_node_set = np.array(list(train_node_set)) 33 | self.train_ts_l = train_ts_l 34 | self.max_train_ts_l = max(train_ts_l) 35 | self.device = device 36 | 37 | for param_q, param_k in zip( 38 | self.encoder.parameters(), self.encoder_k.parameters() 39 | ): 40 | param_k.data.copy_(param_q.data) # initialize 41 | param_k.requires_grad = False # not update by gradient 42 | 43 | self.register_buffer("queue", torch.randn(dim, K)) 44 | self.queue = nn.functional.normalize(self.queue, dim=0) 45 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 46 | 47 | @torch.no_grad() 48 | def _momentum_update_key_encoder(self): 49 | for param_q, param_k in zip( 50 | self.encoder.parameters(), self.encoder_k.parameters() 51 | ): 52 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 53 | 54 | @torch.no_grad() 55 | def _dequeue_and_enqueue(self, keys): 56 | batch_size = keys.shape[0] 57 | ptr = int(self.queue_ptr) 58 | if ptr + batch_size > self.K: 59 | keys_prev = keys[:self.K - ptr] 60 | keys_aft = keys[self.K - ptr:] 61 | self.queue[:, ptr: self.K] = keys_prev.T 62 | self.queue[:, : len(keys_aft)] = keys_aft.T 63 | ptr = len(keys_aft) 64 | self.queue_ptr[0] = ptr 65 | else: 66 | self.queue[:, ptr : ptr + batch_size] = keys.T 67 | ptr = (ptr + batch_size) % self.K # move pointer 68 | self.queue_ptr[0] = ptr 69 | 70 | def forward(self, src_l_cut, dst_l_cut, src_l_fake, dst_l_fake, ts_l_cut, eid_l_cut, NUM_NEIGHBORS, time_gap, g, adj_list, adj_list_pickle): 71 | batch_src_node_embeddings, batch_dst_node_embeddings = self.encoder[0].compute_src_dst_node_temporal_embeddings(src_node_ids=src_l_cut, dst_node_ids=dst_l_cut, node_interact_times=ts_l_cut, num_neighbors=NUM_NEIGHBORS, time_gap=time_gap) 72 | batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = self.encoder[0].compute_src_dst_node_temporal_embeddings(src_node_ids=src_l_fake, dst_node_ids=dst_l_fake, node_interact_times=ts_l_cut, num_neighbors=NUM_NEIGHBORS, time_gap=time_gap) 73 | pos_score = self.encoder[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid() 74 | neg_score = self.encoder[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid() 75 | 76 | train_node_cut = np.array(list(set(np.append(src_l_cut, dst_l_cut)))) 77 | max_ts = np.array([self.max_train_ts_l + 1] * len(train_node_cut)) 78 | 79 | neighbor_finder = NeighborSampler(adj_list, sample_neighbor_strategy='uniform') 80 | neighbor_node_idx, neighbor_edge_idx, neighbor_ts = neighbor_finder.get_historical_neighbors(train_node_cut, max_ts, num_neighbors=self.rnn_nn) 81 | 82 | train_edge_feat = self.view_learner(g) 83 | full_edge_feat = torch.zeros(self.len_full_edge + 1, train_edge_feat.shape[1], device=self.device) 84 | full_edge_feat[self.train_e_idx_l - 1] = train_edge_feat 85 | 86 | neighbor_edge_idx = neighbor_edge_idx.reshape(-1) - 1 87 | neighbor_edge_feat = full_edge_feat[neighbor_edge_idx] # [bs * 20, 172] 88 | neighbor_edge_feat = neighbor_edge_feat.reshape(neighbor_node_idx.shape[0], neighbor_node_idx.shape[1], -1) # [bs, 20, 172] 89 | 90 | neighbor_edge_feat = neighbor_edge_feat.transpose(0, 1) 91 | _, (h_n, _) = self.edge_rnn(neighbor_edge_feat) 92 | context_vec = h_n[-1] # [bs, 172] 93 | 94 | neighbor_finder.sample_neighbor_strategy = 'uniform' 95 | if self.can_type == '1st': 96 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.get_historical_neighbors(train_node_cut, max_ts, num_neighbors=self.can_nn) 97 | 98 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 99 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 100 | elif self.can_type == '3rd': 101 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.get_multi_hop_neighbors(3, train_node_cut, max_ts, num_neighbors=self.can_nn) 102 | candidate_node_idx = candidate_node_idx[-1].reshape(train_node_cut.shape[0], -1) 103 | candidate_edge_idx = candidate_edge_idx[-1].reshape(train_node_cut.shape[0], -1) 104 | candidate_ts = candidate_ts[-1].reshape(train_node_cut.shape[0], -1) 105 | 106 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 107 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 108 | elif self.can_type == 'random': 109 | candidate_node_idx = np.random.choice(self.train_node_set, size=train_node_cut.shape[0] * self.can_nn, replace=True).reshape(train_node_cut.shape[0], -1) # [bs, 20] 110 | candidate_edge_idx = np.array([0] * (train_node_cut.shape[0] * self.can_nn)).reshape(train_node_cut.shape[0], -1) 111 | candidate_ts = np.random.rand(train_node_cut.shape[0], self.can_nn) * self.max_train_ts_l 112 | 113 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 114 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 115 | elif self.can_type == 'mix': 116 | candidate_node_idx_1st, candidate_edge_idx_1st, candidate_ts_1st = neighbor_finder.get_historical_neighbors(train_node_cut, max_ts, num_neighbors=self.can_nn) 117 | candidate_node_idx_3rd, candidate_edge_idx_3rd, candidate_ts_3rd = neighbor_finder.get_multi_hop_neighbors(3, train_node_cut, max_ts, num_neighbors=self.can_nn) 118 | candidate_node_idx_3rd = candidate_node_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 119 | candidate_edge_idx_3rd = candidate_edge_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 120 | candidate_ts_3rd = candidate_ts_3rd[-1].reshape(train_node_cut.shape[0], -1) 121 | 122 | candidate_node_idx = np.concatenate((candidate_node_idx_1st, candidate_node_idx_3rd), axis=-1) 123 | candidate_edge_idx = np.concatenate((candidate_edge_idx_1st, candidate_edge_idx_3rd), axis=-1) 124 | candidate_ts = np.concatenate((candidate_ts_1st, candidate_ts_3rd), axis=-1) 125 | 126 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 127 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 128 | else: 129 | pass 130 | 131 | candidate_edge_idx = candidate_edge_idx.reshape(-1) - 1 132 | candidate_edge_feat = full_edge_feat[candidate_edge_idx] # [bs * 20, 172] 133 | candidate_edge_feat = candidate_edge_feat.reshape(candidate_node_idx.shape[0], candidate_node_idx.shape[1], -1) # [bs, 20, 172] 134 | 135 | ts_aug = np.random.rand(candidate_ts.shape[0], candidate_ts.shape[1]) * self.max_train_ts_l 136 | delta_ts_sample = ts_aug - candidate_ts 137 | delta_ts_sample_context = ts_aug - np.ones_like(candidate_ts) * self.max_train_ts_l 138 | delta_ts_sample_embedding = self.sample_time_encoder(torch.tensor(delta_ts_sample.reshape(-1, 1), dtype=torch.float32).to(self.device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 139 | delta_ts_sample_context_embedding = self.sample_time_encoder(torch.tensor(delta_ts_sample_context.reshape(-1, 1), dtype=torch.float32).to(self.device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 140 | 141 | context_vec = context_vec.unsqueeze(1).expand_as(candidate_edge_feat) 142 | context_vec = context_vec * delta_ts_sample_context_embedding 143 | candidate_edge_feat = candidate_edge_feat * delta_ts_sample_embedding 144 | aug_edge_logits = torch.sum(context_vec * candidate_edge_feat, dim=-1) # [bs, 20, 1] 145 | 146 | # Gumble-Top-K 147 | bias = 0.0 + 0.0001 # If bias is 0, we run into problems 148 | eps = (bias - (1 - bias)) * torch.rand(aug_edge_logits.size()) + (1 - bias) 149 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 150 | gate_inputs = gate_inputs.to(aug_edge_logits.device) 151 | gate_inputs = (gate_inputs + aug_edge_logits) / self.gtau 152 | z = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 153 | __, sorted_idx = z.sort(dim=-1, descending=True) 154 | k = int(self.ratio * z.size(1)) 155 | keep = sorted_idx[:, :k] # [bs, k] 156 | 157 | aug_edge_logits = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 158 | aug_edge_weight = torch.gather(aug_edge_logits, dim=1, index=keep) # [bs, k] 159 | ts_aug = torch.gather(torch.tensor(ts_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 160 | src_node_idx_aug = torch.gather(torch.tensor(src_node_idx_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 161 | dst_node_idx_aug = torch.gather(torch.tensor(dst_node_idx_aug, device=self.device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 162 | candidate_edge_feat = torch.gather(candidate_edge_feat, dim=1, index=keep.unsqueeze(2).repeat(1, 1, candidate_edge_feat.shape[2])) # [bs, k, 172] 163 | 164 | aug_edge_weight = aug_edge_weight.reshape(-1) 165 | ts_aug = ts_aug.reshape(-1) 166 | src_node_idx_aug = src_node_idx_aug.reshape(-1) 167 | dst_node_idx_aug = dst_node_idx_aug.reshape(-1) 168 | 169 | temp_eid = self.len_full_edge 170 | new_eid_list = [] 171 | adj_list_aug = pickle.loads(adj_list_pickle) 172 | for src, dst, ts in zip(src_node_idx_aug, dst_node_idx_aug, ts_aug): 173 | adj_list_aug[src].append((dst, temp_eid, ts)) 174 | adj_list_aug[dst].append((src, temp_eid, ts)) 175 | new_eid_list.append(temp_eid) 176 | temp_eid += 1 177 | 178 | train_ngh_finder_aug = NeighborSampler(adj_list_aug, sample_neighbor_strategy=self.encoder[0].neighbor_sampler.sample_neighbor_strategy, seed=self.encoder[0].neighbor_sampler.seed) 179 | 180 | new_eid_list = np.array(new_eid_list) 181 | full_aug_edge_weight = torch.ones(temp_eid, device=self.device) 182 | full_aug_edge_weight[new_eid_list - 1] = aug_edge_weight 183 | 184 | candidate_edge_feat = candidate_edge_feat.reshape(-1, candidate_edge_feat.shape[2]).detach().cpu().numpy() # [bs * k, 172] 185 | e_feat_aug = np.concatenate((self.e_feat, candidate_edge_feat), axis=0) 186 | edge_raw_embed_aug = torch.from_numpy(e_feat_aug.astype(np.float32)).to(self.device) 187 | 188 | ngh_finder_ori = self.encoder[0].neighbor_sampler 189 | self.encoder[0].set_neighbor_sampler(train_ngh_finder_aug) 190 | edge_raw_embed_ori = self.encoder[0].edge_raw_features 191 | self.encoder[0].edge_raw_features = edge_raw_embed_aug 192 | 193 | max_ts = np.array([self.max_train_ts_l + 1] * len(train_node_cut)) 194 | 195 | batch_src_node_embeddings, batch_dst_node_embeddings = self.encoder[0].compute_src_dst_node_temporal_embeddings(src_node_ids=src_l_cut, dst_node_ids=dst_l_cut, node_interact_times=ts_l_cut, num_neighbors=NUM_NEIGHBORS, time_gap=time_gap, full_aug_edge_weight=full_aug_edge_weight) 196 | batch_neg_src_node_embeddings, batch_neg_dst_node_embeddings = self.encoder[0].compute_src_dst_node_temporal_embeddings(src_node_ids=src_l_fake, dst_node_ids=dst_l_fake, node_interact_times=ts_l_cut, num_neighbors=NUM_NEIGHBORS, time_gap=time_gap, full_aug_edge_weight=full_aug_edge_weight) 197 | pos_score_ed = self.encoder[1](input_1=batch_src_node_embeddings, input_2=batch_dst_node_embeddings).squeeze(dim=-1).sigmoid() 198 | neg_score_ed = self.encoder[1](input_1=batch_neg_src_node_embeddings, input_2=batch_neg_dst_node_embeddings).squeeze(dim=-1).sigmoid() 199 | 200 | q = self.encoder[0].compute_node_temporal_embeddings(train_node_cut, max_ts, NUM_NEIGHBORS, time_gap, full_aug_edge_weight=full_aug_edge_weight) 201 | q = nn.functional.normalize(q, dim=1) 202 | 203 | # Recover 204 | self.encoder[0].set_neighbor_sampler(ngh_finder_ori) 205 | self.encoder[0].edge_raw_features = edge_raw_embed_ori 206 | 207 | with torch.no_grad(): # no gradient to keys 208 | self._momentum_update_key_encoder() # update the key encoder 209 | k = self.encoder_k[0].compute_node_temporal_embeddings(train_node_cut, max_ts, NUM_NEIGHBORS, time_gap) 210 | k = nn.functional.normalize(k, dim=1) 211 | 212 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 213 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 214 | logits = torch.cat([l_pos, l_neg], dim=1) 215 | logits /= self.tau 216 | labels = torch.zeros(logits.shape[0], dtype=torch.long, device=self.device) 217 | self._dequeue_and_enqueue(k) 218 | 219 | return pos_score, neg_score, pos_score_ed, neg_score_ed, logits, labels 220 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/models/GraphMixer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.modules import TimeEncoder 6 | from utils.utils import NeighborSampler 7 | 8 | 9 | class GraphMixer(nn.Module): 10 | 11 | def __init__(self, node_raw_features: np.ndarray, edge_raw_features: np.ndarray, neighbor_sampler: NeighborSampler, 12 | time_feat_dim: int, num_tokens: int, num_layers: int = 2, token_dim_expansion_factor: float = 0.5, 13 | channel_dim_expansion_factor: float = 4.0, dropout: float = 0.1, device: str = 'cpu'): 14 | """ 15 | TCL model. 16 | :param node_raw_features: ndarray, shape (num_nodes + 1, node_feat_dim) 17 | :param edge_raw_features: ndarray, shape (num_edges + 1, edge_feat_dim) 18 | :param neighbor_sampler: neighbor sampler 19 | :param time_feat_dim: int, dimension of time features (encodings) 20 | :param num_tokens: int, number of tokens 21 | :param num_layers: int, number of transformer layers 22 | :param token_dim_expansion_factor: float, dimension expansion factor for tokens 23 | :param channel_dim_expansion_factor: float, dimension expansion factor for channels 24 | :param dropout: float, dropout rate 25 | :param device: str, device 26 | """ 27 | super(GraphMixer, self).__init__() 28 | 29 | self.node_raw_features = torch.from_numpy(node_raw_features.astype(np.float32)).to(device) 30 | self.edge_raw_features = torch.from_numpy(edge_raw_features.astype(np.float32)).to(device) 31 | 32 | self.neighbor_sampler = neighbor_sampler 33 | self.node_feat_dim = self.node_raw_features.shape[1] 34 | self.edge_feat_dim = self.edge_raw_features.shape[1] 35 | self.time_feat_dim = time_feat_dim 36 | self.num_tokens = num_tokens 37 | self.num_layers = num_layers 38 | self.token_dim_expansion_factor = token_dim_expansion_factor 39 | self.channel_dim_expansion_factor = channel_dim_expansion_factor 40 | self.dropout = dropout 41 | self.device = device 42 | 43 | self.num_channels = self.edge_feat_dim 44 | # in GraphMixer, the time encoding function is not trainable 45 | self.time_encoder = TimeEncoder(time_dim=time_feat_dim, parameter_requires_grad=False) 46 | self.projection_layer = nn.Linear(self.edge_feat_dim + time_feat_dim, self.num_channels) 47 | 48 | self.mlp_mixers = nn.ModuleList([ 49 | MLPMixer(num_tokens=self.num_tokens, num_channels=self.num_channels, 50 | token_dim_expansion_factor=self.token_dim_expansion_factor, 51 | channel_dim_expansion_factor=self.channel_dim_expansion_factor, dropout=self.dropout) 52 | for _ in range(self.num_layers) 53 | ]) 54 | 55 | self.output_layer = nn.Linear(in_features=self.num_channels + self.node_feat_dim, out_features=self.node_feat_dim, bias=True) 56 | 57 | def compute_src_dst_node_temporal_embeddings(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, 58 | node_interact_times: np.ndarray, num_neighbors: int = 20, time_gap: int = 2000, 59 | full_aug_edge_weight=None): 60 | """ 61 | compute source and destination node temporal embeddings 62 | :param src_node_ids: ndarray, shape (batch_size, ) 63 | :param dst_node_ids: ndarray, shape (batch_size, ) 64 | :param node_interact_times: ndarray, shape (batch_size, ) 65 | :param num_neighbors: int, number of neighbors to sample for each node 66 | :param time_gap: int, time gap for neighbors to compute node features 67 | :return: 68 | """ 69 | # Tensor, shape (batch_size, node_feat_dim) 70 | src_node_embeddings = self.compute_node_temporal_embeddings(node_ids=src_node_ids, node_interact_times=node_interact_times, 71 | num_neighbors=num_neighbors, time_gap=time_gap, full_aug_edge_weight=full_aug_edge_weight) 72 | # Tensor, shape (batch_size, node_feat_dim) 73 | dst_node_embeddings = self.compute_node_temporal_embeddings(node_ids=dst_node_ids, node_interact_times=node_interact_times, 74 | num_neighbors=num_neighbors, time_gap=time_gap, full_aug_edge_weight=full_aug_edge_weight) 75 | 76 | return src_node_embeddings, dst_node_embeddings 77 | 78 | def compute_node_temporal_embeddings(self, node_ids: np.ndarray, node_interact_times: np.ndarray, 79 | num_neighbors: int = 20, time_gap: int = 2000, full_aug_edge_weight=None): 80 | """ 81 | given node ids node_ids, and the corresponding time node_interact_times, return the temporal embeddings of nodes in node_ids 82 | :param node_ids: ndarray, shape (batch_size, ), node ids 83 | :param node_interact_times: ndarray, shape (batch_size, ), node interaction times 84 | :param num_neighbors: int, number of neighbors to sample for each node 85 | :param time_gap: int, time gap for neighbors to compute node features 86 | :return: 87 | """ 88 | # link encoder ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 89 | # get temporal neighbors, including neighbor ids, edge ids and time information 90 | # neighbor_node_ids, ndarray, shape (batch_size, num_neighbors) 91 | # neighbor_edge_ids, ndarray, shape (batch_size, num_neighbors) 92 | # neighbor_times, ndarray, shape (batch_size, num_neighbors) 93 | neighbor_node_ids, neighbor_edge_ids, neighbor_times = \ 94 | self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids, 95 | node_interact_times=node_interact_times, 96 | num_neighbors=num_neighbors) 97 | 98 | # Tensor, shape (batch_size, num_neighbors, edge_feat_dim) 99 | nodes_edge_raw_features = self.edge_raw_features[torch.from_numpy(neighbor_edge_ids)] 100 | # Tensor, shape (batch_size, num_neighbors, time_feat_dim) 101 | nodes_neighbor_time_features = self.time_encoder(timestamps=torch.from_numpy(node_interact_times[:, np.newaxis] - neighbor_times).float().to(self.device)) 102 | 103 | # ndarray, set the time features to all zeros for the padded timestamp 104 | nodes_neighbor_time_features[torch.from_numpy(neighbor_node_ids == 0)] = 0.0 105 | 106 | # Tensor, shape (batch_size, num_neighbors, edge_feat_dim + time_feat_dim) 107 | combined_features = torch.cat([nodes_edge_raw_features, nodes_neighbor_time_features], dim=-1) 108 | 109 | if full_aug_edge_weight is not None: 110 | src_ngh_eidx_batch = torch.from_numpy(neighbor_edge_ids).reshape(-1) - 1 111 | aug_edge_weight = full_aug_edge_weight[src_ngh_eidx_batch] 112 | aug_edge_weight = aug_edge_weight.reshape(-1, num_neighbors) 113 | aug_edge_weight = aug_edge_weight.unsqueeze(2).expand_as(combined_features) 114 | combined_features = combined_features * aug_edge_weight 115 | 116 | # Tensor, shape (batch_size, num_neighbors, num_channels) 117 | combined_features = self.projection_layer(combined_features) 118 | 119 | for mlp_mixer in self.mlp_mixers: 120 | # Tensor, shape (batch_size, num_neighbors, num_channels) 121 | combined_features = mlp_mixer(input_tensor=combined_features) 122 | 123 | # Tensor, shape (batch_size, num_channels) 124 | combined_features = torch.mean(combined_features, dim=1) 125 | 126 | # node encoder ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 127 | # get temporal neighbors of nodes, including neighbor ids 128 | # time_gap_neighbor_node_ids, ndarray, shape (batch_size, time_gap) 129 | time_gap_neighbor_node_ids, time_gap_neighbor_edge_ids, _ = self.neighbor_sampler.get_historical_neighbors(node_ids=node_ids, 130 | node_interact_times=node_interact_times, 131 | num_neighbors=time_gap) 132 | 133 | # Tensor, shape (batch_size, time_gap, node_feat_dim) 134 | nodes_time_gap_neighbor_node_raw_features = self.node_raw_features[torch.from_numpy(time_gap_neighbor_node_ids)] 135 | 136 | if full_aug_edge_weight is not None: 137 | src_ngh_eidx_batch = torch.from_numpy(time_gap_neighbor_edge_ids).reshape(-1) - 1 138 | aug_edge_weight = full_aug_edge_weight[src_ngh_eidx_batch] 139 | aug_edge_weight = aug_edge_weight.reshape(-1, time_gap) 140 | aug_edge_weight = aug_edge_weight.unsqueeze(2).expand_as(nodes_time_gap_neighbor_node_raw_features) 141 | nodes_time_gap_neighbor_node_raw_features = nodes_time_gap_neighbor_node_raw_features * aug_edge_weight 142 | 143 | # Tensor, shape (batch_size, time_gap) 144 | valid_time_gap_neighbor_node_ids_mask = torch.from_numpy((time_gap_neighbor_node_ids > 0).astype(np.float32)) 145 | # note that if a node has no valid neighbor (whose valid_time_gap_neighbor_node_ids_mask are all zero), directly set the mask to -np.inf will make the 146 | # scores after softmax be nan. Therefore, we choose a very large negative number (-1e10) instead of -np.inf to tackle this case 147 | # Tensor, shape (batch_size, time_gap) 148 | valid_time_gap_neighbor_node_ids_mask[valid_time_gap_neighbor_node_ids_mask == 0] = -1e10 149 | # Tensor, shape (batch_size, time_gap) 150 | scores = torch.softmax(valid_time_gap_neighbor_node_ids_mask, dim=1).to(self.device) 151 | 152 | # Tensor, shape (batch_size, node_feat_dim), average over the time_gap neighbors 153 | nodes_time_gap_neighbor_node_agg_features = torch.mean(nodes_time_gap_neighbor_node_raw_features * scores.unsqueeze(dim=-1), dim=1) 154 | 155 | # Tensor, shape (batch_size, node_feat_dim), add features of nodes in node_ids 156 | output_node_features = nodes_time_gap_neighbor_node_agg_features + self.node_raw_features[torch.from_numpy(node_ids)] 157 | 158 | # Tensor, shape (batch_size, node_feat_dim) 159 | node_embeddings = self.output_layer(torch.cat([combined_features, output_node_features], dim=1)) 160 | 161 | return node_embeddings 162 | 163 | def set_neighbor_sampler(self, neighbor_sampler: NeighborSampler): 164 | """ 165 | set neighbor sampler to neighbor_sampler and reset the random state (for reproducing the results for uniform and time_interval_aware sampling) 166 | :param neighbor_sampler: NeighborSampler, neighbor sampler 167 | :return: 168 | """ 169 | self.neighbor_sampler = neighbor_sampler 170 | if self.neighbor_sampler.sample_neighbor_strategy in ['uniform', 'time_interval_aware']: 171 | assert self.neighbor_sampler.seed is not None 172 | self.neighbor_sampler.reset_random_state() 173 | 174 | 175 | class FeedForwardNet(nn.Module): 176 | 177 | def __init__(self, input_dim: int, dim_expansion_factor: float, dropout: float = 0.0): 178 | """ 179 | two-layered MLP with GELU activation function. 180 | :param input_dim: int, dimension of input 181 | :param dim_expansion_factor: float, dimension expansion factor 182 | :param dropout: float, dropout rate 183 | """ 184 | super(FeedForwardNet, self).__init__() 185 | 186 | self.input_dim = input_dim 187 | self.dim_expansion_factor = dim_expansion_factor 188 | self.dropout = dropout 189 | 190 | self.ffn = nn.Sequential(nn.Linear(in_features=input_dim, out_features=int(dim_expansion_factor * input_dim)), 191 | nn.GELU(), 192 | nn.Dropout(dropout), 193 | nn.Linear(in_features=int(dim_expansion_factor * input_dim), out_features=input_dim), 194 | nn.Dropout(dropout)) 195 | 196 | def forward(self, x: torch.Tensor): 197 | """ 198 | feed forward net forward process 199 | :param x: Tensor, shape (*, input_dim) 200 | :return: 201 | """ 202 | return self.ffn(x) 203 | 204 | 205 | class MLPMixer(nn.Module): 206 | 207 | def __init__(self, num_tokens: int, num_channels: int, token_dim_expansion_factor: float = 0.5, 208 | channel_dim_expansion_factor: float = 4.0, dropout: float = 0.0): 209 | """ 210 | MLP Mixer. 211 | :param num_tokens: int, number of tokens 212 | :param num_channels: int, number of channels 213 | :param token_dim_expansion_factor: float, dimension expansion factor for tokens 214 | :param channel_dim_expansion_factor: float, dimension expansion factor for channels 215 | :param dropout: float, dropout rate 216 | """ 217 | super(MLPMixer, self).__init__() 218 | 219 | self.token_norm = nn.LayerNorm(num_tokens) 220 | self.token_feedforward = FeedForwardNet(input_dim=num_tokens, dim_expansion_factor=token_dim_expansion_factor, 221 | dropout=dropout) 222 | 223 | self.channel_norm = nn.LayerNorm(num_channels) 224 | self.channel_feedforward = FeedForwardNet(input_dim=num_channels, dim_expansion_factor=channel_dim_expansion_factor, 225 | dropout=dropout) 226 | 227 | def forward(self, input_tensor: torch.Tensor): 228 | """ 229 | mlp mixer to compute over tokens and channels 230 | :param input_tensor: Tensor, shape (batch_size, num_tokens, num_channels) 231 | :return: 232 | """ 233 | # mix tokens 234 | # Tensor, shape (batch_size, num_channels, num_tokens) 235 | hidden_tensor = self.token_norm(input_tensor.permute(0, 2, 1)) 236 | # Tensor, shape (batch_size, num_tokens, num_channels) 237 | hidden_tensor = self.token_feedforward(hidden_tensor).permute(0, 2, 1) 238 | # Tensor, shape (batch_size, num_tokens, num_channels), residual connection 239 | output_tensor = hidden_tensor + input_tensor 240 | 241 | # mix channels 242 | # Tensor, shape (batch_size, num_tokens, num_channels) 243 | hidden_tensor = self.channel_norm(output_tensor) 244 | # Tensor, shape (batch_size, num_tokens, num_channels) 245 | hidden_tensor = self.channel_feedforward(hidden_tensor) 246 | # Tensor, shape (batch_size, num_tokens, num_channels), residual connection 247 | output_tensor = hidden_tensor + output_tensor 248 | 249 | return output_tensor 250 | -------------------------------------------------------------------------------- /TGAT+TGSL/TGAT.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class MergeLayer(torch.nn.Module): 8 | def __init__(self, dim1, dim2, dim3, dim4): 9 | super().__init__() 10 | #self.layer_norm = torch.nn.LayerNorm(dim1 + dim2) 11 | self.fc1 = torch.nn.Linear(dim1 + dim2, dim3) 12 | self.fc2 = torch.nn.Linear(dim3, dim4) 13 | self.act = torch.nn.ReLU(inplace=True) 14 | 15 | torch.nn.init.xavier_normal_(self.fc1.weight) 16 | torch.nn.init.xavier_normal_(self.fc2.weight) 17 | 18 | def forward(self, x1, x2): 19 | x = torch.cat([x1, x2], dim=1) 20 | #x = self.layer_norm(x) 21 | h = self.act(self.fc1(x)) 22 | return self.fc2(h) 23 | 24 | 25 | class ScaledDotProductAttention(torch.nn.Module): 26 | ''' Scaled Dot-Product Attention ''' 27 | 28 | def __init__(self, temperature, attn_dropout=0.1): 29 | super().__init__() 30 | self.temperature = temperature 31 | self.dropout = torch.nn.Dropout(attn_dropout) 32 | self.softmax = torch.nn.Softmax(dim=2) 33 | 34 | def forward(self, q, k, v, mask=None): 35 | 36 | attn = torch.bmm(q, k.transpose(1, 2)) 37 | attn = attn / self.temperature 38 | 39 | if mask is not None: 40 | attn = attn.masked_fill(mask, -1e10) 41 | 42 | attn = self.softmax(attn) # [n * b, l_q, l_k] 43 | attn = self.dropout(attn) # [n * b, l_v, d] 44 | 45 | output = torch.bmm(attn, v) 46 | 47 | return output, attn 48 | 49 | class MultiHeadAttention(nn.Module): 50 | ''' Multi-Head Attention module ''' 51 | 52 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 53 | super().__init__() 54 | 55 | self.n_head = n_head 56 | self.d_k = d_k 57 | self.d_v = d_v 58 | 59 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 60 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 61 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 62 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 63 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 64 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 65 | 66 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), attn_dropout=dropout) 67 | self.layer_norm = nn.LayerNorm(d_model) 68 | 69 | self.fc = nn.Linear(n_head * d_v, d_model) 70 | 71 | nn.init.xavier_normal_(self.fc.weight) 72 | 73 | self.dropout = nn.Dropout(dropout) 74 | 75 | 76 | def forward(self, q, k, v, mask=None): 77 | 78 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 79 | 80 | sz_b, len_q, _ = q.size() 81 | sz_b, len_k, _ = k.size() 82 | sz_b, len_v, _ = v.size() 83 | 84 | residual = q 85 | 86 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 87 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 88 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 89 | 90 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 91 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 92 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 93 | 94 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 95 | output, attn = self.attention(q, k, v, mask=mask) 96 | 97 | output = output.view(n_head, sz_b, len_q, d_v) 98 | 99 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 100 | 101 | output = self.dropout(self.fc(output)) 102 | output = self.layer_norm(output + residual) 103 | #output = self.layer_norm(output) 104 | 105 | return output, attn 106 | 107 | 108 | class MapBasedMultiHeadAttention(nn.Module): 109 | ''' Multi-Head Attention module ''' 110 | 111 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 112 | super().__init__() 113 | 114 | self.n_head = n_head 115 | self.d_k = d_k 116 | self.d_v = d_v 117 | 118 | self.wq_node_transform = nn.Linear(d_model, n_head * d_k, bias=False) 119 | self.wk_node_transform = nn.Linear(d_model, n_head * d_k, bias=False) 120 | self.wv_node_transform = nn.Linear(d_model, n_head * d_k, bias=False) 121 | 122 | self.layer_norm = nn.LayerNorm(d_model) 123 | 124 | self.fc = nn.Linear(n_head * d_v, d_model) 125 | 126 | self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) 127 | self.weight_map = nn.Linear(2 * d_k, 1, bias=False) 128 | 129 | nn.init.xavier_normal_(self.fc.weight) 130 | 131 | self.dropout = torch.nn.Dropout(dropout) 132 | self.softmax = torch.nn.Softmax(dim=2) 133 | 134 | self.dropout = nn.Dropout(dropout) 135 | 136 | 137 | def forward(self, q, k, v, mask=None): 138 | 139 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 140 | 141 | sz_b, len_q, _ = q.size() 142 | 143 | sz_b, len_k, _ = k.size() 144 | sz_b, len_v, _ = v.size() 145 | 146 | residual = q 147 | 148 | q = self.wq_node_transform(q).view(sz_b, len_q, n_head, d_k) 149 | 150 | k = self.wk_node_transform(k).view(sz_b, len_k, n_head, d_k) 151 | 152 | v = self.wv_node_transform(v).view(sz_b, len_v, n_head, d_v) 153 | 154 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 155 | q = torch.unsqueeze(q, dim=2) # [(n*b), lq, 1, dk] 156 | q = q.expand(q.shape[0], q.shape[1], len_k, q.shape[3]) # [(n*b), lq, lk, dk] 157 | 158 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 159 | k = torch.unsqueeze(k, dim=1) # [(n*b), 1, lk, dk] 160 | k = k.expand(k.shape[0], len_q, k.shape[2], k.shape[3]) # [(n*b), lq, lk, dk] 161 | 162 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 163 | 164 | mask = mask.repeat(n_head, 1, 1) # (n*b) x lq x lk 165 | 166 | ## Map based Attention 167 | #output, attn = self.attention(q, k, v, mask=mask) 168 | q_k = torch.cat([q, k], dim=3) # [(n*b), lq, lk, dk * 2] 169 | attn = self.weight_map(q_k).squeeze(dim=3) # [(n*b), lq, lk] 170 | 171 | if mask is not None: 172 | attn = attn.masked_fill(mask, -1e10) 173 | 174 | attn = self.softmax(attn) # [n * b, l_q, l_k] 175 | attn = self.dropout(attn) # [n * b, l_q, l_k] 176 | 177 | # [n * b, l_q, l_k] * [n * b, l_v, d_v] >> [n * b, l_q, d_v] 178 | output = torch.bmm(attn, v) 179 | 180 | output = output.view(n_head, sz_b, len_q, d_v) 181 | 182 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 183 | 184 | output = self.dropout(self.act(self.fc(output))) 185 | output = self.layer_norm(output + residual) 186 | 187 | return output, attn 188 | 189 | 190 | def expand_last_dim(x, num): 191 | view_size = list(x.size()) + [1] 192 | expand_size = list(x.size()) + [num] 193 | return x.view(view_size).expand(expand_size) 194 | 195 | 196 | class GraphMixerTE(torch.nn.Module): 197 | def __init__(self, time_dim=172): 198 | super(GraphMixerTE, self).__init__() 199 | self.w = nn.Parameter(torch.tensor([int(np.sqrt(time_dim)) ** (-(i - 1) / int(np.sqrt(time_dim))) for i in range(1, time_dim + 1)]), requires_grad=False) 200 | self.transformation = nn.Sequential( 201 | nn.Linear(time_dim, time_dim), 202 | nn.ReLU(inplace=True), 203 | nn.Linear(time_dim, time_dim) 204 | ) 205 | 206 | def forward(self, ts): 207 | batch_size = ts.size(0) 208 | seq_len = ts.size(1) 209 | 210 | ts = ts.view(batch_size, seq_len, 1) 211 | map_ts = ts * self.w.view(1, 1, -1) 212 | harmonic = torch.cos(map_ts) 213 | harmonic = self.transformation(harmonic) 214 | 215 | return harmonic 216 | 217 | 218 | class PosEncode(torch.nn.Module): 219 | def __init__(self, expand_dim, seq_len): 220 | super().__init__() 221 | 222 | self.pos_embeddings = nn.Embedding(num_embeddings=seq_len, embedding_dim=expand_dim) 223 | 224 | def forward(self, ts): 225 | # ts: [N, L] 226 | order = ts.argsort() 227 | ts_emb = self.pos_embeddings(order) 228 | return ts_emb 229 | 230 | 231 | class EmptyEncode(torch.nn.Module): 232 | def __init__(self, expand_dim): 233 | super().__init__() 234 | self.expand_dim = expand_dim 235 | 236 | def forward(self, ts): 237 | out = torch.zeros_like(ts).float() 238 | out = torch.unsqueeze(out, dim=-1) 239 | out = out.expand(out.shape[0], out.shape[1], self.expand_dim) 240 | return out 241 | 242 | 243 | class LSTMPool(torch.nn.Module): 244 | def __init__(self, feat_dim, edge_dim, time_dim): 245 | super(LSTMPool, self).__init__() 246 | self.feat_dim = feat_dim 247 | self.time_dim = time_dim 248 | self.edge_dim = edge_dim 249 | 250 | self.att_dim = feat_dim + edge_dim + time_dim 251 | 252 | self.act = torch.nn.ReLU(inplace=True) 253 | 254 | self.lstm = torch.nn.LSTM(input_size=self.att_dim, 255 | hidden_size=self.feat_dim, 256 | num_layers=1, 257 | batch_first=True) 258 | self.merger = MergeLayer(feat_dim, feat_dim, feat_dim, feat_dim) 259 | 260 | def forward(self, src, src_t, seq, seq_t, seq_e, mask): 261 | # seq [B, N, D] 262 | # mask [B, N] 263 | seq_x = torch.cat([seq, seq_e, seq_t], dim=2) 264 | 265 | _, (hn, _) = self.lstm(seq_x) 266 | 267 | hn = hn[-1, :, :] #hn.squeeze(dim=0) 268 | 269 | out = self.merger.forward(hn, src) 270 | return out, None 271 | 272 | 273 | class MeanPool(torch.nn.Module): 274 | def __init__(self, feat_dim, edge_dim): 275 | super(MeanPool, self).__init__() 276 | self.edge_dim = edge_dim 277 | self.feat_dim = feat_dim 278 | self.act = torch.nn.ReLU(inplace=True) 279 | self.merger = MergeLayer(edge_dim + feat_dim, feat_dim, feat_dim, feat_dim) 280 | 281 | def forward(self, src, src_t, seq, seq_t, seq_e, mask): 282 | # seq [B, N, D] 283 | # mask [B, N] 284 | src_x = src 285 | seq_x = torch.cat([seq, seq_e], dim=2) #[B, N, De + D] 286 | hn = seq_x.mean(dim=1) #[B, De + D] 287 | output = self.merger(hn, src_x) 288 | return output, None 289 | 290 | 291 | class AttnModel(torch.nn.Module): 292 | """Attention based temporal layers 293 | """ 294 | def __init__(self, feat_dim, edge_dim, time_dim, 295 | attn_mode='prod', n_head=2, drop_out=0.1): 296 | """ 297 | args: 298 | feat_dim: dim for the node features 299 | edge_dim: dim for the temporal edge features 300 | time_dim: dim for the time encoding 301 | attn_mode: choose from 'prod' and 'map' 302 | n_head: number of heads in attention 303 | drop_out: probability of dropping a neural. 304 | """ 305 | super(AttnModel, self).__init__() 306 | 307 | self.feat_dim = feat_dim 308 | self.time_dim = time_dim 309 | 310 | self.edge_in_dim = (feat_dim + edge_dim + time_dim) 311 | self.model_dim = self.edge_in_dim 312 | #self.edge_fc = torch.nn.Linear(self.edge_in_dim, self.feat_dim, bias=False) 313 | 314 | self.merger = MergeLayer(self.model_dim, feat_dim, feat_dim, feat_dim) 315 | 316 | #self.act = torch.nn.ReLU() 317 | # print(self.model_dim, n_head) 318 | assert(self.model_dim % n_head == 0) 319 | self.logger = logging.getLogger(__name__) 320 | self.attn_mode = attn_mode 321 | 322 | if attn_mode == 'prod': 323 | self.multi_head_target = MultiHeadAttention(n_head, 324 | d_model=self.model_dim, 325 | d_k=self.model_dim // n_head, 326 | d_v=self.model_dim // n_head, 327 | dropout=drop_out) 328 | self.logger.info('Using scaled prod attention') 329 | 330 | elif attn_mode == 'map': 331 | self.multi_head_target = MapBasedMultiHeadAttention(n_head, 332 | d_model=self.model_dim, 333 | d_k=self.model_dim // n_head, 334 | d_v=self.model_dim // n_head, 335 | dropout=drop_out) 336 | self.logger.info('Using map based attention') 337 | else: 338 | raise ValueError('attn_mode can only be prod or map') 339 | 340 | 341 | def forward(self, src, src_t, seq, seq_t, seq_e, mask, weight=None): 342 | """"Attention based temporal attention forward pass 343 | args: 344 | src: float Tensor of shape [B, D] 345 | src_t: float Tensor of shape [B, Dt], Dt == D 346 | seq: float Tensor of shape [B, N, D] 347 | seq_t: float Tensor of shape [B, N, Dt] 348 | seq_e: float Tensor of shape [B, N, De], De == D 349 | mask: boolean Tensor of shape [B, N], where the true value indicate a null value in the sequence. 350 | 351 | returns: 352 | output, weight 353 | 354 | output: float Tensor of shape [B, D] 355 | weight: float Tensor of shape [B, N] 356 | """ 357 | 358 | src_ext = torch.unsqueeze(src, dim=1) # src [B, 1, D] 359 | src_e_ph = torch.zeros_like(src_ext) 360 | q = torch.cat([src_ext, src_e_ph, src_t], dim=2) # [B, 1, D + De + Dt] -> [B, 1, D] 361 | k = torch.cat([seq, seq_e, seq_t], dim=2) # [B, 1, D + De + Dt] -> [B, 1, D] 362 | v = k 363 | if weight is not None: 364 | weight = weight.unsqueeze(2).expand_as(v) 365 | v = v * weight 366 | # print(q.shape, k.shape) 367 | 368 | mask = torch.unsqueeze(mask, dim=2) # mask [B, N, 1] 369 | mask = mask.permute([0, 2, 1]) #mask [B, 1, N] 370 | 371 | # # target-attention 372 | output, attn = self.multi_head_target(q=q, k=k, v=v, mask=mask) # output: [B, 1, D + Dt], attn: [B, 1, N] 373 | output = output.squeeze() 374 | attn = attn.squeeze() 375 | 376 | output = self.merger(output, src) 377 | return output, attn 378 | 379 | 380 | class TGAN(torch.nn.Module): 381 | def __init__(self, ngh_finder, n_feat, e_feat, attn_mode='prod', use_time='time', agg_method='attn', num_layers=3, 382 | n_head=4, null_idx=0, drop_out=0.1, seq_len=None): 383 | super(TGAN, self).__init__() 384 | self.num_layers = num_layers 385 | self.ngh_finder = ngh_finder 386 | self.null_idx = null_idx 387 | self.logger = logging.getLogger(__name__) 388 | self.n_feat_th = torch.nn.Parameter(torch.from_numpy(n_feat.astype(np.float32))) 389 | self.e_feat_th = torch.nn.Parameter(torch.from_numpy(e_feat.astype(np.float32))) 390 | self.edge_raw_embed = torch.nn.Embedding.from_pretrained(self.e_feat_th, padding_idx=0, freeze=True) 391 | self.node_raw_embed = torch.nn.Embedding.from_pretrained(self.n_feat_th, padding_idx=0, freeze=True) 392 | 393 | self.feat_dim = self.n_feat_th.shape[1] 394 | 395 | self.n_feat_dim = self.feat_dim 396 | self.e_feat_dim = self.feat_dim 397 | self.model_dim = self.feat_dim 398 | 399 | if agg_method == 'attn': 400 | self.logger.info('Aggregation uses attention model') 401 | self.attn_model_list = torch.nn.ModuleList([AttnModel(self.feat_dim, 402 | self.feat_dim, 403 | self.feat_dim, 404 | attn_mode=attn_mode, 405 | n_head=n_head, 406 | drop_out=drop_out) for _ in range(num_layers)]) 407 | elif agg_method == 'lstm': 408 | self.logger.info('Aggregation uses LSTM model') 409 | self.attn_model_list = torch.nn.ModuleList([LSTMPool(self.feat_dim, 410 | self.feat_dim, 411 | self.feat_dim) for _ in range(num_layers)]) 412 | elif agg_method == 'mean': 413 | self.logger.info('Aggregation uses constant mean model') 414 | self.attn_model_list = torch.nn.ModuleList([MeanPool(self.feat_dim, 415 | self.feat_dim) for _ in range(num_layers)]) 416 | else: 417 | 418 | raise ValueError('invalid agg_method value, use attn or lstm') 419 | 420 | if use_time == 'time': 421 | self.logger.info('Using time encoding') 422 | self.time_encoder = GraphMixerTE(time_dim=self.n_feat_th.shape[1]) 423 | elif use_time == 'pos': 424 | assert(seq_len is not None) 425 | self.logger.info('Using positional encoding') 426 | self.time_encoder = PosEncode(expand_dim=self.n_feat_th.shape[1], seq_len=seq_len) 427 | elif use_time == 'empty': 428 | self.logger.info('Using empty encoding') 429 | self.time_encoder = EmptyEncode(expand_dim=self.n_feat_th.shape[1]) 430 | else: 431 | raise ValueError('invalid time option!') 432 | 433 | self.affinity_score = MergeLayer(self.feat_dim, self.feat_dim, self.feat_dim, 1) 434 | 435 | def forward(self, src_idx_l, target_idx_l, cut_time_l, num_neighbors=20): 436 | src_embed = self.tem_conv(src_idx_l, cut_time_l, self.num_layers, num_neighbors) 437 | target_embed = self.tem_conv(target_idx_l, cut_time_l, self.num_layers, num_neighbors) 438 | score = self.affinity_score(src_embed, target_embed).squeeze(dim=-1) 439 | 440 | return score 441 | 442 | def contrast(self, src_idx_l, target_idx_l, background_idx_l, cut_time_l, num_neighbors=20, full_aug_edge_weight=None): 443 | src_embed = self.tem_conv(src_idx_l, cut_time_l, self.num_layers, num_neighbors, full_aug_edge_weight=full_aug_edge_weight) 444 | target_embed = self.tem_conv(target_idx_l, cut_time_l, self.num_layers, num_neighbors, full_aug_edge_weight=full_aug_edge_weight) 445 | background_embed = self.tem_conv(background_idx_l, cut_time_l, self.num_layers, num_neighbors, full_aug_edge_weight=full_aug_edge_weight) 446 | pos_score = self.affinity_score(src_embed, target_embed).squeeze(dim=-1) 447 | neg_score = self.affinity_score(src_embed, background_embed).squeeze(dim=-1) 448 | return pos_score.sigmoid(), neg_score.sigmoid() 449 | 450 | def tem_conv(self, src_idx_l, cut_time_l, curr_layers, num_neighbors=20, full_aug_edge_weight=None): 451 | assert(curr_layers >= 0) 452 | 453 | device = self.n_feat_th.device 454 | 455 | batch_size = len(src_idx_l) 456 | 457 | src_node_batch_th = torch.from_numpy(src_idx_l).long().to(device) 458 | cut_time_l_th = torch.from_numpy(cut_time_l).float().to(device) 459 | 460 | cut_time_l_th = torch.unsqueeze(cut_time_l_th, dim=1) 461 | # query node always has the start time -> time span == 0 462 | src_node_t_embed = self.time_encoder(torch.zeros_like(cut_time_l_th)) 463 | src_node_feat = self.node_raw_embed(src_node_batch_th) 464 | 465 | if curr_layers == 0: 466 | return src_node_feat 467 | else: 468 | src_node_conv_feat = self.tem_conv(src_idx_l, 469 | cut_time_l, 470 | curr_layers=curr_layers - 1, 471 | num_neighbors=num_neighbors, 472 | full_aug_edge_weight=full_aug_edge_weight) 473 | 474 | 475 | src_ngh_node_batch, src_ngh_eidx_batch, src_ngh_t_batch = self.ngh_finder.get_temporal_neighbor( 476 | src_idx_l, 477 | cut_time_l, 478 | num_neighbors=num_neighbors) 479 | 480 | src_ngh_node_batch_th = torch.from_numpy(src_ngh_node_batch).long().to(device) 481 | src_ngh_eidx_batch = torch.from_numpy(src_ngh_eidx_batch).long().to(device) 482 | 483 | src_ngh_t_batch_delta = cut_time_l[:, np.newaxis] - src_ngh_t_batch 484 | src_ngh_t_batch_th = torch.from_numpy(src_ngh_t_batch_delta).float().to(device) 485 | 486 | # get previous layer's node features 487 | src_ngh_node_batch_flat = src_ngh_node_batch.flatten() #reshape(batch_size, -1) 488 | src_ngh_t_batch_flat = src_ngh_t_batch.flatten() #reshape(batch_size, -1) 489 | src_ngh_node_conv_feat = self.tem_conv(src_ngh_node_batch_flat, 490 | src_ngh_t_batch_flat, 491 | curr_layers=curr_layers - 1, 492 | num_neighbors=num_neighbors, 493 | full_aug_edge_weight=full_aug_edge_weight) 494 | src_ngh_feat = src_ngh_node_conv_feat.view(batch_size, num_neighbors, -1) 495 | 496 | # get edge time features and node features 497 | src_ngh_t_embed = self.time_encoder(src_ngh_t_batch_th) 498 | src_ngn_edge_feat = self.edge_raw_embed(src_ngh_eidx_batch) 499 | 500 | if full_aug_edge_weight is not None: 501 | src_ngh_eidx_batch = src_ngh_eidx_batch.reshape(-1) - 1 502 | aug_edge_weight = full_aug_edge_weight[src_ngh_eidx_batch] 503 | aug_edge_weight = aug_edge_weight.reshape(-1, num_neighbors) 504 | 505 | # print(src_ngh_eidx_batch.shape) 506 | # print(src_ngn_edge_feat.shape) 507 | 508 | # attention aggregation 509 | mask = src_ngh_node_batch_th == 0 510 | attn_m = self.attn_model_list[curr_layers - 1] 511 | 512 | # print(src_ngn_edge_feat.shape) 513 | local, weight = attn_m(src_node_conv_feat, 514 | src_node_t_embed, 515 | src_ngh_feat, 516 | src_ngh_t_embed, 517 | src_ngn_edge_feat, 518 | mask, 519 | weight=aug_edge_weight if full_aug_edge_weight is not None else None) 520 | return local -------------------------------------------------------------------------------- /TGAT+TGSL/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import random 5 | import argparse 6 | import torch 7 | import dgl 8 | import pandas as pd 9 | import numpy as np 10 | import pickle 11 | 12 | from sklearn.metrics import average_precision_score, roc_auc_score 13 | 14 | from TGAT import TGAN 15 | from neighbor_finder import NeighborFinder 16 | from utils import EarlyStopMonitor, RandEdgeSampler, set_seed, get_device, show_time 17 | from config import * 18 | from view_learner import ETGNN, TimeMapping 19 | from MTL import MTL 20 | 21 | torch.autograd.set_detect_anomaly(True) 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--dataset', type=str, help='dataset', default='wikipedia') 24 | parser.add_argument('--cuda', type=str, required=True, help='idx for the gpu to use') 25 | parser.add_argument('--prefix', type=str, default='', help='prefix to name the checkpoints') 26 | parser.add_argument('--agg_method', type=str, choices=['attn', 'lstm', 'mean'], help='local aggregation method', default='attn') 27 | parser.add_argument('--attn_mode', type=str, choices=['prod', 'map'], default='prod', help='use dot product attention or mapping based') 28 | parser.add_argument('--time', type=str, choices=['time', 'pos', 'empty'], help='how to use time information', default='time') 29 | parser.add_argument('--uniform', action='store_true', help='take uniform sampling from temporal neighbors') 30 | parser.add_argument('--seed', type=int, default=2023) 31 | parser.add_argument('--patience', type=int, default=3) 32 | parser.add_argument('--tolerance', type=float, default=1e-3) 33 | parser.add_argument('--tau', type=float, default=0.1) 34 | parser.add_argument('--gtau', type=float, default=1.0) 35 | parser.add_argument('--K', type=int, default=512) 36 | parser.add_argument('--batch_size', type=int, default=200) 37 | parser.add_argument('--coe', type=float, default=0.2) 38 | parser.add_argument('--ratio', type=float, default=0.02) 39 | parser.add_argument('--infer_bs', type=int, default=200) 40 | parser.add_argument('--can_nn', type=int, default=20) 41 | parser.add_argument('--rnn_nn', type=int, default=20) 42 | parser.add_argument('--rnn_layer', type=int, default=1) 43 | parser.add_argument('--can_type', type=str, choices=['1st', '3rd', 'random', 'mix'], default='3rd') 44 | args = parser.parse_args() 45 | 46 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda) 47 | 48 | if args.dataset == 'wikipedia': 49 | config = WikiConfig() 50 | elif args.dataset == 'reddit': 51 | config = RedditConfig() 52 | elif args.dataset == 'escorts': 53 | config = ESConfig() 54 | else: 55 | raise Exception('Dataset Error') 56 | 57 | DATA = args.dataset 58 | AGG_METHOD = args.agg_method 59 | ATTN_MODE = args.attn_mode 60 | UNIFORM = args.uniform 61 | USE_TIME = args.time 62 | 63 | BATCH_SIZE = config.BATCH_SIZE 64 | NUM_EPOCH = config.EPOCH 65 | LEARNING_RATE = config.LR 66 | 67 | NUM_NEIGHBORS = config.N_DEGREE 68 | NUM_HEADS = config.N_HEAD 69 | NUM_LAYER = config.N_LAYER 70 | DROP_OUT = config.DROPOUT 71 | NODE_DIM = config.NODE_DIM 72 | TIME_DIM = config.TIME_DIM 73 | SEED = args.seed 74 | PATIENCE = args.patience 75 | TOLERANCE = args.tolerance 76 | TAU = args.tau 77 | GTAU = args.gtau 78 | K = args.K 79 | SSL_BATCH_SIZE = args.batch_size 80 | COE = args.coe 81 | RATIO = args.ratio 82 | INFER_BS = args.infer_bs 83 | NUM_CAN_NN = args.can_nn 84 | NUM_RNN_NN = args.rnn_nn 85 | CAN_TYPE = args.can_type 86 | NUM_RNN_LAYER = args.rnn_layer 87 | 88 | NUM_NEG = 1 89 | SEQ_LEN = NUM_NEIGHBORS 90 | 91 | set_seed(SEED) 92 | 93 | os.makedirs(f"./saved_models/", exist_ok=True) 94 | os.makedirs(f"./saved_checkpoints/", exist_ok=True) 95 | MODEL_SAVE_PATH = f'./saved_models/{args.prefix}-{args.agg_method}-{args.attn_mode}-{args.dataset}.pth' 96 | get_checkpoint_path = lambda epoch: f'./saved_checkpoints/{args.prefix}-{args.agg_method}-{args.attn_mode}-{args.dataset}-{epoch}.pth' 97 | 98 | 99 | 100 | def eval_one_epoch(hint, tgan, view_learner, edge_rnn, sample_time_encoder, sampler, src, dst, ts, label): 101 | val_acc, val_ap, val_f1, val_auc = [], [], [], [] 102 | with torch.no_grad(): 103 | tgan = tgan.eval() 104 | view_learner = view_learner.eval() 105 | edge_rnn = edge_rnn.eval() 106 | sample_time_encoder = sample_time_encoder.eval() 107 | TEST_BATCH_SIZE = INFER_BS 108 | num_test_instance = len(src) 109 | num_test_batch = math.ceil(num_test_instance / TEST_BATCH_SIZE) 110 | train_edge_feat = view_learner(g) 111 | full_edge_feat = torch.zeros(e_feat.shape[0] + 1, train_edge_feat.shape[1], device=device) 112 | full_edge_feat[train_e_idx_l - 1] = train_edge_feat 113 | for k in range(num_test_batch): 114 | s_idx = k * TEST_BATCH_SIZE 115 | e_idx = min(num_test_instance - 1, s_idx + TEST_BATCH_SIZE) 116 | src_l_cut = src[s_idx:e_idx] 117 | dst_l_cut = dst[s_idx:e_idx] 118 | ts_l_cut = ts[s_idx:e_idx] 119 | size = len(src_l_cut) 120 | src_l_fake, dst_l_fake = sampler.sample(size) 121 | 122 | train_node_cut = np.array(list(set(np.append(src_l_cut, dst_l_cut)).intersection(train_node_set))) 123 | max_ts = np.array([max_train_ts_l + 1] * len(train_node_cut)) 124 | neighbor_finder = NeighborFinder(adj_list, uniform=True) 125 | neighbor_node_idx, neighbor_edge_idx, neighbor_ts = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=NUM_RNN_NN) 126 | 127 | neighbor_edge_idx = neighbor_edge_idx.reshape(-1) - 1 128 | neighbor_edge_feat = full_edge_feat[neighbor_edge_idx] # [bs * 20, 172] 129 | neighbor_edge_feat = neighbor_edge_feat.reshape(neighbor_node_idx.shape[0], neighbor_node_idx.shape[1], -1) # [bs, 20, 172] 130 | 131 | neighbor_edge_feat = neighbor_edge_feat.transpose(0, 1) 132 | _, (h_n, _) = edge_rnn(neighbor_edge_feat) 133 | context_vec = h_n[-1] # [bs, 172] 134 | 135 | neighbor_finder.uniform = True 136 | if CAN_TYPE == '1st': 137 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=NUM_CAN_NN) 138 | 139 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 140 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 141 | elif CAN_TYPE == '3rd': 142 | candidate_node_idx, candidate_edge_idx, candidate_ts = neighbor_finder.find_k_hop(3, train_node_cut, max_ts, num_neighbors=NUM_CAN_NN) 143 | candidate_node_idx = candidate_node_idx[-1].reshape(train_node_cut.shape[0], -1) 144 | candidate_edge_idx = candidate_edge_idx[-1].reshape(train_node_cut.shape[0], -1) 145 | candidate_ts = candidate_ts[-1].reshape(train_node_cut.shape[0], -1) 146 | 147 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 148 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 149 | elif CAN_TYPE == 'random': 150 | candidate_node_idx = np.random.choice(np.array(list(train_node_set)), size=train_node_cut.shape[0] * NUM_CAN_NN, replace=True).reshape(train_node_cut.shape[0], -1) # [bs, 20] 151 | candidate_edge_idx = np.array([0] * (train_node_cut.shape[0] * NUM_CAN_NN)).reshape(train_node_cut.shape[0], -1) 152 | candidate_ts = np.random.rand(train_node_cut.shape[0], NUM_CAN_NN) * max_train_ts_l 153 | 154 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 155 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 156 | elif CAN_TYPE == 'mix': 157 | candidate_node_idx_1st, candidate_edge_idx_1st, candidate_ts_1st = neighbor_finder.get_temporal_neighbor(train_node_cut, max_ts, num_neighbors=NUM_CAN_NN) 158 | candidate_node_idx_3rd, candidate_edge_idx_3rd, candidate_ts_3rd = neighbor_finder.find_k_hop(3, train_node_cut, max_ts, num_neighbors=NUM_CAN_NN) 159 | candidate_node_idx_3rd = candidate_node_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 160 | candidate_edge_idx_3rd = candidate_edge_idx_3rd[-1].reshape(train_node_cut.shape[0], -1) 161 | candidate_ts_3rd = candidate_ts_3rd[-1].reshape(train_node_cut.shape[0], -1) 162 | 163 | candidate_node_idx = np.concatenate((candidate_node_idx_1st, candidate_node_idx_3rd), axis=-1) 164 | candidate_edge_idx = np.concatenate((candidate_edge_idx_1st, candidate_edge_idx_3rd), axis=-1) 165 | candidate_ts = np.concatenate((candidate_ts_1st, candidate_ts_3rd), axis=-1) 166 | 167 | src_node_idx_aug = np.repeat(train_node_cut.reshape(train_node_cut.shape[0], 1), candidate_node_idx.shape[1], axis=1) # [bs, 20] 168 | dst_node_idx_aug = candidate_node_idx # [bs, 20] 169 | else: 170 | pass 171 | 172 | candidate_edge_idx = candidate_edge_idx.reshape(-1) - 1 173 | candidate_edge_feat = full_edge_feat[candidate_edge_idx] # [bs * 20, 172] 174 | candidate_edge_feat = candidate_edge_feat.reshape(candidate_node_idx.shape[0], candidate_node_idx.shape[1], -1) # [bs, 20, 172] 175 | 176 | ts_aug = np.random.rand(candidate_ts.shape[0], candidate_ts.shape[1]) * max_train_ts_l 177 | delta_ts_sample = ts_aug - candidate_ts 178 | delta_ts_sample_context = ts_aug - np.ones_like(candidate_ts) * max_train_ts_l 179 | delta_ts_sample_embedding = sample_time_encoder(torch.tensor(delta_ts_sample.reshape(-1, 1), dtype=torch.float32).to(device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 180 | delta_ts_sample_context_embedding = sample_time_encoder(torch.tensor(delta_ts_sample_context.reshape(-1, 1), dtype=torch.float32).to(device)).reshape(ts_aug.shape[0], ts_aug.shape[1], -1) 181 | 182 | context_vec = context_vec.unsqueeze(1).expand_as(candidate_edge_feat) 183 | context_vec = context_vec * delta_ts_sample_context_embedding 184 | candidate_edge_feat = candidate_edge_feat * delta_ts_sample_embedding 185 | aug_edge_logits = torch.sum(context_vec * candidate_edge_feat, dim=-1) # [bs, 20, 1] 186 | 187 | # Gumble-Top-K 188 | bias = 0.0 + 0.0001 # If bias is 0, we run into problems 189 | eps = (bias - (1 - bias)) * torch.rand(aug_edge_logits.size()) + (1 - bias) 190 | gate_inputs = torch.log(eps) - torch.log(1 - eps) 191 | gate_inputs = gate_inputs.to(aug_edge_logits.device) 192 | gate_inputs = (gate_inputs + aug_edge_logits) / GTAU 193 | z = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 194 | __, sorted_idx = z.sort(dim=-1, descending=True) 195 | keep = sorted_idx[:, :int(RATIO * z.size(1))] # [bs, k] 196 | 197 | aug_edge_logits = torch.sigmoid(gate_inputs).squeeze() # [bs, 20] 198 | aug_edge_weight = torch.gather(aug_edge_logits, dim=1, index=keep) # [bs, k] 199 | ts_aug = torch.gather(torch.tensor(ts_aug, device=device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 200 | src_node_idx_aug = torch.gather(torch.tensor(src_node_idx_aug, device=device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 201 | dst_node_idx_aug = torch.gather(torch.tensor(dst_node_idx_aug, device=device), dim=1, index=keep).detach().cpu().numpy() # [bs, k] 202 | candidate_edge_feat = torch.gather(candidate_edge_feat, dim=1, index=keep.unsqueeze(2).repeat(1, 1, candidate_edge_feat.shape[2])) # [bs, k, 172] 203 | 204 | aug_edge_weight = aug_edge_weight.reshape(-1) 205 | ts_aug = ts_aug.reshape(-1) 206 | src_node_idx_aug = src_node_idx_aug.reshape(-1) 207 | dst_node_idx_aug = dst_node_idx_aug.reshape(-1) 208 | 209 | temp_eid = e_feat.shape[0] 210 | new_eid_list = [] 211 | adj_list_aug = pickle.loads(full_adj_list_pickle) 212 | for src_aug, dst_aug, ts_aug_temp in zip(src_node_idx_aug, dst_node_idx_aug, ts_aug): 213 | adj_list_aug[src_aug].append((dst_aug, temp_eid, ts_aug_temp)) 214 | adj_list_aug[dst_aug].append((src_aug, temp_eid, ts_aug_temp)) 215 | new_eid_list.append(temp_eid) 216 | temp_eid += 1 217 | train_ngh_finder_aug = NeighborFinder(adj_list_aug, uniform=tgan.ngh_finder.uniform) 218 | 219 | new_eid_list = np.array(new_eid_list) 220 | full_aug_edge_weight = torch.ones(temp_eid, device=device) 221 | full_aug_edge_weight[new_eid_list - 1] = aug_edge_weight 222 | 223 | candidate_edge_feat = candidate_edge_feat.reshape(-1, candidate_edge_feat.shape[ 224 | 2]).detach().cpu().numpy() # [bs * k, 172] 225 | e_feat_aug = np.concatenate((e_feat, candidate_edge_feat), axis=0) 226 | e_feat_th_aug = torch.nn.Parameter(torch.from_numpy(e_feat_aug.astype(np.float32))) 227 | edge_raw_embed_aug = torch.nn.Embedding.from_pretrained(e_feat_th_aug, padding_idx=0, freeze=True).to( 228 | device) 229 | 230 | ngh_finder_ori = tgan.ngh_finder 231 | tgan.ngh_finder = train_ngh_finder_aug 232 | edge_raw_embed_ori = tgan.edge_raw_embed 233 | tgan.edge_raw_embed = edge_raw_embed_aug 234 | 235 | pos_prob, neg_prob = tgan.contrast(src_l_cut, dst_l_cut, dst_l_fake, ts_l_cut, NUM_NEIGHBORS, full_aug_edge_weight=full_aug_edge_weight) 236 | 237 | # Recover 238 | tgan.ngh_finder = ngh_finder_ori 239 | tgan.edge_raw_embed = edge_raw_embed_ori 240 | 241 | pred_score = np.concatenate([(pos_prob).cpu().numpy(), (neg_prob).cpu().numpy()]) 242 | pred_label = pred_score > 0.5 243 | true_label = np.concatenate([np.ones(size), np.zeros(size)]) 244 | 245 | val_acc.append((pred_label == true_label).mean()) 246 | val_ap.append(average_precision_score(true_label, pred_score)) 247 | val_auc.append(roc_auc_score(true_label, pred_score)) 248 | 249 | return np.mean(val_acc), np.mean(val_ap), np.mean(val_f1), np.mean(val_auc) 250 | 251 | # Load data and train val test split 252 | g_df = pd.read_csv('./processed/ml_{}.csv'.format(DATA)) 253 | e_feat = np.load('./processed/ml_{}.npy'.format(DATA)) 254 | n_feat = np.load('./processed/ml_{}_node.npy'.format(DATA)) 255 | 256 | if e_feat.shape[1] < 172: 257 | edge_zero_padding = np.zeros((e_feat.shape[0], 172 - e_feat.shape[1])) 258 | e_feat = np.concatenate([e_feat, edge_zero_padding], axis=1) 259 | if n_feat.shape[1] < 172: 260 | node_zero_padding = np.zeros((n_feat.shape[0], 172 - n_feat.shape[1])) 261 | n_feat = np.concatenate([n_feat, node_zero_padding], axis=1) 262 | 263 | val_time, test_time = list(np.quantile(g_df.ts, [0.70, 0.85])) 264 | 265 | src_l = g_df.u.values 266 | dst_l = g_df.i.values 267 | e_idx_l = g_df.idx.values 268 | label_l = g_df.label.values 269 | ts_l = g_df.ts.values 270 | 271 | max_src_index = src_l.max() 272 | max_idx = max(src_l.max(), dst_l.max()) 273 | 274 | total_node_set = set(np.unique(np.hstack([g_df.u.values, g_df.i.values]))) 275 | num_total_unique_nodes = len(total_node_set) 276 | 277 | mask_node_set = set(random.sample(set(src_l[ts_l > val_time]).union(set(dst_l[ts_l > val_time])), int(0.1 * num_total_unique_nodes))) 278 | mask_src_flag = g_df.u.map(lambda x: x in mask_node_set).values 279 | mask_dst_flag = g_df.i.map(lambda x: x in mask_node_set).values 280 | none_node_flag = (1 - mask_src_flag) * (1 - mask_dst_flag) 281 | 282 | valid_train_flag = (ts_l <= val_time) * (none_node_flag > 0) 283 | 284 | train_src_l = src_l[valid_train_flag] 285 | train_dst_l = dst_l[valid_train_flag] 286 | train_ts_l = ts_l[valid_train_flag] 287 | train_e_idx_l = e_idx_l[valid_train_flag] 288 | train_label_l = label_l[valid_train_flag] 289 | 290 | # define the new nodes sets for testing inductiveness of the model 291 | train_node_set = set(train_src_l).union(train_dst_l) 292 | assert(len(train_node_set - mask_node_set) == len(train_node_set)) 293 | new_node_set = total_node_set - train_node_set 294 | 295 | # select validation and test dataset 296 | valid_val_flag = (ts_l <= test_time) * (ts_l > val_time) 297 | valid_test_flag = ts_l > test_time 298 | 299 | is_new_node_edge = np.array([(a in new_node_set or b in new_node_set) for a, b in zip(src_l, dst_l)]) 300 | nn_val_flag = valid_val_flag * is_new_node_edge 301 | nn_test_flag = valid_test_flag * is_new_node_edge 302 | 303 | # validation and test with all edges 304 | val_src_l = src_l[valid_val_flag] 305 | val_dst_l = dst_l[valid_val_flag] 306 | val_ts_l = ts_l[valid_val_flag] 307 | val_e_idx_l = e_idx_l[valid_val_flag] 308 | val_label_l = label_l[valid_val_flag] 309 | 310 | test_src_l = src_l[valid_test_flag] 311 | test_dst_l = dst_l[valid_test_flag] 312 | test_ts_l = ts_l[valid_test_flag] 313 | test_e_idx_l = e_idx_l[valid_test_flag] 314 | test_label_l = label_l[valid_test_flag] 315 | # validation and test with edges that at least has one new node (not in training set) 316 | nn_val_src_l = src_l[nn_val_flag] 317 | nn_val_dst_l = dst_l[nn_val_flag] 318 | nn_val_ts_l = ts_l[nn_val_flag] 319 | nn_val_e_idx_l = e_idx_l[nn_val_flag] 320 | nn_val_label_l = label_l[nn_val_flag] 321 | 322 | nn_test_src_l = src_l[nn_test_flag] 323 | nn_test_dst_l = dst_l[nn_test_flag] 324 | nn_test_ts_l = ts_l[nn_test_flag] 325 | nn_test_e_idx_l = e_idx_l[nn_test_flag] 326 | nn_test_label_l = label_l[nn_test_flag] 327 | 328 | # Initialize the data structure for graph and edge sampling 329 | # build the graph for fast query 330 | # graph only contains the training data (with 10% nodes removal) 331 | adj_list = [[] for _ in range(max_idx + 1)] 332 | for src, dst, eidx, ts in zip(train_src_l, train_dst_l, train_e_idx_l, train_ts_l): 333 | adj_list[src].append((dst, eidx, ts)) 334 | adj_list[dst].append((src, eidx, ts)) 335 | train_ngh_finder = NeighborFinder(adj_list, uniform=UNIFORM) 336 | 337 | # full graph with all the data for the test and validation purpose 338 | full_adj_list = [[] for _ in range(max_idx + 1)] 339 | for src, dst, eidx, ts in zip(src_l, dst_l, e_idx_l, ts_l): 340 | full_adj_list[src].append((dst, eidx, ts)) 341 | full_adj_list[dst].append((src, eidx, ts)) 342 | full_ngh_finder = NeighborFinder(full_adj_list, uniform=UNIFORM) 343 | 344 | adj_list_pickle = pickle.dumps(adj_list, -1) 345 | full_adj_list_pickle = pickle.dumps(full_adj_list, -1) 346 | 347 | 348 | train_rand_sampler = RandEdgeSampler(train_src_l, train_dst_l) 349 | val_rand_sampler = RandEdgeSampler(src_l, dst_l) 350 | nn_val_rand_sampler = RandEdgeSampler(nn_val_src_l, nn_val_dst_l) 351 | test_rand_sampler = RandEdgeSampler(src_l, dst_l) 352 | nn_test_rand_sampler = RandEdgeSampler(nn_test_src_l, nn_test_dst_l) 353 | 354 | 355 | device = get_device(index=0) 356 | max_train_ts_l = max(train_ts_l) 357 | ''' 358 | +++++++++++++++++++++++++++++ MTL Stage +++++++++++++++++++++++++++++ 359 | ''' 360 | # DGL Graph Construction 361 | g = dgl.graph((train_src_l, train_dst_l)) 362 | 363 | ndata = [] 364 | for ind in range(g.num_nodes()): 365 | if ind in train_node_set: 366 | ndata.append(n_feat[ind]) 367 | else: 368 | ndata.append([0] * n_feat.shape[1]) 369 | 370 | edata_feat = e_feat[train_e_idx_l] 371 | edata_ts = train_ts_l.reshape(train_ts_l.shape[0], -1) 372 | 373 | g.ndata['feat'] = torch.tensor(np.array(ndata), dtype=torch.float32) 374 | g.edata['edge_feat'] = torch.tensor(np.array(edata_feat), dtype=torch.float32) 375 | g.edata['ts'] = torch.tensor(np.array(edata_ts), dtype=torch.float32) 376 | 377 | g = dgl.add_self_loop(g) 378 | g = dgl.add_reverse_edges(g, copy_ndata=True, copy_edata=True) 379 | g = g.to(device) 380 | 381 | SSL_Encoder_k = TGAN(train_ngh_finder, n_feat, e_feat, num_layers=NUM_LAYER, use_time=USE_TIME, agg_method=AGG_METHOD, 382 | attn_mode=ATTN_MODE, seq_len=SEQ_LEN, n_head=NUM_HEADS, drop_out=0.0) 383 | view_learner = ETGNN(in_dim=n_feat.shape[1], hidden_dim=n_feat.shape[1], train_src_l=train_src_l, train_dst_l=train_dst_l) 384 | tgan = TGAN(train_ngh_finder, n_feat, e_feat, num_layers=NUM_LAYER, use_time=USE_TIME, agg_method=AGG_METHOD, 385 | attn_mode=ATTN_MODE, seq_len=SEQ_LEN, n_head=NUM_HEADS, drop_out=DROP_OUT) 386 | edge_rnn = torch.nn.LSTM(input_size=n_feat.shape[1], hidden_size=n_feat.shape[1], num_layers=NUM_RNN_LAYER, bidirectional=False) 387 | sample_time_encoder = TimeMapping() 388 | model = MTL(base_encoder_k=SSL_Encoder_k, encoder=tgan, view_learner=view_learner, edge_rnn=edge_rnn, 389 | sample_time_encoder=sample_time_encoder, len_full_edge=e_feat.shape[0], train_e_idx_l=train_e_idx_l, 390 | train_node_set=train_node_set, train_ts_l=train_ts_l, e_feat=e_feat, device=device, K=K, ratio=RATIO, 391 | can_nn=NUM_CAN_NN, rnn_nn=NUM_RNN_NN, can_type=CAN_TYPE, tau=TAU, gtau=GTAU) 392 | model = model.to(device) 393 | 394 | 395 | optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE) 396 | ssl_criterion = torch.nn.CrossEntropyLoss().to(device) 397 | criterion = torch.nn.BCELoss().to(device) 398 | early_stopper = EarlyStopMonitor(max_round=PATIENCE, tolerance=TOLERANCE) 399 | 400 | 401 | num_instance = len(train_src_l) 402 | num_batch = math.ceil(num_instance / BATCH_SIZE) 403 | 404 | print('Fine-tuning Stage: num of training instances: {}'.format(num_instance)) 405 | print('Fine-tuning Stage: num of batches per epoch: {}'.format(num_batch)) 406 | 407 | for epoch in range(NUM_EPOCH): 408 | start_time = time.time() 409 | model.encoder.ngh_finder = train_ngh_finder 410 | acc, ap, f1, auc, m_loss = [], [], [], [], [] 411 | print('{} start {} epoch'.format(show_time(), epoch)) 412 | for k in range(num_batch): 413 | s_idx = k * BATCH_SIZE 414 | e_idx = min(num_instance - 1, s_idx + BATCH_SIZE) 415 | src_l_cut, dst_l_cut = train_src_l[s_idx:e_idx], train_dst_l[s_idx:e_idx] 416 | ts_l_cut = train_ts_l[s_idx:e_idx] 417 | label_l_cut = train_label_l[s_idx:e_idx] 418 | size = len(src_l_cut) 419 | src_l_fake, dst_l_fake = train_rand_sampler.sample(size) 420 | 421 | with torch.no_grad(): 422 | pos_label = torch.ones(size, dtype=torch.float, device=device) 423 | neg_label = torch.zeros(size, dtype=torch.float, device=device) 424 | 425 | optimizer.zero_grad() 426 | model = model.train() 427 | 428 | pos_prob, neg_prob, pos_prob_ed, neg_prob_ed, output, target = model(src_l_cut, dst_l_cut, dst_l_fake, ts_l_cut, NUM_NEIGHBORS, g, adj_list, adj_list_pickle) 429 | 430 | loss = criterion(pos_prob, pos_label) 431 | loss += criterion(neg_prob, neg_label) 432 | loss += criterion(pos_prob_ed, pos_label) 433 | loss += criterion(neg_prob_ed, neg_label) 434 | loss += ssl_criterion(output, target) * COE 435 | loss.backward() 436 | optimizer.step() 437 | # get training results 438 | with torch.no_grad(): 439 | model = model.eval() 440 | pred_score = np.concatenate([(pos_prob).cpu().detach().numpy(), (neg_prob).cpu().detach().numpy()]) 441 | pred_label = pred_score > 0.5 442 | true_label = np.concatenate([np.ones(size), np.zeros(size)]) 443 | acc.append((pred_label == true_label).mean()) 444 | ap.append(average_precision_score(true_label, pred_score)) 445 | m_loss.append(loss.item()) 446 | auc.append(roc_auc_score(true_label, pred_score)) 447 | 448 | 449 | end_time = time.time() 450 | print('epoch: {} took {:.2f}s'.format(epoch, end_time - start_time)) 451 | # validation phase use all information 452 | tgan.ngh_finder = full_ngh_finder 453 | val_acc, val_ap, val_f1, val_auc = eval_one_epoch('val for old nodes', tgan, view_learner, edge_rnn, sample_time_encoder, val_rand_sampler, val_src_l, 454 | val_dst_l, val_ts_l, val_label_l) 455 | 456 | nn_val_acc, nn_val_ap, nn_val_f1, nn_val_auc = eval_one_epoch('val for new nodes', tgan, view_learner, edge_rnn, sample_time_encoder, val_rand_sampler, nn_val_src_l, 457 | nn_val_dst_l, nn_val_ts_l, nn_val_label_l) 458 | 459 | print('epoch: {}:'.format(epoch)) 460 | print('Epoch mean loss: {}'.format(np.mean(m_loss))) 461 | print('train acc: {}, val acc: {}, new node val acc: {}'.format(np.mean(acc), val_acc, nn_val_acc)) 462 | print('train auc: {}, val auc: {}, new node val auc: {}'.format(np.mean(auc), val_auc, nn_val_auc)) 463 | print('train ap: {}, val ap: {}, new node val ap: {}'.format(np.mean(ap), val_ap, nn_val_ap)) 464 | # print('train f1: {}, val f1: {}, new node val f1: {}'.format(np.mean(f1), val_f1, nn_val_f1)) 465 | 466 | if early_stopper.early_stop_check(val_ap): 467 | print('No improvment over {} epochs, stop training'.format(early_stopper.max_round)) 468 | print(f'Loading the best model at epoch {early_stopper.best_epoch}') 469 | best_model_path = get_checkpoint_path(early_stopper.best_epoch) 470 | tgan.load_state_dict(torch.load(best_model_path)) 471 | view_learner.load_state_dict(torch.load(best_model_path[:-4] + '-ViewLearner.pth')) 472 | edge_rnn.load_state_dict(torch.load(best_model_path[:-4] + '-edgernn.pth')) 473 | sample_time_encoder.load_state_dict(torch.load(best_model_path[:-4] + '-ste.pth')) 474 | print(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference') 475 | tgan.eval() 476 | view_learner.eval() 477 | edge_rnn.eval() 478 | sample_time_encoder.eval() 479 | break 480 | else: 481 | torch.save(tgan.state_dict(), get_checkpoint_path(epoch)) 482 | torch.save(view_learner.state_dict(), get_checkpoint_path(epoch)[:-4] + '-ViewLearner.pth') 483 | torch.save(edge_rnn.state_dict(), get_checkpoint_path(epoch)[:-4] + '-edgernn.pth') 484 | torch.save(sample_time_encoder.state_dict(), get_checkpoint_path(epoch)[:-4] + '-ste.pth') 485 | 486 | 487 | # testing phase use all information 488 | tgan.ngh_finder = full_ngh_finder 489 | test_acc, test_ap, test_f1, test_auc = eval_one_epoch('test for old nodes', tgan, view_learner, edge_rnn, sample_time_encoder, test_rand_sampler, test_src_l, 490 | test_dst_l, test_ts_l, test_label_l) 491 | 492 | nn_test_acc, nn_test_ap, nn_test_f1, nn_test_auc = eval_one_epoch('test for new nodes', tgan, view_learner, edge_rnn, sample_time_encoder, nn_test_rand_sampler, nn_test_src_l, 493 | nn_test_dst_l, nn_test_ts_l, nn_test_label_l) 494 | 495 | print('Test statistics: Old nodes -- acc: {}, auc: {}, ap: {}'.format(test_acc, test_auc, test_ap)) 496 | print('Test statistics: New nodes -- acc: {}, auc: {}, ap: {}'.format(nn_test_acc, nn_test_auc, nn_test_ap)) 497 | 498 | print('Saving TGAN model') 499 | torch.save(tgan.state_dict(), MODEL_SAVE_PATH) 500 | torch.save(view_learner.state_dict(), MODEL_SAVE_PATH[:-4] + '-ViewLearner.pth') 501 | torch.save(edge_rnn.state_dict(), MODEL_SAVE_PATH[:-4] + '-edgernn.pth') 502 | torch.save(sample_time_encoder.state_dict(), MODEL_SAVE_PATH[:-4] + '-ste.pth') 503 | print('TGAN models saved') 504 | -------------------------------------------------------------------------------- /GraphMixer+TGSL/utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import pandas as pd 6 | 7 | 8 | def set_random_seed(seed: int = 0): 9 | """ 10 | set random seed 11 | :param seed: int, random seed 12 | :return: 13 | """ 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | if torch.cuda.is_available(): 18 | torch.cuda.manual_seed_all(seed) 19 | torch.backends.cudnn.deterministic = True 20 | torch.backends.cudnn.benchmark = False 21 | 22 | 23 | def convert_to_gpu(*data, device: str): 24 | """ 25 | convert data from cpu to gpu, accelerate the running speed 26 | :param data: can be any type, including Tensor, Module, ... 27 | :param device: str 28 | """ 29 | res = [] 30 | for item in data: 31 | item = item.to(device) 32 | res.append(item) 33 | if len(res) > 1: 34 | res = tuple(res) 35 | else: 36 | res = res[0] 37 | return res 38 | 39 | 40 | def get_parameter_sizes(model: nn.Module): 41 | """ 42 | get parameter size of trainable parameters in model 43 | :param model: nn.Module 44 | :return: 45 | """ 46 | return sum([p.numel() for p in model.parameters() if p.requires_grad]) 47 | 48 | 49 | def create_optimizer(model: nn.Module, optimizer_name: str, learning_rate: float, weight_decay: float = 0.0): 50 | """ 51 | create optimizer 52 | :param model: nn.Module 53 | :param optimizer_name: str, optimizer name 54 | :param learning_rate: float, learning rate 55 | :param weight_decay: float, weight decay 56 | :return: 57 | """ 58 | if optimizer_name == 'Adam': 59 | optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay) 60 | elif optimizer_name == 'SGD': 61 | optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay) 62 | elif optimizer_name == 'RMSprop': 63 | optimizer = torch.optim.RMSprop(params=model.parameters(), lr=learning_rate, weight_decay=weight_decay) 64 | else: 65 | raise ValueError(f"Wrong value for optimizer {optimizer_name}!") 66 | 67 | return optimizer 68 | 69 | 70 | class NeighborSampler: 71 | 72 | def __init__(self, adj_list: list, sample_neighbor_strategy: str = 'uniform', time_scaling_factor: float = 0.0, seed: int = None): 73 | """ 74 | Neighbor sampler. 75 | :param adj_list: list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp) 76 | :param sample_neighbor_strategy: str, how to sample historical neighbors, 'uniform', 'recent', or 'time_interval_aware' 77 | :param time_scaling_factor: float, a hyper-parameter that controls the sampling preference with time interval, 78 | a large time_scaling_factor tends to sample more on recent links, this parameter works when sample_neighbor_strategy == 'time_interval_aware' 79 | :param seed: int, random seed 80 | """ 81 | self.sample_neighbor_strategy = sample_neighbor_strategy 82 | self.seed = seed 83 | 84 | # list of each node's neighbor ids, edge ids and interaction times, which are sorted by interaction times 85 | self.nodes_neighbor_ids = [] 86 | self.nodes_edge_ids = [] 87 | self.nodes_neighbor_times = [] 88 | 89 | if self.sample_neighbor_strategy == 'time_interval_aware': 90 | self.nodes_neighbor_sampled_probabilities = [] 91 | self.time_scaling_factor = time_scaling_factor 92 | 93 | # the list at the first position in adj_list is empty, hence, sorted() will return an empty list for the first position 94 | # its corresponding value in self.nodes_neighbor_ids, self.nodes_edge_ids, self.nodes_neighbor_times will also be empty with length 0 95 | for node_idx, per_node_neighbors in enumerate(adj_list): 96 | # per_node_neighbors is a list of tuples (neighbor_id, edge_id, timestamp) 97 | # sort the list based on timestamps, sorted() function is stable 98 | # Note that sort the list based on edge id is also correct, as the original data file ensures the interactions are chronological 99 | sorted_per_node_neighbors = sorted(per_node_neighbors, key=lambda x: x[2]) 100 | self.nodes_neighbor_ids.append(np.array([x[0] for x in sorted_per_node_neighbors])) 101 | self.nodes_edge_ids.append(np.array([x[1] for x in sorted_per_node_neighbors])) 102 | self.nodes_neighbor_times.append(np.array([x[2] for x in sorted_per_node_neighbors])) 103 | 104 | # additional for time interval aware sampling strategy (proposed in CAWN paper) 105 | if self.sample_neighbor_strategy == 'time_interval_aware': 106 | self.nodes_neighbor_sampled_probabilities.append(self.compute_sampled_probabilities(np.array([x[2] for x in sorted_per_node_neighbors]))) 107 | 108 | if self.seed is not None: 109 | self.random_state = np.random.RandomState(self.seed) 110 | 111 | def compute_sampled_probabilities(self, node_neighbor_times: np.ndarray): 112 | """ 113 | compute the sampled probabilities of historical neighbors based on their interaction times 114 | :param node_neighbor_times: ndarray, shape (num_historical_neighbors, ) 115 | :return: 116 | """ 117 | if len(node_neighbor_times) == 0: 118 | return np.array([]) 119 | # compute the time delta with regard to the last time in node_neighbor_times 120 | node_neighbor_times = node_neighbor_times - np.max(node_neighbor_times) 121 | # compute the normalized sampled probabilities of historical neighbors 122 | exp_node_neighbor_times = np.exp(self.time_scaling_factor * node_neighbor_times) 123 | sampled_probabilities = exp_node_neighbor_times / np.cumsum(exp_node_neighbor_times) 124 | # note that the first few values in exp_node_neighbor_times may be all zero, which make the corresponding values in sampled_probabilities 125 | # become nan (divided by zero), so we replace the nan by a very large negative number -1e10 to denote the sampled probabilities 126 | sampled_probabilities[np.isnan(sampled_probabilities)] = -1e10 127 | return sampled_probabilities 128 | 129 | def find_neighbors_before(self, node_id: int, interact_time: float, return_sampled_probabilities: bool = False): 130 | """ 131 | extracts all the interactions happening before interact_time (less than interact_time) for node_id in the overall interaction graph 132 | the returned interactions are sorted by time. 133 | :param node_id: int, node id 134 | :param interact_time: float, interaction time 135 | :param return_sampled_probabilities: boolean, whether return the sampled probabilities of neighbors 136 | :return: neighbors, edge_ids, timestamps and sampled_probabilities (if return_sampled_probabilities is True) with shape (historical_nodes_num, ) 137 | """ 138 | # return index i, which satisfies list[i - 1] < v <= list[i] 139 | # return 0 for the first position in self.nodes_neighbor_times since the value at the first position is empty 140 | i = np.searchsorted(self.nodes_neighbor_times[node_id], interact_time) 141 | 142 | if return_sampled_probabilities: 143 | return self.nodes_neighbor_ids[node_id][:i], self.nodes_edge_ids[node_id][:i], self.nodes_neighbor_times[node_id][:i], \ 144 | self.nodes_neighbor_sampled_probabilities[node_id][:i] 145 | else: 146 | return self.nodes_neighbor_ids[node_id][:i], self.nodes_edge_ids[node_id][:i], self.nodes_neighbor_times[node_id][:i], None 147 | 148 | def get_historical_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20): 149 | """ 150 | get historical neighbors of nodes in node_ids with interactions before the corresponding time in node_interact_times 151 | :param node_ids: ndarray, shape (batch_size, ) or (*, ), node ids 152 | :param node_interact_times: ndarray, shape (batch_size, ) or (*, ), node interaction times 153 | :param num_neighbors: int, number of neighbors to sample for each node 154 | :return: 155 | """ 156 | assert num_neighbors > 0, 'Number of sampled neighbors for each node should be greater than 0!' 157 | # All interactions described in the following three matrices are sorted in each row by time 158 | # each entry in position (i,j) represents the id of the j-th dst node of src node node_ids[i] with an interaction before node_interact_times[i] 159 | # ndarray, shape (batch_size, num_neighbors) 160 | nodes_neighbor_ids = np.zeros((len(node_ids), num_neighbors)).astype(int) 161 | # each entry in position (i,j) represents the id of the edge with src node node_ids[i] and dst node nodes_neighbor_ids[i][j] with an interaction before node_interact_times[i] 162 | # ndarray, shape (batch_size, num_neighbors) 163 | nodes_edge_ids = np.zeros((len(node_ids), num_neighbors)).astype(int) 164 | # each entry in position (i,j) represents the interaction time between src node node_ids[i] and dst node nodes_neighbor_ids[i][j], before node_interact_times[i] 165 | # ndarray, shape (batch_size, num_neighbors) 166 | nodes_neighbor_times = np.zeros((len(node_ids), num_neighbors)).astype(np.float32) 167 | 168 | # extracts all neighbors ids, edge ids and interaction times of nodes in node_ids, which happened before the corresponding time in node_interact_times 169 | for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)): 170 | # find neighbors that interacted with node_id before time node_interact_time 171 | node_neighbor_ids, node_edge_ids, node_neighbor_times, node_neighbor_sampled_probabilities = \ 172 | self.find_neighbors_before(node_id=node_id, interact_time=node_interact_time, return_sampled_probabilities=self.sample_neighbor_strategy == 'time_interval_aware') 173 | 174 | if len(node_neighbor_ids) > 0: 175 | if self.sample_neighbor_strategy in ['uniform', 'time_interval_aware']: 176 | # when self.sample_neighbor_strategy == 'uniform', we shuffle the data before sampling with node_neighbor_sampled_probabilities as None 177 | # when self.sample_neighbor_strategy == 'time_interval_aware', we sample neighbors based on node_neighbor_sampled_probabilities 178 | # for time_interval_aware sampling strategy, we additionally use softmax to make the sum of sampled probabilities be 1 179 | if node_neighbor_sampled_probabilities is not None: 180 | # for extreme case that node_neighbor_sampled_probabilities only contains -1e10, which will make the denominator of softmax be zero, 181 | # torch.softmax() function can tackle this case 182 | node_neighbor_sampled_probabilities = torch.softmax(torch.from_numpy(node_neighbor_sampled_probabilities).float(), dim=0).numpy() 183 | if self.seed is None: 184 | sampled_indices = np.random.choice(a=len(node_neighbor_ids), size=num_neighbors, p=node_neighbor_sampled_probabilities) 185 | else: 186 | sampled_indices = self.random_state.choice(a=len(node_neighbor_ids), size=num_neighbors, p=node_neighbor_sampled_probabilities) 187 | 188 | nodes_neighbor_ids[idx, :] = node_neighbor_ids[sampled_indices] 189 | nodes_edge_ids[idx, :] = node_edge_ids[sampled_indices] 190 | nodes_neighbor_times[idx, :] = node_neighbor_times[sampled_indices] 191 | 192 | # resort based on timestamps, return the ids in sorted increasing order, note this maybe unstable when multiple edges happen at the same time 193 | # (we still do this though this is unnecessary for TGAT or CAWN to guarantee the order of nodes, 194 | # since TGAT computes in an order-agnostic manner with relative time encoding, and CAWN computes for each walk while the sampled nodes are in different walks) 195 | sorted_position = nodes_neighbor_times[idx, :].argsort() 196 | nodes_neighbor_ids[idx, :] = nodes_neighbor_ids[idx, :][sorted_position] 197 | nodes_edge_ids[idx, :] = nodes_edge_ids[idx, :][sorted_position] 198 | nodes_neighbor_times[idx, :] = nodes_neighbor_times[idx, :][sorted_position] 199 | elif self.sample_neighbor_strategy == 'recent': 200 | # Take most recent interactions with number num_neighbors 201 | node_neighbor_ids = node_neighbor_ids[-num_neighbors:] 202 | node_edge_ids = node_edge_ids[-num_neighbors:] 203 | node_neighbor_times = node_neighbor_times[-num_neighbors:] 204 | 205 | # put the neighbors' information at the back positions 206 | nodes_neighbor_ids[idx, num_neighbors - len(node_neighbor_ids):] = node_neighbor_ids 207 | nodes_edge_ids[idx, num_neighbors - len(node_edge_ids):] = node_edge_ids 208 | nodes_neighbor_times[idx, num_neighbors - len(node_neighbor_times):] = node_neighbor_times 209 | else: 210 | raise ValueError(f'Not implemented error for sample_neighbor_strategy {self.sample_neighbor_strategy}!') 211 | 212 | # three ndarrays, with shape (batch_size, num_neighbors) 213 | return nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times 214 | 215 | def get_multi_hop_neighbors(self, num_hops: int, node_ids: np.ndarray, node_interact_times: np.ndarray, num_neighbors: int = 20): 216 | """ 217 | get historical neighbors of nodes in node_ids within num_hops hops 218 | :param num_hops: int, number of sampled hops 219 | :param node_ids: ndarray, shape (batch_size, ), node ids 220 | :param node_interact_times: ndarray, shape (batch_size, ), node interaction times 221 | :param num_neighbors: int, number of neighbors to sample for each node 222 | :return: 223 | """ 224 | assert num_hops > 0, 'Number of sampled hops should be greater than 0!' 225 | 226 | # get the temporal neighbors at the first hop 227 | # nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times -> ndarray, shape (batch_size, num_neighbors) 228 | nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=node_ids, 229 | node_interact_times=node_interact_times, 230 | num_neighbors=num_neighbors) 231 | # three lists to store the neighbor ids, edge ids and interaction timestamp information 232 | nodes_neighbor_ids_list = [nodes_neighbor_ids] 233 | nodes_edge_ids_list = [nodes_edge_ids] 234 | nodes_neighbor_times_list = [nodes_neighbor_times] 235 | for hop in range(1, num_hops): 236 | # get information of neighbors sampled at the current hop 237 | # three ndarrays, with shape (batch_size * num_neighbors ** hop, num_neighbors) 238 | nodes_neighbor_ids, nodes_edge_ids, nodes_neighbor_times = self.get_historical_neighbors(node_ids=nodes_neighbor_ids_list[-1].flatten(), 239 | node_interact_times=nodes_neighbor_times_list[-1].flatten(), 240 | num_neighbors=num_neighbors) 241 | # three ndarrays with shape (batch_size, num_neighbors ** (hop + 1)) 242 | nodes_neighbor_ids = nodes_neighbor_ids.reshape(len(node_ids), -1) 243 | nodes_edge_ids = nodes_edge_ids.reshape(len(node_ids), -1) 244 | nodes_neighbor_times = nodes_neighbor_times.reshape(len(node_ids), -1) 245 | 246 | nodes_neighbor_ids_list.append(nodes_neighbor_ids) 247 | nodes_edge_ids_list.append(nodes_edge_ids) 248 | nodes_neighbor_times_list.append(nodes_neighbor_times) 249 | 250 | # tuple, each element in the tuple is a list of num_hops ndarrays, each with shape (batch_size, num_neighbors ** current_hop) 251 | return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list 252 | 253 | def get_all_first_hop_neighbors(self, node_ids: np.ndarray, node_interact_times: np.ndarray): 254 | """ 255 | get historical neighbors of nodes in node_ids at the first hop with max_num_neighbors as the maximal number of neighbors (make the computation feasible) 256 | :param node_ids: ndarray, shape (batch_size, ), node ids 257 | :param node_interact_times: ndarray, shape (batch_size, ), node interaction times 258 | :return: 259 | """ 260 | # three lists to store the first-hop neighbor ids, edge ids and interaction timestamp information, with batch_size as the list length 261 | nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list = [], [], [] 262 | # get the temporal neighbors at the first hop 263 | for idx, (node_id, node_interact_time) in enumerate(zip(node_ids, node_interact_times)): 264 | # find neighbors that interacted with node_id before time node_interact_time 265 | node_neighbor_ids, node_edge_ids, node_neighbor_times, _ = self.find_neighbors_before(node_id=node_id, 266 | interact_time=node_interact_time, 267 | return_sampled_probabilities=False) 268 | nodes_neighbor_ids_list.append(node_neighbor_ids) 269 | nodes_edge_ids_list.append(node_edge_ids) 270 | nodes_neighbor_times_list.append(node_neighbor_times) 271 | 272 | return nodes_neighbor_ids_list, nodes_edge_ids_list, nodes_neighbor_times_list 273 | 274 | def reset_random_state(self): 275 | """ 276 | reset the random state by self.seed 277 | :return: 278 | """ 279 | self.random_state = np.random.RandomState(self.seed) 280 | 281 | 282 | def get_neighbor_sampler(data: pd.DataFrame, sample_neighbor_strategy: str = 'uniform', time_scaling_factor: float = 0.0, seed: int = None): 283 | """ 284 | get neighbor sampler 285 | :param data: DataFrame 286 | :param sample_neighbor_strategy: str, how to sample historical neighbors, 'uniform', 'recent', or 'time_interval_aware'' 287 | :param time_scaling_factor: float, a hyper-parameter that controls the sampling preference with time interval, 288 | a large time_scaling_factor tends to sample more on recent links, this parameter works when sample_neighbor_strategy == 'time_interval_aware' 289 | :param seed: int, random seed 290 | :return: 291 | """ 292 | max_node_id = max(data.src_node_ids.max(), data.dst_node_ids.max()) 293 | # the adjacency vector stores edges for each node (source or destination), undirected 294 | # adj_list, list of list, where each element is a list of triple tuple (node_id, edge_id, timestamp) 295 | # the list at the first position in adj_list is empty 296 | adj_list = [[] for _ in range(max_node_id + 1)] 297 | for src_node_id, dst_node_id, edge_id, node_interact_time in zip(data.src_node_ids, data.dst_node_ids, data.edge_ids, data.node_interact_times): 298 | adj_list[src_node_id].append((dst_node_id, edge_id, node_interact_time)) 299 | adj_list[dst_node_id].append((src_node_id, edge_id, node_interact_time)) 300 | 301 | return NeighborSampler(adj_list=adj_list, sample_neighbor_strategy=sample_neighbor_strategy, time_scaling_factor=time_scaling_factor, seed=seed), adj_list 302 | 303 | 304 | class NegativeEdgeSampler(object): 305 | 306 | def __init__(self, src_node_ids: np.ndarray, dst_node_ids: np.ndarray, interact_times: np.ndarray = None, last_observed_time: float = None, 307 | negative_sample_strategy: str = 'random', seed: int = None): 308 | """ 309 | Negative Edge Sampler, which supports three strategies: "random", "historical", "inductive". 310 | :param src_node_ids: ndarray, (num_src_nodes, ), source node ids, num_src_nodes == num_dst_nodes 311 | :param dst_node_ids: ndarray, (num_dst_nodes, ), destination node ids 312 | :param interact_times: ndarray, (num_src_nodes, ), interaction timestamps 313 | :param last_observed_time: float, time of the last observation (for inductive negative sampling strategy) 314 | :param negative_sample_strategy: str, negative sampling strategy, can be "random", "historical", "inductive" 315 | :param seed: int, random seed 316 | """ 317 | self.seed = seed 318 | self.negative_sample_strategy = negative_sample_strategy 319 | self.src_node_ids = src_node_ids 320 | self.dst_node_ids = dst_node_ids 321 | self.interact_times = interact_times 322 | self.unique_src_node_ids = np.unique(src_node_ids) 323 | self.unique_dst_node_ids = np.unique(dst_node_ids) 324 | self.unique_interact_times = np.unique(interact_times) 325 | self.earliest_time = min(self.unique_interact_times) 326 | self.last_observed_time = last_observed_time 327 | 328 | if self.negative_sample_strategy != 'random': 329 | # all the possible edges that connect source nodes in self.unique_src_node_ids with destination nodes in self.unique_dst_node_ids 330 | self.possible_edges = set((src_node_id, dst_node_id) for src_node_id in self.unique_src_node_ids for dst_node_id in self.unique_dst_node_ids) 331 | 332 | if self.negative_sample_strategy == 'inductive': 333 | # set of observed edges 334 | self.observed_edges = self.get_unique_edges_between_start_end_time(self.earliest_time, self.last_observed_time) 335 | 336 | if self.seed is not None: 337 | self.random_state = np.random.RandomState(self.seed) 338 | 339 | def get_unique_edges_between_start_end_time(self, start_time: float, end_time: float): 340 | """ 341 | get unique edges happened between start and end time 342 | :param start_time: float, start timestamp 343 | :param end_time: float, end timestamp 344 | :return: a set of edges, where each edge is a tuple of (src_node_id, dst_node_id) 345 | """ 346 | selected_time_interval = np.logical_and(self.interact_times >= start_time, self.interact_times <= end_time) 347 | # return the unique select source and destination nodes in the selected time interval 348 | return set((src_node_id, dst_node_id) for src_node_id, dst_node_id in zip(self.src_node_ids[selected_time_interval], self.dst_node_ids[selected_time_interval])) 349 | 350 | def sample(self, size: int, batch_src_node_ids: np.ndarray = None, batch_dst_node_ids: np.ndarray = None, 351 | current_batch_start_time: float = 0.0, current_batch_end_time: float = 0.0): 352 | """ 353 | sample negative edges, support random, historical and inductive sampling strategy 354 | :param size: int, number of sampled negative edges 355 | :param batch_src_node_ids: ndarray, shape (batch_size, ), source node ids in the current batch 356 | :param batch_dst_node_ids: ndarray, shape (batch_size, ), destination node ids in the current batch 357 | :param current_batch_start_time: float, start time in the current batch 358 | :param current_batch_end_time: float, end time in the current batch 359 | :return: 360 | """ 361 | if self.negative_sample_strategy == 'random': 362 | negative_src_node_ids, negative_dst_node_ids = self.random_sample(size=size) 363 | elif self.negative_sample_strategy == 'historical': 364 | negative_src_node_ids, negative_dst_node_ids = self.historical_sample(size=size, batch_src_node_ids=batch_src_node_ids, 365 | batch_dst_node_ids=batch_dst_node_ids, 366 | current_batch_start_time=current_batch_start_time, 367 | current_batch_end_time=current_batch_end_time) 368 | elif self.negative_sample_strategy == 'inductive': 369 | negative_src_node_ids, negative_dst_node_ids = self.inductive_sample(size=size, batch_src_node_ids=batch_src_node_ids, 370 | batch_dst_node_ids=batch_dst_node_ids, 371 | current_batch_start_time=current_batch_start_time, 372 | current_batch_end_time=current_batch_end_time) 373 | else: 374 | raise ValueError(f'Not implemented error for negative_sample_strategy {self.negative_sample_strategy}!') 375 | return negative_src_node_ids, negative_dst_node_ids 376 | 377 | def random_sample(self, size: int): 378 | """ 379 | random sampling strategy, which is used by previous works 380 | :param size: int, number of sampled negative edges 381 | :return: 382 | """ 383 | if self.seed is None: 384 | random_sample_edge_src_node_indices = np.random.randint(0, len(self.unique_src_node_ids), size) 385 | random_sample_edge_dst_node_indices = np.random.randint(0, len(self.unique_dst_node_ids), size) 386 | else: 387 | random_sample_edge_src_node_indices = self.random_state.randint(0, len(self.unique_src_node_ids), size) 388 | random_sample_edge_dst_node_indices = self.random_state.randint(0, len(self.unique_dst_node_ids), size) 389 | return self.unique_src_node_ids[random_sample_edge_src_node_indices], self.unique_dst_node_ids[random_sample_edge_dst_node_indices] 390 | 391 | def random_sample_with_collision_check(self, size: int, batch_src_node_ids: np.ndarray, batch_dst_node_ids: np.ndarray): 392 | """ 393 | random sampling strategy with collision check, which guarantees that the sampled edges do not appear in the current batch, 394 | used for historical and inductive sampling strategy 395 | :param size: int, number of sampled negative edges 396 | :param batch_src_node_ids: ndarray, shape (batch_size, ), source node ids in the current batch 397 | :param batch_dst_node_ids: ndarray, shape (batch_size, ), destination node ids in the current batch 398 | :return: 399 | """ 400 | assert batch_src_node_ids is not None and batch_dst_node_ids is not None 401 | batch_edges = set((batch_src_node_id, batch_dst_node_id) for batch_src_node_id, batch_dst_node_id in zip(batch_src_node_ids, batch_dst_node_ids)) 402 | possible_random_edges = list(self.possible_edges - batch_edges) 403 | assert len(possible_random_edges) > 0 404 | # if replace is True, then a value in the list can be selected multiple times, otherwise, a value can be selected only once at most 405 | random_edge_indices = self.random_state.choice(len(possible_random_edges), size=size, replace=len(possible_random_edges) < size) 406 | return np.array([possible_random_edges[random_edge_idx][0] for random_edge_idx in random_edge_indices]), \ 407 | np.array([possible_random_edges[random_edge_idx][1] for random_edge_idx in random_edge_indices]) 408 | 409 | def historical_sample(self, size: int, batch_src_node_ids: np.ndarray, batch_dst_node_ids: np.ndarray, 410 | current_batch_start_time: float, current_batch_end_time: float): 411 | """ 412 | historical sampling strategy, first randomly samples among historical edges that are not in the current batch, 413 | if number of historical edges is smaller than size, then fill in remaining edges with randomly sampled edges 414 | :param size: int, number of sampled negative edges 415 | :param batch_src_node_ids: ndarray, shape (batch_size, ), source node ids in the current batch 416 | :param batch_dst_node_ids: ndarray, shape (batch_size, ), destination node ids in the current batch 417 | :param current_batch_start_time: float, start time in the current batch 418 | :param current_batch_end_time: float, end time in the current batch 419 | :return: 420 | """ 421 | assert self.seed is not None 422 | # get historical edges up to current_batch_start_time 423 | historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time) 424 | # get edges in the current batch 425 | current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time) 426 | # get source and destination node ids of unique historical edges 427 | unique_historical_edges = historical_edges - current_batch_edges 428 | unique_historical_edges_src_node_ids = np.array([edge[0] for edge in unique_historical_edges]) 429 | unique_historical_edges_dst_node_ids = np.array([edge[1] for edge in unique_historical_edges]) 430 | 431 | # if sample size is larger than number of unique historical edges, then fill in remaining edges with randomly sampled edges with collision check 432 | if size > len(unique_historical_edges): 433 | num_random_sample_edges = size - len(unique_historical_edges) 434 | random_sample_src_node_ids, random_sample_dst_node_ids = self.random_sample_with_collision_check(size=num_random_sample_edges, 435 | batch_src_node_ids=batch_src_node_ids, 436 | batch_dst_node_ids=batch_dst_node_ids) 437 | 438 | negative_src_node_ids = np.concatenate([random_sample_src_node_ids, unique_historical_edges_src_node_ids]) 439 | negative_dst_node_ids = np.concatenate([random_sample_dst_node_ids, unique_historical_edges_dst_node_ids]) 440 | else: 441 | historical_sample_edge_node_indices = self.random_state.choice(len(unique_historical_edges), size=size, replace=False) 442 | negative_src_node_ids = unique_historical_edges_src_node_ids[historical_sample_edge_node_indices] 443 | negative_dst_node_ids = unique_historical_edges_dst_node_ids[historical_sample_edge_node_indices] 444 | 445 | # Note that if one of the input of np.concatenate is empty, the output will be composed of floats. 446 | # Hence, convert the type to long to guarantee valid index 447 | return negative_src_node_ids.astype(int), negative_dst_node_ids.astype(int) 448 | 449 | def inductive_sample(self, size: int, batch_src_node_ids: np.ndarray, batch_dst_node_ids: np.ndarray, 450 | current_batch_start_time: float, current_batch_end_time: float): 451 | """ 452 | inductive sampling strategy, first randomly samples among inductive edges that are not in self.observed_edges and the current batch, 453 | if number of inductive edges is smaller than size, then fill in remaining edges with randomly sampled edges 454 | :param size: int, number of sampled negative edges 455 | :param batch_src_node_ids: ndarray, shape (batch_size, ), source node ids in the current batch 456 | :param batch_dst_node_ids: ndarray, shape (batch_size, ), destination node ids in the current batch 457 | :param current_batch_start_time: float, start time in the current batch 458 | :param current_batch_end_time: float, end time in the current batch 459 | :return: 460 | """ 461 | assert self.seed is not None 462 | # get historical edges up to current_batch_start_time 463 | historical_edges = self.get_unique_edges_between_start_end_time(start_time=self.earliest_time, end_time=current_batch_start_time) 464 | # get edges in the current batch 465 | current_batch_edges = self.get_unique_edges_between_start_end_time(start_time=current_batch_start_time, end_time=current_batch_end_time) 466 | # get source and destination node ids of historical edges but 1) not in self.observed_edges; 2) not in the current batch 467 | unique_inductive_edges = historical_edges - self.observed_edges - current_batch_edges 468 | unique_inductive_edges_src_node_ids = np.array([edge[0] for edge in unique_inductive_edges]) 469 | unique_inductive_edges_dst_node_ids = np.array([edge[1] for edge in unique_inductive_edges]) 470 | 471 | # if sample size is larger than number of unique inductive edges, then fill in remaining edges with randomly sampled edges 472 | if size > len(unique_inductive_edges): 473 | num_random_sample_edges = size - len(unique_inductive_edges) 474 | random_sample_src_node_ids, random_sample_dst_node_ids = self.random_sample_with_collision_check(size=num_random_sample_edges, 475 | batch_src_node_ids=batch_src_node_ids, 476 | batch_dst_node_ids=batch_dst_node_ids) 477 | 478 | negative_src_node_ids = np.concatenate([random_sample_src_node_ids, unique_inductive_edges_src_node_ids]) 479 | negative_dst_node_ids = np.concatenate([random_sample_dst_node_ids, unique_inductive_edges_dst_node_ids]) 480 | else: 481 | inductive_sample_edge_node_indices = self.random_state.choice(len(unique_inductive_edges), size=size, replace=False) 482 | negative_src_node_ids = unique_inductive_edges_src_node_ids[inductive_sample_edge_node_indices] 483 | negative_dst_node_ids = unique_inductive_edges_dst_node_ids[inductive_sample_edge_node_indices] 484 | 485 | # Note that if one of the input of np.concatenate is empty, the output will be composed of floats. 486 | # Hence, convert the type to long to guarantee valid index 487 | return negative_src_node_ids.astype(int), negative_dst_node_ids.astype(int) 488 | 489 | def reset_random_state(self): 490 | """ 491 | reset the random state by self.seed 492 | :return: 493 | """ 494 | self.random_state = np.random.RandomState(self.seed) 495 | --------------------------------------------------------------------------------