├── .gitignore ├── figures └── topology.png ├── pytorch_models └── 2022-11-04 16-47-13 │ └── yelp_GDN.pkl ├── config └── gdn_yelpchi.yml ├── utils ├── data_process.py └── utils.py ├── models ├── model.py ├── layers.py └── graphsage.py ├── README.md ├── main.py └── model_handler.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /figures/topology.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blacksingular/wsdm_GDN/HEAD/figures/topology.png -------------------------------------------------------------------------------- /pytorch_models/2022-11-04 16-47-13/yelp_GDN.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/blacksingular/wsdm_GDN/HEAD/pytorch_models/2022-11-04 16-47-13/yelp_GDN.pkl -------------------------------------------------------------------------------- /config/gdn_yelpchi.yml: -------------------------------------------------------------------------------- 1 | # Data 2 | data_name: 'yelp' 3 | data_dir: 'data/' 4 | train_ratio: 0.6 5 | test_ratio: 0.33 6 | save_dir: './pytorch_models/' 7 | 8 | # Model 9 | model: 'GDN' 10 | multi_relation: 'GNN' 11 | 12 | 13 | # Model architecture 14 | emb_size: 64 15 | 16 | thres: 0.2 17 | 18 | seed: 42 19 | 20 | # Run multiple times with different random seeds 21 | # seed: 22 | # - 42 23 | # - 448 24 | # - 854 25 | # - 29493 26 | # - 88867 27 | 28 | 29 | # hyper-parameters 30 | optimizer: 'adam' 31 | lr_1: 0.001 32 | lr_2: 0.00001 33 | weight_decay: 0.001 34 | weight_decay_2: 0.1 35 | batch_size: 1024 36 | test_batch_size: 1024 37 | num_epochs: 101 38 | valid_epochs: 5 39 | 40 | add_constraint: True 41 | Beta: 0.1 42 | # - 0.1 43 | # - 0.5 44 | # - 1 45 | # - 2 46 | 47 | topk: 10 48 | # - 6 49 | # - 12 50 | # - 18 51 | # - 24 52 | biased_split: False 53 | # Device 54 | no_cuda: True 55 | cuda_id: '2' 56 | -------------------------------------------------------------------------------- /utils/data_process.py: -------------------------------------------------------------------------------- 1 | from utils.utils import sparse_to_adjlist 2 | from scipy.io import loadmat 3 | 4 | """ 5 | Read data and save the adjacency matrices to adjacency lists 6 | """ 7 | 8 | 9 | if __name__ == "__main__": 10 | 11 | prefix = 'data/' 12 | 13 | yelp = loadmat('data/YelpChi.mat') 14 | net_rur = yelp['net_rur'] 15 | net_rtr = yelp['net_rtr'] 16 | net_rsr = yelp['net_rsr'] 17 | yelp_homo = yelp['homo'] 18 | 19 | sparse_to_adjlist(net_rur, prefix + 'yelp_rur_adjlists.pickle') 20 | sparse_to_adjlist(net_rtr, prefix + 'yelp_rtr_adjlists.pickle') 21 | sparse_to_adjlist(net_rsr, prefix + 'yelp_rsr_adjlists.pickle') 22 | sparse_to_adjlist(yelp_homo, prefix + 'yelp_homo_adjlists.pickle') 23 | 24 | # amz = loadmat('data/Amazon.mat') 25 | # net_upu = amz['net_upu'] 26 | # net_usu = amz['net_usu'] 27 | # net_uvu = amz['net_uvu'] 28 | # amz_homo = amz['homo'] 29 | 30 | # sparse_to_adjlist(net_upu, prefix + 'amz_upu_adjlists.pickle') 31 | # sparse_to_adjlist(net_usu, prefix + 'amz_usu_adjlists.pickle') 32 | # sparse_to_adjlist(net_uvu, prefix + 'amz_uvu_adjlists.pickle') 33 | # sparse_to_adjlist(amz_homo, prefix + 'amz_homo_adjlists.pickle') 34 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | """ 6 | GDN Model 7 | """ 8 | 9 | 10 | class GDNLayer(nn.Module): 11 | """ 12 | One GDN layer 13 | """ 14 | 15 | def __init__(self, num_classes, inter1): 16 | """ 17 | Initialize GDN model 18 | :param num_classes: number of classes (2 in our paper) 19 | :param inter1: the inter-relation aggregator that output the final embedding 20 | """ 21 | super(GDNLayer, self).__init__() 22 | self.inter1 = inter1 23 | self.xent = nn.CrossEntropyLoss() 24 | self.softmax = nn.Softmax(dim=-1) 25 | self.KLDiv = nn.KLDivLoss(reduction='batchmean') 26 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) 27 | # the parameter to transform the final embedding 28 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, inter1.embed_dim)) 29 | init.xavier_uniform_(self.weight) 30 | 31 | 32 | def forward(self, nodes, labels): 33 | embeds1 = self.inter1(nodes, labels) 34 | scores = self.weight.mm(embeds1) 35 | return scores.t() 36 | 37 | def to_prob(self, nodes, labels): 38 | gnn_logits = self.forward(nodes, labels) 39 | gnn_scores = self.softmax(gnn_logits) 40 | return gnn_scores 41 | 42 | def loss(self, nodes, labels): 43 | gnn_scores = self.forward(nodes, labels) 44 | # GNN loss 45 | gnn_loss = self.xent(gnn_scores, labels.squeeze()) 46 | return gnn_loss 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GDN: Alleviating Structrual Distribution Shift in Graph Anomaly Detection 2 | Pytorch Implementation of 3 | 4 | Alleviating Structrual Distribution Shift in Graph Anomaly Detection (WSDM 2023) 5 | 6 | # Overview 7 | This work solves the SDS problem from a feature view. We observe that the heterophily degree is different across the training and test environments in GAD, leading to the poor generalization of the classifier. To address the issue, we propose GDN to resist high heterophily for anomalies meanwhile benefit the learning of normals from 8 | homophily. Since different labels correspond to the difference of critical anomaly features which make great contributions to the GAD, we tease out the anomaly features on which we constrain to mitigate the effect of heterophilous neighbors and make them invariant. To better estimate the prior distribution of anomaly features, we devise a prototype vector to infer and update this distribution during training. For normal nodes, we constrain the remaining features to preserve the connectivity of nodes and reinforce the influence of the homophilous neighborhood. 9 | 10 |

11 |
12 |

13 | 14 | Illustration of GDN. The feature separation module 15 | separates the node feature into two sets. Two constraints 16 | are leveraged to assist separation. Blank positions in node 17 | representation mean they are zero when calculating losses. 18 | 19 | # Dataset 20 | YelpChi and Amazon can be downloaded from [here](https://github.com/YingtongDou/CARE-GNN/tree/master/data) or [dgl.data.FraudDataset](https://docs.dgl.ai/api/python/dgl.data.html#fraud-dataset). 21 | 22 | Run `python src/data_process.py` to pre-process the data. 23 | 24 | # Dependencies 25 | Please set up the environment following Requirements in this [repository](https://github.com/PonderLY/PC-GNN). 26 | ```sh 27 | argparse 1.1.0 28 | networkx 1.11 29 | numpy 1.16.4 30 | scikit_learn 0.21rc2 31 | scipy 1.2.1 32 | torch 1.4.0 33 | ``` 34 | 35 | # Reproduce 36 | ```sh 37 | python main.py --config ./config/gdn_yelpchi.yml 38 | ``` 39 | 40 | # Acknowledgement 41 | Our code references: 42 | - [CAREGNN](https://github.com/YingtongDou/CARE-GNN) 43 | 44 | - [PCGNN](https://github.com/PonderLY/PC-GNN) 45 | 46 | # Reference 47 | ``` 48 | @inproceedings{ 49 | gao2023gdn, 50 | title={Alleviating Structrual Distribution Shift in Graph Anomaly Detection}, 51 | author={Yuan Gao and Xiang Wang and Xiangnan He and Zhenguang Liu and Huamin Feng and Yongdong Zhang}, 52 | booktitle={WSDM}, 53 | year={2023}, 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import torch 4 | import datetime 5 | import numpy as np 6 | from collections import defaultdict, OrderedDict 7 | import time 8 | 9 | from model_handler import ModelHandler 10 | import logging 11 | 12 | timestamp = time.time() 13 | timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S') 14 | logging.basicConfig(filename='result.log',level=logging.INFO) 15 | logging.info(timestamp) 16 | # timestamp = time.time() 17 | 18 | 19 | ################################################################################ 20 | # Main # 21 | ################################################################################ 22 | 23 | 24 | def set_random_seed(seed): 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | np.random.seed(seed) 28 | 29 | 30 | def main(config): 31 | print_config(config) 32 | set_random_seed(config['seed']) 33 | model = ModelHandler(config) 34 | f1_mac_test, auc_test, gmean_test = model.train() 35 | print("F1-Macro: {}".format(f1_mac_test)) 36 | print("AUC: {}".format(auc_test)) 37 | print("G-Mean: {}".format(gmean_test)) 38 | 39 | 40 | 41 | def multi_run_main(config): 42 | print_config(config) 43 | hyperparams = [] 44 | for k, v in config.items(): 45 | if isinstance(v, list): 46 | hyperparams.append(k) 47 | 48 | f1_list, f1_1_list, f1_0_list, auc_list, gmean_list = [], [], [], [], [] 49 | configs = grid(config) 50 | for i, cnf in enumerate(configs): 51 | print('Running {}:\n'.format(i)) 52 | for k in hyperparams: 53 | cnf['save_dir'] += '{}_{}_'.format(k, cnf[k]) 54 | print(cnf['save_dir']) 55 | set_random_seed(cnf['seed']) 56 | st = time.time() 57 | model = ModelHandler(cnf) 58 | f1_mac_test, f1_1_test, f1_0_test, auc_test, gmean_test = model.train() 59 | f1_list.append(f1_mac_test) 60 | f1_1_list.append(f1_1_test) 61 | f1_0_list.append(f1_0_test) 62 | auc_list.append(auc_test) 63 | gmean_list.append(gmean_test) 64 | print("Running {} done, elapsed time {}s".format(i, time.time()-st)) 65 | with open(cnf['result_save'], 'a+') as f: 66 | f.write("{}, F1-Macro: {}, AUC: {}, G-Mean: {}\n".format(cnf['save_dir'], f1_mac_test, auc_test, gmean_test)) 67 | f.close() 68 | 69 | print("F1-Macro: {}".format(f1_list)) 70 | print("AUC: {}".format(auc_list)) 71 | print("G-Mean: {}".format(gmean_list)) 72 | 73 | f1_mean, f1_std = np.mean(f1_list), np.std(f1_list, ddof=1) 74 | f1_1_mean, f1_1_std = np.mean(f1_1_list), np.std(f1_1_list, ddof=1) 75 | f1_0_mean, f1_0_std = np.mean(f1_0_list), np.std(f1_0_list, ddof=1) 76 | auc_mean, auc_std = np.mean(auc_list), np.std(auc_list, ddof=1) 77 | gmean_mean, gmean_std = np.mean(gmean_list), np.std(gmean_list, ddof=1) 78 | 79 | print("F1-Macro: {}+{}".format(f1_mean, f1_std)) 80 | print("F1-binary-1: {}+{}".format(f1_1_mean, f1_1_std)) 81 | print("F1-binary-0: {}+{}".format(f1_0_mean, f1_0_std)) 82 | print("AUC: {}+{}".format(auc_mean, auc_std)) 83 | print("G-Mean: {}+{}".format(gmean_mean, gmean_std)) 84 | 85 | 86 | 87 | ################################################################################ 88 | # ArgParse and Helper Functions # 89 | ################################################################################ 90 | def get_config(config_path="config.yml"): 91 | with open(config_path, "r") as setting: 92 | config = yaml.load(setting, Loader=yaml.FullLoader) 93 | return config 94 | 95 | def get_args(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file') 98 | parser.add_argument('--multi_run', action='store_true', help='flag: multi run') 99 | args = vars(parser.parse_args()) 100 | return args 101 | 102 | 103 | def print_config(config): 104 | logging.info("**************** MODEL CONFIGURATION ****************") 105 | for key in sorted(config.keys()): 106 | val = config[key] 107 | keystr = "{}".format(key) + (" " * (24 - len(key))) 108 | logging.info("{} --> {}".format(keystr, val)) 109 | logging.info("**************** MODEL CONFIGURATION ****************") 110 | 111 | 112 | def grid(kwargs): 113 | """Builds a mesh grid with given keyword arguments for this Config class. 114 | If the value is not a list, then it is considered fixed""" 115 | 116 | class MncDc: 117 | """This is because np.meshgrid does not always work properly...""" 118 | 119 | def __init__(self, a): 120 | self.a = a # tuple! 121 | 122 | def __call__(self): 123 | return self.a 124 | 125 | def merge_dicts(*dicts): 126 | """ 127 | Merges dictionaries recursively. Accepts also `None` and returns always a (possibly empty) dictionary 128 | """ 129 | from functools import reduce 130 | def merge_two_dicts(x, y): 131 | z = x.copy() # start with x's keys and values 132 | z.update(y) # modifies z with y's keys and values & returns None 133 | return z 134 | 135 | return reduce(lambda a, nd: merge_two_dicts(a, nd if nd else {}), dicts, {}) 136 | 137 | 138 | sin = OrderedDict({k: v for k, v in kwargs.items() if isinstance(v, list)}) 139 | for k, v in sin.items(): 140 | copy_v = [] 141 | for e in v: 142 | copy_v.append(MncDc(e) if isinstance(e, tuple) else e) 143 | sin[k] = copy_v 144 | 145 | grd = np.array(np.meshgrid(*sin.values()), dtype=object).T.reshape(-1, len(sin.values())) 146 | return [merge_dicts( 147 | {k: v for k, v in kwargs.items() if not isinstance(v, list)}, 148 | {k: vv[i]() if isinstance(vv[i], MncDc) else vv[i] for i, k in enumerate(sin)} 149 | ) for vv in grd] 150 | 151 | 152 | ################################################################################ 153 | # Module Command-line Behavior # 154 | ################################################################################ 155 | if __name__ == '__main__': 156 | cfg = get_args() 157 | config = get_config(cfg['config']) 158 | if cfg['multi_run']: 159 | multi_run_main(config) 160 | else: 161 | main(config) 162 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import numpy as np 4 | import scipy.sparse as sp 5 | from scipy.io import loadmat 6 | import copy as cp 7 | from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix 8 | from collections import defaultdict 9 | import logging 10 | 11 | 12 | logging.basicConfig(filename='result.log',level=logging.INFO) 13 | """ 14 | Utility functions to handle data and evaluate model. 15 | """ 16 | 17 | def biased_split(dataset): 18 | prefix = '/data' 19 | if dataset == 'yelp': 20 | data_name = 'YelpChi.mat' # 'Amazon.mat' or 'YelpChi.mat' 21 | elif dataset == 'amazon': 22 | data_name = 'Amazon.mat' 23 | data = loadmat(prefix + data_name) 24 | 25 | if data_name == 'YelpChi.mat': 26 | net_list = [data['net_rur'].nonzero(), data['net_rtr'].nonzero(), 27 | data['net_rsr'].nonzero(), data['homo'].nonzero()] 28 | else: # amazon dataset 29 | net_list = [data['net_upu'].nonzero(), data['net_usu'].nonzero(), 30 | data['net_uvu'].nonzero(), data['homo'].nonzero()] 31 | 32 | label = data['label'][0] 33 | pos_nodes = set(label.nonzero()[0].tolist()) 34 | 35 | pos_node_dict, neg_node_dict = defaultdict(lambda: [0, 0]), defaultdict(lambda: [0, 0]) 36 | # extract the edges of positive nodes in each relation graph 37 | net = net_list[-1] 38 | 39 | # calculate (homophily) probability for nodes 40 | for u, v in zip(net[0].tolist(), net[1].tolist()): 41 | if dataset == 'amazon' and min(u, v) < 3305: # 0~3304 are unlabelled nodes for amazon 42 | continue 43 | if u in pos_nodes: 44 | pos_node_dict[u][0] += 1 45 | if label[u] == label[v]: 46 | pos_node_dict[u][1] += 1 47 | else: 48 | neg_node_dict[u][0] += 1 49 | if label[u] == label[v]: 50 | neg_node_dict[u][1] += 1 51 | 52 | p1 = np.zeros(len(label)) 53 | for k in pos_node_dict: 54 | p1[k] = pos_node_dict[k][1] / pos_node_dict[k][0] 55 | p1 = p1 / p1.sum() 56 | pos_index = np.random.choice(range(len(p1)), size=round(0.6 * len(pos_node_dict)), replace=False, p = p1.ravel()) 57 | p2 = np.zeros(len(label)) 58 | for k in neg_node_dict: 59 | p2[k] = neg_node_dict[k][1] / neg_node_dict[k][0] 60 | p2 = p2 / p2.sum() 61 | neg_index = np.random.choice(range(len(p2)), size=round(0.6 * len(neg_node_dict)), replace=False, p = p2.ravel()) 62 | 63 | idx_train = np.concatenate((pos_index, neg_index)) 64 | np.random.shuffle(idx_train) 65 | idx_train = list(idx_train) 66 | y_train = np.array(label[idx_train]) 67 | 68 | # find test label 69 | idx_test = list(set(range(len(label))).difference(set(idx_train)).difference(set(range(3305)))) 70 | random.shuffle(idx_test) 71 | y_test = np.array(label[idx_test]) 72 | return idx_train, idx_test, y_train, y_test 73 | 74 | def load_data(data, prefix='/data'): 75 | """ 76 | Load graph, feature, and label given dataset name 77 | :returns: home and single-relation graphs, feature, label 78 | """ 79 | 80 | if data == 'yelp': 81 | data_file = loadmat(prefix + 'YelpChi.mat') 82 | labels = data_file['label'].flatten() 83 | feat_data = data_file['features'].todense().A 84 | # load the preprocessed adj_lists 85 | with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file: 86 | homo = pickle.load(file) 87 | file.close() 88 | with open(prefix + 'yelp_rur_adjlists.pickle', 'rb') as file: 89 | relation1 = pickle.load(file) 90 | file.close() 91 | with open(prefix + 'yelp_rtr_adjlists.pickle', 'rb') as file: 92 | relation2 = pickle.load(file) 93 | file.close() 94 | with open(prefix + 'yelp_rsr_adjlists.pickle', 'rb') as file: 95 | relation3 = pickle.load(file) 96 | file.close() 97 | elif data == 'amazon': 98 | data_file = loadmat(prefix + 'Amazon.mat') 99 | labels = data_file['label'].flatten() 100 | feat_data = data_file['features'].todense().A 101 | # load the preprocessed adj_lists 102 | with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file: 103 | homo = pickle.load(file) 104 | file.close() 105 | with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file: 106 | relation1 = pickle.load(file) 107 | file.close() 108 | with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file: 109 | relation2 = pickle.load(file) 110 | file.close() 111 | with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file: 112 | relation3 = pickle.load(file) 113 | 114 | return [homo, relation1, relation2, relation3], feat_data, labels 115 | 116 | 117 | def normalize(mx): 118 | """ 119 | Row-normalize sparse matrix 120 | Code from https://github.com/williamleif/graphsage-simple/ 121 | """ 122 | rowsum = np.array(mx.sum(1)) + 0.01 123 | r_inv = np.power(rowsum, -1).flatten() 124 | r_inv[np.isinf(r_inv)] = 0. 125 | r_mat_inv = sp.diags(r_inv) 126 | mx = r_mat_inv.dot(mx) 127 | return mx 128 | 129 | 130 | def sparse_to_adjlist(sp_matrix, filename): 131 | """ 132 | Transfer sparse matrix to adjacency list 133 | :param sp_matrix: the sparse matrix 134 | :param filename: the filename of adjlist 135 | """ 136 | # add self loop 137 | homo_adj = sp_matrix + sp.eye(sp_matrix.shape[0]) 138 | # create adj_list 139 | adj_lists = defaultdict(set) 140 | edges = homo_adj.nonzero() 141 | for index, node in enumerate(edges[0]): 142 | adj_lists[node].add(edges[1][index]) 143 | adj_lists[edges[1][index]].add(node) 144 | with open(filename, 'wb') as file: 145 | pickle.dump(adj_lists, file) 146 | file.close() 147 | 148 | 149 | def pos_neg_split(nodes, labels): 150 | """ 151 | Find positive and negative nodes given a list of nodes and their labels 152 | :param nodes: a list of nodes 153 | :param labels: a list of node labels 154 | :returns: the spited positive and negative nodes 155 | """ 156 | pos_nodes = [] 157 | neg_nodes = cp.deepcopy(nodes) 158 | aux_nodes = cp.deepcopy(nodes) 159 | for idx, label in enumerate(labels): 160 | if label == 1: 161 | pos_nodes.append(aux_nodes[idx]) 162 | neg_nodes.remove(aux_nodes[idx]) 163 | 164 | return pos_nodes, neg_nodes 165 | 166 | 167 | def test_sage(test_cases, labels, model, batch_size, thres=0.5, save=False): 168 | """ 169 | Test the performance of GraphSAGE 170 | :param test_cases: a list of testing node 171 | :param labels: a list of testing node labels 172 | :param model: the GNN model 173 | :param batch_size: number nodes in a batch 174 | """ 175 | 176 | test_batch_num = int(len(test_cases) / batch_size) + 1 177 | gnn_pred_list = [] 178 | gnn_prob_list = [] 179 | for iteration in range(test_batch_num): 180 | i_start = iteration * batch_size 181 | i_end = min((iteration + 1) * batch_size, len(test_cases)) 182 | batch_nodes = test_cases[i_start:i_end] 183 | gnn_prob = model.to_prob(batch_nodes, False) 184 | 185 | gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1] 186 | gnn_pred = prob2pred(gnn_prob_arr, thres) 187 | 188 | gnn_pred_list.extend(gnn_pred.tolist()) 189 | gnn_prob_list.extend(gnn_prob_arr.tolist()) 190 | 191 | auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list)) 192 | f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro') 193 | conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list)) 194 | tn, fp, fn, tp = conf_gnn.ravel() 195 | gmean_gnn = conf_gmean(conf_gnn) 196 | 197 | logging.info(f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}") 198 | logging.info(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}") 199 | return f1_macro_gnn, auc_gnn, gmean_gnn 200 | 201 | 202 | 203 | def prob2pred(y_prob, thres=0.5): 204 | """ 205 | Convert probability to predicted results according to given threshold 206 | :param y_prob: numpy array of probability in [0, 1] 207 | :param thres: binary classification threshold, default 0.5 208 | :returns: the predicted result with the same shape as y_prob 209 | """ 210 | y_pred = np.zeros_like(y_prob, dtype=np.int32) 211 | y_pred[y_prob >= thres] = 1 212 | y_pred[y_prob < thres] = 0 213 | return y_pred 214 | 215 | 216 | def test_GDN(test_cases, labels, model, batch_size, thres=0.5, save=False): 217 | """ 218 | Test the performance of GDN 219 | :param test_cases: a list of testing node 220 | :param labels: a list of testing node labels 221 | :param model: the GNN model 222 | :param batch_size: number nodes in a batch 223 | :returns: the AUC and Recall of GNN and Simi modules 224 | """ 225 | 226 | test_batch_num = int(len(test_cases) / batch_size) + 1 227 | gnn_pred_list = [] 228 | gnn_prob_list = [] 229 | 230 | for iteration in range(test_batch_num): 231 | i_start = iteration * batch_size 232 | i_end = min((iteration + 1) * batch_size, len(test_cases)) 233 | batch_nodes = test_cases[i_start:i_end] 234 | batch_label = labels[i_start:i_end] 235 | gnn_prob = model.to_prob(batch_nodes, batch_label) 236 | gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1] 237 | gnn_pred = prob2pred(gnn_prob_arr, thres) 238 | 239 | gnn_pred_list.extend(gnn_pred.tolist()) 240 | gnn_prob_list.extend(gnn_prob_arr.tolist()) 241 | 242 | auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list)) 243 | f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro') 244 | conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list)) 245 | tn, fp, fn, tp = conf_gnn.ravel() 246 | gmean_gnn = conf_gmean(conf_gnn) 247 | 248 | logging.info(f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}") 249 | logging.info(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}") 250 | return f1_macro_gnn, auc_gnn, gmean_gnn 251 | 252 | def conf_gmean(conf): 253 | tn, fp, fn, tp = conf.ravel() 254 | return (tp*tn/((tp+fn)*(tn+fp)))**0.5 -------------------------------------------------------------------------------- /model_handler.py: -------------------------------------------------------------------------------- 1 | import time, datetime 2 | import os 3 | import random 4 | import argparse 5 | import numpy as np 6 | from sklearn.model_selection import train_test_split 7 | 8 | from utils.utils import test_GDN, test_sage, load_data, pos_neg_split, normalize, biased_split 9 | from models.model import GDNLayer 10 | from models.layers import InterAgg, IntraAgg 11 | from models.graphsage import * 12 | import pickle as pkl 13 | import logging 14 | import torch 15 | import torch.nn as nn 16 | 17 | timestamp = time.time() 18 | timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S') 19 | logging.basicConfig(filename='result.log',level=logging.INFO) 20 | 21 | """ 22 | Training GDN 23 | """ 24 | 25 | 26 | class ModelHandler(object): 27 | 28 | def __init__(self, config): 29 | args = argparse.Namespace(**config) 30 | # load graph, feature, and label 31 | [homo, relation1, relation2, relation3], feat_data, labels = load_data(args.data_name, prefix=args.data_dir) 32 | 33 | # train_test split 34 | np.random.seed(args.seed) 35 | random.seed(args.seed) 36 | 37 | if not args.biased_split: 38 | if args.data_name == 'yelp': 39 | index = list(range(len(labels))) 40 | idx_rest, idx_test, y_rest, y_test = train_test_split(index, labels, stratify=labels, train_size=args.train_ratio, 41 | random_state=2, shuffle=True) 42 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio, 43 | random_state=2, shuffle=True) 44 | elif args.data_name == 'amazon': # amazon 45 | # 0-3304 are unlabeled nodes 46 | index = list(range(3305, len(labels))) 47 | idx_rest, idx_test, y_rest, y_test = train_test_split(index, labels[3305:], stratify=labels[3305:], 48 | train_size=args.train_ratio, random_state=2, shuffle=True) 49 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio, 50 | random_state=2, shuffle=True) 51 | else: 52 | idx_rest, idx_test, y_rest, y_test = biased_split(args.data_name) 53 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio, 54 | random_state=2, shuffle=True) 55 | 56 | print(f'Run on {args.data_name}, postive/total num: {np.sum(labels)}/{len(labels)}, train num {len(y_train)},'+ 57 | f'valid num {len(y_valid)}, test num {len(y_test)}, test positive num {np.sum(y_test)}') 58 | print(f"Classification threshold: {args.thres}") 59 | print(f"Feature dimension: {feat_data.shape[1]}") 60 | 61 | 62 | # split pos neg sets for under-sampling 63 | train_pos, train_neg = pos_neg_split(idx_train, y_train) 64 | 65 | 66 | if args.data_name == 'amazon': 67 | feat_data = normalize(feat_data) 68 | 69 | args.cuda = not args.no_cuda and torch.cuda.is_available() 70 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_id 71 | 72 | # set input graph 73 | if args.model == 'SAGE' or args.model == 'GCN': 74 | adj_lists = homo 75 | else: 76 | adj_lists = [relation1, relation2, relation3] 77 | 78 | print(f'Model: {args.model}, multi-relation aggregator: {args.multi_relation}, emb_size: {args.emb_size}.') 79 | 80 | self.args = args 81 | self.dataset = {'feat_data': feat_data, 'labels': labels, 'adj_lists': adj_lists, 'homo': homo, 82 | 'idx_train': idx_train, 'idx_valid': idx_valid, 'idx_test': idx_test, 83 | 'y_train': y_train, 'y_valid': y_valid, 'y_test': y_test, 84 | 'train_pos': train_pos, 'train_neg': train_neg} 85 | 86 | 87 | def train(self): 88 | args = self.args 89 | feat_data, adj_lists = self.dataset['feat_data'], self.dataset['adj_lists'] 90 | idx_train, y_train = self.dataset['idx_train'], self.dataset['y_train'] 91 | idx_valid, y_valid, idx_test, y_test = self.dataset['idx_valid'], self.dataset['y_valid'], self.dataset['idx_test'], self.dataset['y_test'] 92 | # initialize model input 93 | features = nn.Embedding(feat_data.shape[0], feat_data.shape[1]) 94 | features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=True) 95 | if args.cuda: 96 | features.cuda() 97 | 98 | # build one-layer models 99 | if args.model == 'GDN': 100 | intra1 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda) 101 | intra2 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda) 102 | intra3 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda) 103 | inter1 = InterAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], self.dataset['train_neg'], 104 | adj_lists, [intra1, intra2, intra3], inter=args.multi_relation, cuda=args.cuda) 105 | elif args.model == 'SAGE': 106 | agg_sage = MeanAggregator(features, cuda=args.cuda) 107 | enc_sage = Encoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_sage, self.dataset['train_pos'], self.dataset['train_neg'], gcn=False, cuda=args.cuda) 108 | elif args.model == 'GCN': 109 | agg_gcn = GCNAggregator(features, cuda=args.cuda) 110 | enc_gcn = GCNEncoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_gcn, self.dataset['train_pos'], 111 | self.dataset['train_neg'], gcn=True, cuda=args.cuda) 112 | 113 | if args.model == 'GDN': 114 | gnn_model = GDNLayer(2, inter1) 115 | elif args.model == 'SAGE': 116 | # the vanilla GraphSAGE model as baseline 117 | enc_sage.num_samples = 5 118 | gnn_model = GraphSage(2, enc_sage) 119 | elif args.model == 'GCN': 120 | gnn_model = GCN(2, enc_gcn) 121 | 122 | if args.cuda: 123 | gnn_model.cuda() 124 | 125 | if args.model == 'GDN' or 'SAGE' or 'GCN': 126 | group_1 = [] 127 | group_2 = [] 128 | for name, param in gnn_model.named_parameters(): 129 | print(name) 130 | if name == 'inter1.features.weight': 131 | group_2 += [param] 132 | else: 133 | group_1 += [param] 134 | optimizer = torch.optim.Adam([ 135 | dict(params=group_1, weight_decay=args.weight_decay, lr=args.lr_1), 136 | dict(params=group_2, weight_decay=args.weight_decay_2, lr=args.lr_2) 137 | ], ) 138 | else: 139 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gnn_model.parameters()), lr=args.lr_1, weight_decay=args.weight_decay) 140 | 141 | dir_saver = args.save_dir+timestamp 142 | path_saver = os.path.join(dir_saver, '{}_{}.pkl'.format(args.data_name, args.model)) 143 | f1_mac_best, auc_best, ep_best = 0, 0, -1 144 | 145 | # train the model 146 | for epoch in range(args.num_epochs): 147 | num_batches = int(len(idx_train) / args.batch_size) + 1 148 | 149 | loss = 0.0 150 | epoch_time = 0 151 | 152 | # mini-batch training 153 | for batch in range(num_batches): 154 | start_time = time.time() 155 | i_start = batch * args.batch_size 156 | i_end = min((batch + 1) * args.batch_size, len(idx_train)) 157 | batch_nodes = idx_train[i_start:i_end] 158 | batch_label = self.dataset['labels'][np.array(batch_nodes)] 159 | optimizer.zero_grad() 160 | if args.cuda: 161 | loss = gnn_model.loss(batch_nodes, Variable(torch.cuda.LongTensor(batch_label))) 162 | else: 163 | loss = gnn_model.loss(batch_nodes, Variable(torch.LongTensor(batch_label))) 164 | loss.backward(retain_graph=True) 165 | 166 | # calculate grad and rank 167 | if args.add_constraint: 168 | # fl step 169 | if args.model == 'GDN': 170 | grad = torch.abs(torch.autograd.grad(outputs=loss, inputs=gnn_model.inter1.features.weight)[0]) 171 | elif args.model == 'GCN' or 'SAGE': 172 | grad = torch.abs(torch.autograd.grad(outputs=loss, inputs=gnn_model.enc.features.weight)[0]) 173 | grads_idx = grad.mean(dim=0).topk(k=args.topk).indices 174 | 175 | # find non grad idx 176 | mask_len = feat_data.shape[1] - args.topk 177 | non_grads_idx = torch.zeros(mask_len, dtype=torch.long) 178 | idx = 0 179 | for i in range(feat_data.shape[1]): 180 | if i not in grads_idx: 181 | non_grads_idx[idx] = i 182 | idx += 1 183 | 184 | if args.model == 'GDN': 185 | loss_pos, loss_neg = gnn_model.inter1.fl_loss(grads_idx) 186 | elif args.model == 'GCN' or 'SAGE': 187 | loss_pos, loss_neg = gnn_model.enc.constraint_loss(grads_idx) 188 | loss_stable = args.Beta * torch.exp((loss_pos - loss_neg)) 189 | loss_stable.backward(retain_graph=True) 190 | 191 | # fn step 192 | if args.model == 'GDN': 193 | fn_pos, fn_neg = gnn_model.inter1.fn_loss(batch_nodes, non_grads_idx) 194 | elif args.model == 'GCN' or 'SAGE': 195 | fn_pos, fn_neg = gnn_model.enc.fn_loss(batch_nodes, non_grads_idx) 196 | loss_fn = args.Beta * torch.exp((fn_pos - fn_neg)) 197 | loss_fn.backward() 198 | 199 | optimizer.step() 200 | end_time = time.time() 201 | epoch_time += end_time - start_time 202 | loss += loss.item() 203 | 204 | print(f'Epoch: {epoch}, loss: {loss.item() / num_batches}, time: {epoch_time}s') 205 | 206 | # Valid the model for every $valid_epoch$ epoch 207 | if epoch % args.valid_epochs == 0: 208 | if args.model == 'SAGE' or args.model == 'GCN': 209 | print("Valid at epoch {}".format(epoch)) 210 | f1_mac_val, auc_val, gmean_val = test_sage(idx_valid, y_valid, gnn_model, args.test_batch_size, args.thres) 211 | if auc_val > auc_best: 212 | auc_best, ep_best = auc_val, epoch 213 | if not os.path.exists(dir_saver): 214 | os.makedirs(dir_saver) 215 | print(' Saving model ...') 216 | torch.save(gnn_model.state_dict(), path_saver) 217 | else: 218 | print("Valid at epoch {}".format(epoch)) 219 | f1_mac_val, auc_val, gmean_val = test_GDN(idx_valid, y_valid, gnn_model, args.batch_size, args.thres) 220 | if auc_val > auc_best: 221 | auc_best, ep_best = auc_val, epoch 222 | if not os.path.exists(dir_saver): 223 | os.makedirs(dir_saver) 224 | print(' Saving model ...') 225 | torch.save(gnn_model.state_dict(), path_saver) 226 | with open(args.data_name+'_features.pkl', 'wb+') as f: 227 | pkl.dump(gnn_model.inter1.features.weight, f) 228 | 229 | print("Restore model from epoch {}".format(ep_best)) 230 | print("Model path: {}".format(path_saver)) 231 | gnn_model.load_state_dict(torch.load(path_saver)) 232 | 233 | if args.model == 'SAGE' or args.model == 'GCN': 234 | f1_mac_test, auc_test, gmean_test = test_sage(idx_test, y_test, gnn_model, args.test_batch_size, args.thres) 235 | else: 236 | f1_mac_test, auc_test, gmean_test = test_GDN(idx_test, y_test, gnn_model, args.batch_size, args.thres, True) 237 | return f1_mac_test, auc_test, gmean_test 238 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import random 7 | 8 | """ 9 | GDN Layers 10 | """ 11 | 12 | 13 | class InterAgg(nn.Module): 14 | 15 | def __init__(self, features, feature_dim, embed_dim, 16 | train_pos, train_neg, adj_lists, intraggs, inter='GNN', cuda=True): 17 | """ 18 | Initialize the inter-relation aggregator 19 | :param features: the input node features or embeddings for all nodes 20 | :param feature_dim: the input dimension 21 | :param embed_dim: the embed dimension 22 | :param train_pos: positive samples in training set 23 | :param adj_lists: a list of adjacency lists for each single-relation graph 24 | :param intraggs: the intra-relation aggregators used by each single-relation graph 25 | :param inter: NOT used in this version, the aggregator type: 'Att', 'Weight', 'Mean', 'GNN' 26 | :param cuda: whether to use GPU 27 | """ 28 | super(InterAgg, self).__init__() 29 | 30 | # stored parameters 31 | self.features = features 32 | self.pos_vector = None 33 | self.neg_vector = None 34 | 35 | # Functions 36 | self.softmax = nn.Softmax(dim=-1) 37 | self.KLDiv = nn.KLDivLoss(reduction='batchmean') 38 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) 39 | 40 | self.dropout = 0.6 41 | self.adj_lists = adj_lists 42 | self.intra_agg1 = intraggs[0] 43 | self.intra_agg2 = intraggs[1] 44 | self.intra_agg3 = intraggs[2] 45 | self.embed_dim = embed_dim 46 | self.feat_dim = feature_dim 47 | self.inter = inter 48 | self.cuda = cuda 49 | self.intra_agg1.cuda = cuda 50 | self.intra_agg2.cuda = cuda 51 | self.intra_agg3.cuda = cuda 52 | self.train_pos = train_pos 53 | self.train_neg = train_neg 54 | 55 | # initial filtering thresholds 56 | self.thresholds = [0.5, 0.5, 0.5] 57 | 58 | # parameter used to transform node embeddings before inter-relation aggregation 59 | self.weight = nn.Parameter(torch.FloatTensor(self.embed_dim*len(intraggs)+self.feat_dim, self.embed_dim)) 60 | init.xavier_uniform_(self.weight) 61 | 62 | # label predictor for similarity measure 63 | self.label_clf = nn.Linear(self.feat_dim, 2) 64 | 65 | # initialize the parameter logs 66 | self.weights_log = [] 67 | self.thresholds_log = [self.thresholds] 68 | self.relation_score_log = [] 69 | 70 | 71 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list): 72 | self.pos_index = torch.LongTensor(self.train_pos).cuda() 73 | self.neg_index = torch.LongTensor(self.train_neg).cuda() 74 | else: 75 | self.pos_index = torch.LongTensor(self.train_pos) 76 | self.neg_index = torch.LongTensor(self.train_neg) 77 | 78 | def forward(self, nodes, labels): 79 | """ 80 | :param nodes: a list of batch node ids 81 | :param labels: a list of batch node labels 82 | :param train_flag: indicates whether in training or testing mode 83 | :return combined: the embeddings of a batch of input node features 84 | :return center_scores: the label-aware scores of batch nodes 85 | """ 86 | 87 | # extract 1-hop neighbor ids from adj lists of each single-relation graph 88 | to_neighs = [] 89 | for adj_list in self.adj_lists: 90 | to_neighs.append([set(adj_list[int(node)]) for node in nodes]) 91 | 92 | # get neighbor node id list for each batch node and relation 93 | r1_list = [list(to_neigh) for to_neigh in to_neighs[0]] 94 | r2_list = [list(to_neigh) for to_neigh in to_neighs[1]] 95 | r3_list = [list(to_neigh) for to_neigh in to_neighs[2]] 96 | 97 | # find unique nodes and their neighbors used in current batch 98 | unique_nodes = set.union(set.union(*to_neighs[0]), set.union(*to_neighs[1]), 99 | set.union(*to_neighs[2], set(nodes))) 100 | self.unique_nodes = unique_nodes 101 | 102 | # intra-aggregation steps for each relation 103 | r1_feats = self.intra_agg1.forward(nodes, r1_list) 104 | r2_feats = self.intra_agg2.forward(nodes, r2_list) 105 | r3_feats = self.intra_agg3.forward(nodes, r3_list) 106 | 107 | # get features or embeddings for batch nodes 108 | self_feats = self.fetch_feat(nodes) 109 | 110 | # Update label vector 111 | self.update_label_vector(self.features) 112 | 113 | # concat the intra-aggregated embeddings from each relation 114 | cat_feats = torch.cat((self_feats, r1_feats, r2_feats, r3_feats), dim=1) 115 | combined = F.relu(cat_feats.mm(self.weight).t()) 116 | return combined 117 | 118 | def fetch_feat(self, nodes): 119 | if self.cuda and isinstance(nodes, list): 120 | index = torch.LongTensor(nodes).cuda() 121 | else: 122 | index = torch.LongTensor(nodes) 123 | return self.features(index) 124 | 125 | def cal_simi_scores(self, nodes): 126 | self_feats = self.fetch_feat(nodes) 127 | cosine_pos = self.cos(self.pos_vector, self_feats).detach() 128 | cosine_neg = self.cos(self.neg_vector, self_feats).detach() 129 | simi_scores = torch.cat((cosine_neg.view(-1, 1), cosine_pos.view(-1, 1)), 1) 130 | return simi_scores 131 | 132 | def update_label_vector(self, x): 133 | # pdb.set_trace() 134 | if isinstance(x, torch.Tensor): 135 | x_pos = x[self.train_pos] 136 | x_neg = x[self.train_neg] 137 | elif isinstance(x, torch.nn.Embedding): 138 | x_pos = x(self.pos_index) 139 | x_neg = x(self.neg_index) 140 | if self.pos_vector is None: 141 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach() 142 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach() 143 | else: 144 | cosine_pos = self.cos(self.pos_vector, x_pos) 145 | cosine_neg = self.cos(self.neg_vector, x_neg) 146 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1) 147 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1) 148 | self.pos_vector = torch.mm(weights_pos, x_pos).detach() 149 | self.neg_vector = torch.mm(weights_neg, x_neg).detach() 150 | 151 | def fl_loss(self, grads_idx): 152 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1) 153 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 154 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 155 | loss_pos = self.KLDiv(x, target_pos) 156 | loss_neg = self.KLDiv(x, target_neg) 157 | return loss_pos, loss_neg 158 | 159 | def fn_loss(self, nodes, non_grad_idx): 160 | pos_nodes = set(self.train_pos) 161 | to_neighs = [] 162 | target = [] 163 | for adj_list in self.adj_lists: 164 | target_r = [] 165 | to_neighs_r = [] 166 | for node in nodes: 167 | if int(node) in pos_nodes: 168 | target_r.append(int(node)) 169 | to_neighs_r.append(set(adj_list[int(node)])) 170 | to_neighs.append(to_neighs_r) 171 | target.append(target_r) 172 | 173 | to_neighs_all = [] 174 | for x, y, z in zip(to_neighs[0], to_neighs[1], to_neighs[2]): 175 | to_neighs_all.append(set.union(x, y, z)) 176 | 177 | r1_list = [list(to_neigh) for to_neigh in to_neighs[0]] 178 | r2_list = [list(to_neigh) for to_neigh in to_neighs[1]] 179 | r3_list = [list(to_neigh) for to_neigh in to_neighs[2]] 180 | # print(non_grad_idx) 181 | pos_1, neg_1 = self.intra_agg1.fn_loss(non_grad_idx, target[0], r1_list, self.unique_nodes, to_neighs_all) 182 | pos_2, neg_2 = self.intra_agg2.fn_loss(non_grad_idx, target[1], r2_list, self.unique_nodes, to_neighs_all) 183 | pos_3, neg_3 = self.intra_agg3.fn_loss(non_grad_idx, target[2], r3_list, self.unique_nodes, to_neighs_all) 184 | return pos_1 + pos_2 + pos_3, neg_1 + neg_2 + neg_3 185 | def softmax_with_temperature(self, input, t=1, axis=-1): 186 | ex = torch.exp(input/t) 187 | sum = torch.sum(ex, axis=axis) 188 | return ex/sum 189 | 190 | 191 | class IntraAgg(nn.Module): 192 | 193 | def __init__(self, features, feat_dim, embed_dim, train_pos, cuda=False): 194 | """ 195 | Initialize the intra-relation aggregator 196 | :param features: the input node features or embeddings for all nodes 197 | :param feat_dim: the input dimension 198 | :param embed_dim: the embed dimension 199 | :param train_pos: positive samples in training set 200 | :param cuda: whether to use GPU 201 | """ 202 | super(IntraAgg, self).__init__() 203 | 204 | self.features = features 205 | self.cuda = cuda 206 | self.feat_dim = feat_dim 207 | self.embed_dim = embed_dim 208 | self.train_pos = train_pos 209 | self.weight = nn.Parameter(torch.FloatTensor(2*self.feat_dim, self.embed_dim)) 210 | init.xavier_uniform_(self.weight) 211 | 212 | self.KLDiv = nn.KLDivLoss(reduction='batchmean') 213 | 214 | def forward(self, nodes, to_neighs_list): 215 | """ 216 | Code partially from https://github.com/williamleif/graphsage-simple/ 217 | :param nodes: list of nodes in a batch 218 | :param to_neighs_list: neighbor node id list for each batch node in one relation 219 | :param batch_scores: the label-aware scores of batch nodes 220 | :param neigh_scores: the label-aware scores 1-hop neighbors each batch node in one relation 221 | :param pos_scores: the label-aware scores 1-hop neighbors for the minority positive nodes 222 | :param train_flag: indicates whether in training or testing mode 223 | :param sample_list: the number of neighbors kept for each batch node in one relation 224 | :return to_feats: the aggregated embeddings of batch nodes neighbors in one relation 225 | :return samp_scores: the average neighbor distances for each relation after filtering 226 | """ 227 | 228 | samp_neighs = [set(x) for x in to_neighs_list] 229 | # find the unique nodes among batch nodes and the filtered neighbors 230 | unique_nodes_list = list(set.union(*samp_neighs)) 231 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)} 232 | 233 | 234 | # intra-relation aggregation only with sampled neighbors 235 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) 236 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] 237 | row_indices = [i for i in range(len(samp_neighs)) for _ in range(len(samp_neighs[i]))] 238 | mask[row_indices, column_indices] = 1 239 | if self.cuda: 240 | mask = mask.cuda() 241 | num_neigh = mask.sum(1, keepdim=True) 242 | mask = mask.div(num_neigh) # mean aggregator 243 | if self.cuda: 244 | self_feats = self.features(torch.LongTensor(nodes).cuda()) 245 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda()) 246 | else: 247 | self_feats = self.features(torch.LongTensor(nodes)) 248 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list)) 249 | agg_feats = mask.mm(embed_matrix) # single relation aggregator 250 | cat_feats = torch.cat((self_feats, agg_feats), dim=1) # concat with last layer 251 | to_feats = F.relu(cat_feats.mm(self.weight)) 252 | return to_feats 253 | 254 | def update_label_vector(self, x): 255 | # pdb.set_trace() 256 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list): 257 | pos_index = torch.LongTensor(self.train_pos).cuda() 258 | neg_index = torch.LongTensor(self.train_neg).cuda() 259 | if self.pos_vector is None: 260 | self.pos_vector = torch.mean(x(pos_index), dim=0, keepdim=True).detach() 261 | self.neg_vector = torch.mean(x(neg_index), dim=0, keepdim=True).detach() 262 | else: 263 | cosine_pos = self.cos(self.pos_vector, x(pos_index)) 264 | cosine_neg = self.cos(self.neg_vector, x(neg_index)) 265 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1) 266 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1) 267 | self.pos_vector = torch.mm(weights_pos, x(pos_index)).detach() 268 | self.neg_vector = torch.mm(weights_neg, x(neg_index)).detach() 269 | 270 | def fetch_feat(self, nodes): 271 | if self.cuda and isinstance(nodes, list): 272 | index = torch.LongTensor(nodes).cuda() 273 | else: 274 | index = torch.LongTensor(nodes) 275 | return self.features(index) 276 | 277 | def softmax_with_temperature(self, input, t=1, axis=-1): 278 | ex = torch.exp(input/t) 279 | sum = torch.sum(ex, axis=axis) 280 | return ex/sum 281 | 282 | def fn_loss(self, non_grad_idx, target, neighs, all_nodes, all_neighs): 283 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1) 284 | pos = torch.zeros_like(self.fetch_feat(target)) 285 | neg = torch.zeros_like(self.fetch_feat(target)) 286 | for i in range(len(target)): 287 | pos[i] = torch.mean(self.fetch_feat(neighs[i]), dim=0, keepdim=True) 288 | neg_idx = [random.choice(list(all_nodes.difference(all_neighs[i])))] 289 | neg[i] = self.fetch_feat(neg_idx) 290 | # pdb.set_trace() 291 | pos = pos[:, non_grad_idx].softmax(dim=-1) 292 | neg = neg[:, non_grad_idx].softmax(dim=-1) 293 | loss_pos = self.KLDiv(x, pos) 294 | loss_neg = self.KLDiv(x, neg) 295 | # pdb.set_trace() 296 | return loss_pos, loss_neg -------------------------------------------------------------------------------- /models/graphsage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import random 7 | import pdb 8 | 9 | 10 | """ 11 | GraphSAGE implementations 12 | Paper: Inductive Representation Learning on Large Graphs 13 | Source: https://github.com/williamleif/graphsage-simple/ 14 | """ 15 | 16 | 17 | class GraphSage(nn.Module): 18 | """ 19 | Vanilla GraphSAGE Model 20 | Code partially from https://github.com/williamleif/graphsage-simple/ 21 | """ 22 | def __init__(self, num_classes, enc): 23 | super(GraphSage, self).__init__() 24 | self.enc = enc 25 | self.xent = nn.CrossEntropyLoss() 26 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim)) 27 | init.xavier_uniform_(self.weight) 28 | 29 | def forward(self, nodes): 30 | embeds = self.enc(nodes) 31 | scores = self.weight.mm(embeds) 32 | return scores.t() 33 | 34 | def to_prob(self, nodes, train_flag=True): 35 | pos_scores = torch.sigmoid(self.forward(nodes)) 36 | return pos_scores 37 | 38 | def loss(self, nodes, labels): 39 | scores = self.forward(nodes) 40 | return self.xent(scores, labels.squeeze()) 41 | 42 | 43 | class MeanAggregator(nn.Module): 44 | """ 45 | Aggregates a node's embeddings using mean of neighbors' embeddings 46 | """ 47 | 48 | def __init__(self, features, cuda=False, gcn=False): 49 | """ 50 | Initializes the aggregator for a specific graph. 51 | 52 | features -- function mapping LongTensor of node ids to FloatTensor of feature values. 53 | cuda -- whether to use GPU 54 | gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style 55 | """ 56 | 57 | super(MeanAggregator, self).__init__() 58 | 59 | self.features = features 60 | self.cuda = cuda 61 | self.gcn = gcn 62 | 63 | def forward(self, nodes, to_neighs, num_sample=10): 64 | """ 65 | nodes --- list of nodes in a batch 66 | to_neighs --- list of sets, each set is the set of neighbors for node in batch 67 | num_sample --- number of neighbors to sample. No sampling if None. 68 | """ 69 | # Local pointers to functions (speed hack) 70 | _set = set 71 | if not num_sample is None: 72 | _sample = random.sample 73 | samp_neighs = [_set(_sample(to_neigh, 74 | num_sample, 75 | )) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs] 76 | else: 77 | samp_neighs = to_neighs 78 | 79 | if self.gcn: 80 | samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in enumerate(samp_neighs)] 81 | unique_nodes_list = list(set.union(*samp_neighs)) 82 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)} 83 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) 84 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] 85 | row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] 86 | mask[row_indices, column_indices] = 1 87 | if self.cuda: 88 | mask = mask.cuda() 89 | num_neigh = mask.sum(1, keepdim=True) 90 | mask = mask.div(num_neigh) 91 | if self.cuda: 92 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda()) 93 | else: 94 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list)) 95 | to_feats = mask.mm(embed_matrix) 96 | return to_feats 97 | 98 | 99 | class Encoder(nn.Module): 100 | """ 101 | Vanilla GraphSAGE Encoder Module 102 | Encodes a node's using 'convolutional' GraphSage approach 103 | """ 104 | def __init__(self, features, feature_dim, 105 | embed_dim, adj_lists, aggregator, 106 | train_pos, train_neg, num_sample=10, 107 | base_model=None, gcn=False, cuda=False, 108 | feature_transform=False): 109 | super(Encoder, self).__init__() 110 | 111 | self.features = features 112 | self.feat_dim = feature_dim 113 | self.adj_lists = adj_lists 114 | self.aggregator = aggregator 115 | self.num_sample = num_sample 116 | if base_model != None: 117 | self.base_model = base_model 118 | 119 | self.gcn = gcn 120 | self.embed_dim = embed_dim 121 | self.cuda = cuda 122 | self.aggregator.cuda = cuda 123 | self.weight = nn.Parameter( 124 | torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim)) 125 | init.xavier_uniform_(self.weight) 126 | 127 | self.pos_vector = None 128 | self.neg_vector = None 129 | self.train_pos = train_pos 130 | self.train_neg = train_neg 131 | self.softmax = nn.Softmax(dim=-1) 132 | self.KLDiv = nn.KLDivLoss(reduction='batchmean') 133 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) 134 | 135 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list): 136 | self.pos_index = torch.LongTensor(self.train_pos).cuda() 137 | self.neg_index = torch.LongTensor(self.train_neg).cuda() 138 | else: 139 | self.pos_index = torch.LongTensor(self.train_pos) 140 | self.neg_index = torch.LongTensor(self.train_neg) 141 | 142 | self.unique_nodes = set() 143 | for _, adj_list in self.adj_lists.items(): 144 | self.unique_nodes = self.unique_nodes.union(adj_list) 145 | # def __init__(self, features, feature_dim, 146 | # embed_dim, adj_lists, aggregator, 147 | # num_sample=10, 148 | # base_model=None, gcn=False, cuda=False, 149 | # feature_transform=False): 150 | # super(Encoder, self).__init__() 151 | 152 | # self.features = features 153 | # self.feat_dim = feature_dim 154 | # self.adj_lists = adj_lists 155 | # self.aggregator = aggregator 156 | # self.num_sample = num_sample 157 | # if base_model != None: 158 | # self.base_model = base_model 159 | 160 | # self.gcn = gcn 161 | # self.embed_dim = embed_dim 162 | # self.cuda = cuda 163 | # self.aggregator.cuda = cuda 164 | # self.weight = nn.Parameter( 165 | # torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim)) 166 | # init.xavier_uniform_(self.weight) 167 | 168 | def forward(self, nodes): 169 | """ 170 | Generates embeddings for a batch of nodes. 171 | 172 | nodes -- list of nodes 173 | """ 174 | neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes], 175 | self.num_sample) 176 | self.update_label_vector(self.features) 177 | if isinstance(nodes, list): 178 | index = torch.LongTensor(nodes) 179 | else: 180 | index = nodes 181 | 182 | if not self.gcn: 183 | if self.cuda: 184 | self_feats = self.features(index).cuda() 185 | else: 186 | self_feats = self.features(index) 187 | combined = torch.cat((self_feats, neigh_feats), dim=1) 188 | else: 189 | combined = neigh_feats 190 | combined = F.relu(self.weight.mm(combined.t())) 191 | return combined 192 | 193 | def constraint_loss(self, grads_idx, with_grad): 194 | if with_grad: 195 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1) 196 | # x = F.log_softmax(self.stored_hidden[self.train_pos][:, grads_idx], dim=-1) 197 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 198 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 199 | loss_pos = self.KLDiv(x, target_pos) 200 | loss_neg = self.KLDiv(x, target_neg) 201 | # pdb.set_trace() 202 | else: 203 | x = F.log_softmax(self.features(self.pos_index), dim=-1) 204 | # x = F.log_softmax(self.stored_hidden[self.train_pos], dim=-1) 205 | target_pos = self.pos_vector.repeat(x.shape[0], 1).softmax(dim=-1) 206 | target_neg = self.neg_vector.repeat(x.shape[0], 1).softmax(dim=-1) 207 | loss_pos = self.KLDiv(x, target_pos) 208 | loss_neg = self.KLDiv(x, target_neg) 209 | # pdb.set_trace() 210 | return loss_pos, loss_neg 211 | 212 | def softmax_with_temperature(self, input, t=1, axis=-1): 213 | ex = torch.exp(input/t) 214 | sum = torch.sum(ex, axis=axis) 215 | return ex/sum 216 | 217 | def fn_loss(self, nodes, non_grad_idx): 218 | pos_nodes = set(self.train_pos) 219 | target = [] 220 | neighs = [] 221 | for node in nodes: 222 | if int(node) in pos_nodes: 223 | target.append(int(node)) 224 | neighs.append(self.adj_lists[int(node)]) 225 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1) 226 | pos = torch.zeros_like(self.fetch_feat(target)) 227 | neg = torch.zeros_like(self.fetch_feat(target)) 228 | for i in range(len(target)): 229 | pos[i] = torch.mean(self.fetch_feat(list(neighs[i])), dim=0, keepdim=True) 230 | neg_idx = [random.choice(list(self.unique_nodes.difference(neighs[i])))] 231 | neg[i] = self.fetch_feat(neg_idx) 232 | # pdb.set_trace() 233 | pos = pos[:, non_grad_idx].softmax(dim=-1) 234 | neg = neg[:, non_grad_idx].softmax(dim=-1) 235 | loss_pos = self.KLDiv(x, pos) 236 | loss_neg = self.KLDiv(x, neg) 237 | # pdb.set_trace() 238 | return loss_pos, loss_neg 239 | 240 | def fetch_feat(self, nodes): 241 | if self.cuda and isinstance(nodes, list): 242 | index = torch.LongTensor(nodes).cuda() 243 | else: 244 | index = torch.LongTensor(nodes) 245 | return self.features(index) 246 | 247 | def update_label_vector(self, x): 248 | # pdb.set_trace() 249 | if isinstance(x, torch.Tensor): 250 | x_pos = x[self.train_pos] 251 | x_neg = x[self.train_neg] 252 | elif isinstance(x, torch.nn.Embedding): 253 | x_pos = x(self.pos_index) 254 | x_neg = x(self.neg_index) 255 | if self.pos_vector is None: 256 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach() 257 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach() 258 | else: 259 | cosine_pos = self.cos(self.pos_vector, x_pos) 260 | cosine_neg = self.cos(self.neg_vector, x_neg) 261 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1) 262 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1) 263 | self.pos_vector = torch.mm(weights_pos, x_pos).detach() 264 | self.neg_vector = torch.mm(weights_neg, x_neg).detach() 265 | 266 | 267 | 268 | class GCN(nn.Module): 269 | """ 270 | Vanilla GCN Model 271 | Code partially from https://github.com/williamleif/graphsage-simple/ 272 | """ 273 | def __init__(self, num_classes, enc, add_constraint): 274 | super(GCN, self).__init__() 275 | self.enc = enc 276 | self.xent = nn.CrossEntropyLoss() 277 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim)) 278 | init.xavier_uniform_(self.weight) 279 | self.pos_vector = None 280 | self.neg_vector = None 281 | 282 | 283 | def forward(self, nodes, train_flag=True): 284 | if train_flag: 285 | embeds = self.enc(nodes, train_flag) 286 | scores = self.weight.mm(embeds) 287 | else: 288 | embeds = self.enc(nodes, train_flag) 289 | scores = self.weight.mm(embeds) 290 | return scores.t() 291 | 292 | def to_prob(self, nodes, train_flag=True): 293 | pos_scores = torch.sigmoid(self.forward(nodes, train_flag)) 294 | return pos_scores 295 | 296 | def loss(self, nodes, labels): 297 | scores = self.forward(nodes) 298 | return self.xent(scores, labels.squeeze()) 299 | 300 | 301 | class GCNAggregator(nn.Module): 302 | """ 303 | Aggregates a node's embeddings using normalized mean of neighbors' embeddings 304 | """ 305 | 306 | def __init__(self, features, cuda=False, gcn=False): 307 | """ 308 | Initializes the aggregator for a specific graph. 309 | 310 | features -- function mapping LongTensor of node ids to FloatTensor of feature values. 311 | cuda -- whether to use GPU 312 | gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style 313 | """ 314 | 315 | super(GCNAggregator, self).__init__() 316 | 317 | self.features = features 318 | self.cuda = cuda 319 | self.gcn = gcn 320 | 321 | def forward(self, nodes, to_neighs): 322 | """ 323 | nodes --- list of nodes in a batch 324 | to_neighs --- list of sets, each set is the set of neighbors for node in batch 325 | """ 326 | # Local pointers to functions (speed hack) 327 | 328 | samp_neighs = to_neighs 329 | 330 | # Add self to neighs 331 | samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in enumerate(samp_neighs)] 332 | unique_nodes_list = list(set.union(*samp_neighs)) 333 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)} 334 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) 335 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] 336 | row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] 337 | mask[row_indices, column_indices] = 1.0 # Adjacency matrix for the sub-graph 338 | if self.cuda: 339 | mask = mask.cuda() 340 | row_normalized = mask.sum(1, keepdim=True).sqrt() 341 | col_normalized = mask.sum(0, keepdim=True).sqrt() 342 | mask = mask.div(row_normalized).div(col_normalized) 343 | if self.cuda: 344 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda()) 345 | else: 346 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list)) 347 | to_feats = mask.mm(embed_matrix) 348 | return to_feats 349 | 350 | class GCNEncoder(nn.Module): 351 | """ 352 | GCN Encoder Module 353 | """ 354 | 355 | def __init__(self, features, feature_dim, 356 | embed_dim, adj_lists, aggregator, 357 | train_pos, train_neg, num_sample=10, 358 | base_model=None, gcn=False, cuda=False, 359 | feature_transform=False): 360 | super(GCNEncoder, self).__init__() 361 | 362 | self.features = features 363 | self.feat_dim = feature_dim 364 | self.adj_lists = adj_lists 365 | self.aggregator = aggregator 366 | self.num_sample = num_sample 367 | if base_model != None: 368 | self.base_model = base_model 369 | 370 | self.gcn = gcn 371 | self.embed_dim = embed_dim 372 | self.cuda = cuda 373 | self.aggregator.cuda = cuda 374 | self.weight = nn.Parameter( 375 | torch.FloatTensor(embed_dim, self.feat_dim )) 376 | init.xavier_uniform_(self.weight) 377 | 378 | self.pos_vector = None 379 | self.neg_vector = None 380 | self.train_pos = train_pos 381 | self.train_neg = train_neg 382 | self.softmax = nn.Softmax(dim=-1) 383 | self.KLDiv = nn.KLDivLoss(reduction='batchmean') 384 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6) 385 | 386 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list): 387 | self.pos_index = torch.LongTensor(self.train_pos).cuda() 388 | self.neg_index = torch.LongTensor(self.train_neg).cuda() 389 | else: 390 | self.pos_index = torch.LongTensor(self.train_pos) 391 | self.neg_index = torch.LongTensor(self.train_neg) 392 | 393 | self.unique_nodes = set() 394 | for _, adj_list in self.adj_lists.items(): 395 | self.unique_nodes = self.unique_nodes.union(adj_list) 396 | 397 | 398 | 399 | def forward(self, nodes, train_flag=True): 400 | """ 401 | Generates embeddings for a batch of nodes. 402 | Input: 403 | nodes -- list of nodes 404 | Output: 405 | embed_dim*len(nodes) 406 | """ 407 | neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes]) 408 | self.update_label_vector(self.features) 409 | if isinstance(nodes, list): 410 | index = torch.LongTensor(nodes) 411 | else: 412 | index = nodes 413 | self_feats = self.features(index) 414 | 415 | combined = F.relu(self.weight.mm(neigh_feats.t())) 416 | if not train_flag: 417 | cosine_pos = self.cos(self.pos_vector, self_feats).detach() 418 | cosine_neg = self.cos(self.neg_vector, self_feats).detach() 419 | simi_scores = torch.cat((cosine_neg.view(-1, 1), cosine_pos.view(-1, 1)), 1).t() 420 | # simi_scores = torch.zeros(2, self_feats.shape[0]) 421 | return combined, simi_scores 422 | return combined 423 | 424 | def update_label_vector(self, x): 425 | # pdb.set_trace() 426 | if isinstance(x, torch.Tensor): 427 | x_pos = x[self.train_pos] 428 | x_neg = x[self.train_neg] 429 | elif isinstance(x, torch.nn.Embedding): 430 | x_pos = x(self.pos_index) 431 | x_neg = x(self.neg_index) 432 | if self.pos_vector is None: 433 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach() 434 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach() 435 | else: 436 | cosine_pos = self.cos(self.pos_vector, x_pos) 437 | cosine_neg = self.cos(self.neg_vector, x_neg) 438 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1) 439 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1) 440 | self.pos_vector = torch.mm(weights_pos, x_pos).detach() 441 | self.neg_vector = torch.mm(weights_neg, x_neg).detach() 442 | # pdb.set_trace() 443 | 444 | def constraint_loss(self, grads_idx, with_grad): 445 | if with_grad: 446 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1) 447 | # x = F.log_softmax(self.stored_hidden[self.train_pos][:, grads_idx], dim=-1) 448 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 449 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1) 450 | loss_pos = self.KLDiv(x, target_pos) 451 | loss_neg = self.KLDiv(x, target_neg) 452 | # pdb.set_trace() 453 | else: 454 | x = F.log_softmax(self.features(self.pos_index), dim=-1) 455 | # x = F.log_softmax(self.stored_hidden[self.train_pos], dim=-1) 456 | target_pos = self.pos_vector.repeat(x.shape[0], 1).softmax(dim=-1) 457 | target_neg = self.neg_vector.repeat(x.shape[0], 1).softmax(dim=-1) 458 | loss_pos = self.KLDiv(x, target_pos) 459 | loss_neg = self.KLDiv(x, target_neg) 460 | # pdb.set_trace() 461 | return loss_pos, loss_neg 462 | 463 | def softmax_with_temperature(self, input, t=1, axis=-1): 464 | ex = torch.exp(input/t) 465 | sum = torch.sum(ex, axis=axis) 466 | return ex/sum 467 | 468 | def fn_loss(self, nodes, non_grad_idx): 469 | pos_nodes = set(self.train_pos) 470 | target = [] 471 | neighs = [] 472 | for node in nodes: 473 | if int(node) in pos_nodes: 474 | target.append(int(node)) 475 | neighs.append(self.adj_lists[int(node)]) 476 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1) 477 | pos = torch.zeros_like(self.fetch_feat(target)) 478 | neg = torch.zeros_like(self.fetch_feat(target)) 479 | for i in range(len(target)): 480 | pos[i] = torch.mean(self.fetch_feat(list(neighs[i])), dim=0, keepdim=True) 481 | neg_idx = [random.choice(list(self.unique_nodes.difference(neighs[i])))] 482 | neg[i] = self.fetch_feat(neg_idx) 483 | # pdb.set_trace() 484 | pos = pos[:, non_grad_idx].softmax(dim=-1) 485 | neg = neg[:, non_grad_idx].softmax(dim=-1) 486 | loss_pos = self.KLDiv(x, pos) 487 | loss_neg = self.KLDiv(x, neg) 488 | # pdb.set_trace() 489 | return loss_pos, loss_neg 490 | 491 | def fetch_feat(self, nodes): 492 | if self.cuda and isinstance(nodes, list): 493 | index = torch.LongTensor(nodes).cuda() 494 | else: 495 | index = torch.LongTensor(nodes) 496 | return self.features(index) 497 | --------------------------------------------------------------------------------