├── .gitignore
├── figures
└── topology.png
├── pytorch_models
└── 2022-11-04 16-47-13
│ └── yelp_GDN.pkl
├── config
└── gdn_yelpchi.yml
├── utils
├── data_process.py
└── utils.py
├── models
├── model.py
├── layers.py
└── graphsage.py
├── README.md
├── main.py
└── model_handler.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/figures/topology.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blacksingular/wsdm_GDN/HEAD/figures/topology.png
--------------------------------------------------------------------------------
/pytorch_models/2022-11-04 16-47-13/yelp_GDN.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/blacksingular/wsdm_GDN/HEAD/pytorch_models/2022-11-04 16-47-13/yelp_GDN.pkl
--------------------------------------------------------------------------------
/config/gdn_yelpchi.yml:
--------------------------------------------------------------------------------
1 | # Data
2 | data_name: 'yelp'
3 | data_dir: 'data/'
4 | train_ratio: 0.6
5 | test_ratio: 0.33
6 | save_dir: './pytorch_models/'
7 |
8 | # Model
9 | model: 'GDN'
10 | multi_relation: 'GNN'
11 |
12 |
13 | # Model architecture
14 | emb_size: 64
15 |
16 | thres: 0.2
17 |
18 | seed: 42
19 |
20 | # Run multiple times with different random seeds
21 | # seed:
22 | # - 42
23 | # - 448
24 | # - 854
25 | # - 29493
26 | # - 88867
27 |
28 |
29 | # hyper-parameters
30 | optimizer: 'adam'
31 | lr_1: 0.001
32 | lr_2: 0.00001
33 | weight_decay: 0.001
34 | weight_decay_2: 0.1
35 | batch_size: 1024
36 | test_batch_size: 1024
37 | num_epochs: 101
38 | valid_epochs: 5
39 |
40 | add_constraint: True
41 | Beta: 0.1
42 | # - 0.1
43 | # - 0.5
44 | # - 1
45 | # - 2
46 |
47 | topk: 10
48 | # - 6
49 | # - 12
50 | # - 18
51 | # - 24
52 | biased_split: False
53 | # Device
54 | no_cuda: True
55 | cuda_id: '2'
56 |
--------------------------------------------------------------------------------
/utils/data_process.py:
--------------------------------------------------------------------------------
1 | from utils.utils import sparse_to_adjlist
2 | from scipy.io import loadmat
3 |
4 | """
5 | Read data and save the adjacency matrices to adjacency lists
6 | """
7 |
8 |
9 | if __name__ == "__main__":
10 |
11 | prefix = 'data/'
12 |
13 | yelp = loadmat('data/YelpChi.mat')
14 | net_rur = yelp['net_rur']
15 | net_rtr = yelp['net_rtr']
16 | net_rsr = yelp['net_rsr']
17 | yelp_homo = yelp['homo']
18 |
19 | sparse_to_adjlist(net_rur, prefix + 'yelp_rur_adjlists.pickle')
20 | sparse_to_adjlist(net_rtr, prefix + 'yelp_rtr_adjlists.pickle')
21 | sparse_to_adjlist(net_rsr, prefix + 'yelp_rsr_adjlists.pickle')
22 | sparse_to_adjlist(yelp_homo, prefix + 'yelp_homo_adjlists.pickle')
23 |
24 | # amz = loadmat('data/Amazon.mat')
25 | # net_upu = amz['net_upu']
26 | # net_usu = amz['net_usu']
27 | # net_uvu = amz['net_uvu']
28 | # amz_homo = amz['homo']
29 |
30 | # sparse_to_adjlist(net_upu, prefix + 'amz_upu_adjlists.pickle')
31 | # sparse_to_adjlist(net_usu, prefix + 'amz_usu_adjlists.pickle')
32 | # sparse_to_adjlist(net_uvu, prefix + 'amz_uvu_adjlists.pickle')
33 | # sparse_to_adjlist(amz_homo, prefix + 'amz_homo_adjlists.pickle')
34 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | """
6 | GDN Model
7 | """
8 |
9 |
10 | class GDNLayer(nn.Module):
11 | """
12 | One GDN layer
13 | """
14 |
15 | def __init__(self, num_classes, inter1):
16 | """
17 | Initialize GDN model
18 | :param num_classes: number of classes (2 in our paper)
19 | :param inter1: the inter-relation aggregator that output the final embedding
20 | """
21 | super(GDNLayer, self).__init__()
22 | self.inter1 = inter1
23 | self.xent = nn.CrossEntropyLoss()
24 | self.softmax = nn.Softmax(dim=-1)
25 | self.KLDiv = nn.KLDivLoss(reduction='batchmean')
26 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
27 | # the parameter to transform the final embedding
28 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, inter1.embed_dim))
29 | init.xavier_uniform_(self.weight)
30 |
31 |
32 | def forward(self, nodes, labels):
33 | embeds1 = self.inter1(nodes, labels)
34 | scores = self.weight.mm(embeds1)
35 | return scores.t()
36 |
37 | def to_prob(self, nodes, labels):
38 | gnn_logits = self.forward(nodes, labels)
39 | gnn_scores = self.softmax(gnn_logits)
40 | return gnn_scores
41 |
42 | def loss(self, nodes, labels):
43 | gnn_scores = self.forward(nodes, labels)
44 | # GNN loss
45 | gnn_loss = self.xent(gnn_scores, labels.squeeze())
46 | return gnn_loss
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GDN: Alleviating Structrual Distribution Shift in Graph Anomaly Detection
2 | Pytorch Implementation of
3 |
4 | Alleviating Structrual Distribution Shift in Graph Anomaly Detection (WSDM 2023)
5 |
6 | # Overview
7 | This work solves the SDS problem from a feature view. We observe that the heterophily degree is different across the training and test environments in GAD, leading to the poor generalization of the classifier. To address the issue, we propose GDN to resist high heterophily for anomalies meanwhile benefit the learning of normals from
8 | homophily. Since different labels correspond to the difference of critical anomaly features which make great contributions to the GAD, we tease out the anomaly features on which we constrain to mitigate the effect of heterophilous neighbors and make them invariant. To better estimate the prior distribution of anomaly features, we devise a prototype vector to infer and update this distribution during training. For normal nodes, we constrain the remaining features to preserve the connectivity of nodes and reinforce the influence of the homophilous neighborhood.
9 |
10 |
11 |
12 |
13 |
14 | Illustration of GDN. The feature separation module
15 | separates the node feature into two sets. Two constraints
16 | are leveraged to assist separation. Blank positions in node
17 | representation mean they are zero when calculating losses.
18 |
19 | # Dataset
20 | YelpChi and Amazon can be downloaded from [here](https://github.com/YingtongDou/CARE-GNN/tree/master/data) or [dgl.data.FraudDataset](https://docs.dgl.ai/api/python/dgl.data.html#fraud-dataset).
21 |
22 | Run `python src/data_process.py` to pre-process the data.
23 |
24 | # Dependencies
25 | Please set up the environment following Requirements in this [repository](https://github.com/PonderLY/PC-GNN).
26 | ```sh
27 | argparse 1.1.0
28 | networkx 1.11
29 | numpy 1.16.4
30 | scikit_learn 0.21rc2
31 | scipy 1.2.1
32 | torch 1.4.0
33 | ```
34 |
35 | # Reproduce
36 | ```sh
37 | python main.py --config ./config/gdn_yelpchi.yml
38 | ```
39 |
40 | # Acknowledgement
41 | Our code references:
42 | - [CAREGNN](https://github.com/YingtongDou/CARE-GNN)
43 |
44 | - [PCGNN](https://github.com/PonderLY/PC-GNN)
45 |
46 | # Reference
47 | ```
48 | @inproceedings{
49 | gao2023gdn,
50 | title={Alleviating Structrual Distribution Shift in Graph Anomaly Detection},
51 | author={Yuan Gao and Xiang Wang and Xiangnan He and Zhenguang Liu and Huamin Feng and Yongdong Zhang},
52 | booktitle={WSDM},
53 | year={2023},
54 | }
55 | ```
56 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import yaml
3 | import torch
4 | import datetime
5 | import numpy as np
6 | from collections import defaultdict, OrderedDict
7 | import time
8 |
9 | from model_handler import ModelHandler
10 | import logging
11 |
12 | timestamp = time.time()
13 | timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S')
14 | logging.basicConfig(filename='result.log',level=logging.INFO)
15 | logging.info(timestamp)
16 | # timestamp = time.time()
17 |
18 |
19 | ################################################################################
20 | # Main #
21 | ################################################################################
22 |
23 |
24 | def set_random_seed(seed):
25 | torch.manual_seed(seed)
26 | torch.cuda.manual_seed_all(seed)
27 | np.random.seed(seed)
28 |
29 |
30 | def main(config):
31 | print_config(config)
32 | set_random_seed(config['seed'])
33 | model = ModelHandler(config)
34 | f1_mac_test, auc_test, gmean_test = model.train()
35 | print("F1-Macro: {}".format(f1_mac_test))
36 | print("AUC: {}".format(auc_test))
37 | print("G-Mean: {}".format(gmean_test))
38 |
39 |
40 |
41 | def multi_run_main(config):
42 | print_config(config)
43 | hyperparams = []
44 | for k, v in config.items():
45 | if isinstance(v, list):
46 | hyperparams.append(k)
47 |
48 | f1_list, f1_1_list, f1_0_list, auc_list, gmean_list = [], [], [], [], []
49 | configs = grid(config)
50 | for i, cnf in enumerate(configs):
51 | print('Running {}:\n'.format(i))
52 | for k in hyperparams:
53 | cnf['save_dir'] += '{}_{}_'.format(k, cnf[k])
54 | print(cnf['save_dir'])
55 | set_random_seed(cnf['seed'])
56 | st = time.time()
57 | model = ModelHandler(cnf)
58 | f1_mac_test, f1_1_test, f1_0_test, auc_test, gmean_test = model.train()
59 | f1_list.append(f1_mac_test)
60 | f1_1_list.append(f1_1_test)
61 | f1_0_list.append(f1_0_test)
62 | auc_list.append(auc_test)
63 | gmean_list.append(gmean_test)
64 | print("Running {} done, elapsed time {}s".format(i, time.time()-st))
65 | with open(cnf['result_save'], 'a+') as f:
66 | f.write("{}, F1-Macro: {}, AUC: {}, G-Mean: {}\n".format(cnf['save_dir'], f1_mac_test, auc_test, gmean_test))
67 | f.close()
68 |
69 | print("F1-Macro: {}".format(f1_list))
70 | print("AUC: {}".format(auc_list))
71 | print("G-Mean: {}".format(gmean_list))
72 |
73 | f1_mean, f1_std = np.mean(f1_list), np.std(f1_list, ddof=1)
74 | f1_1_mean, f1_1_std = np.mean(f1_1_list), np.std(f1_1_list, ddof=1)
75 | f1_0_mean, f1_0_std = np.mean(f1_0_list), np.std(f1_0_list, ddof=1)
76 | auc_mean, auc_std = np.mean(auc_list), np.std(auc_list, ddof=1)
77 | gmean_mean, gmean_std = np.mean(gmean_list), np.std(gmean_list, ddof=1)
78 |
79 | print("F1-Macro: {}+{}".format(f1_mean, f1_std))
80 | print("F1-binary-1: {}+{}".format(f1_1_mean, f1_1_std))
81 | print("F1-binary-0: {}+{}".format(f1_0_mean, f1_0_std))
82 | print("AUC: {}+{}".format(auc_mean, auc_std))
83 | print("G-Mean: {}+{}".format(gmean_mean, gmean_std))
84 |
85 |
86 |
87 | ################################################################################
88 | # ArgParse and Helper Functions #
89 | ################################################################################
90 | def get_config(config_path="config.yml"):
91 | with open(config_path, "r") as setting:
92 | config = yaml.load(setting, Loader=yaml.FullLoader)
93 | return config
94 |
95 | def get_args():
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument('-config', '--config', required=True, type=str, help='path to the config file')
98 | parser.add_argument('--multi_run', action='store_true', help='flag: multi run')
99 | args = vars(parser.parse_args())
100 | return args
101 |
102 |
103 | def print_config(config):
104 | logging.info("**************** MODEL CONFIGURATION ****************")
105 | for key in sorted(config.keys()):
106 | val = config[key]
107 | keystr = "{}".format(key) + (" " * (24 - len(key)))
108 | logging.info("{} --> {}".format(keystr, val))
109 | logging.info("**************** MODEL CONFIGURATION ****************")
110 |
111 |
112 | def grid(kwargs):
113 | """Builds a mesh grid with given keyword arguments for this Config class.
114 | If the value is not a list, then it is considered fixed"""
115 |
116 | class MncDc:
117 | """This is because np.meshgrid does not always work properly..."""
118 |
119 | def __init__(self, a):
120 | self.a = a # tuple!
121 |
122 | def __call__(self):
123 | return self.a
124 |
125 | def merge_dicts(*dicts):
126 | """
127 | Merges dictionaries recursively. Accepts also `None` and returns always a (possibly empty) dictionary
128 | """
129 | from functools import reduce
130 | def merge_two_dicts(x, y):
131 | z = x.copy() # start with x's keys and values
132 | z.update(y) # modifies z with y's keys and values & returns None
133 | return z
134 |
135 | return reduce(lambda a, nd: merge_two_dicts(a, nd if nd else {}), dicts, {})
136 |
137 |
138 | sin = OrderedDict({k: v for k, v in kwargs.items() if isinstance(v, list)})
139 | for k, v in sin.items():
140 | copy_v = []
141 | for e in v:
142 | copy_v.append(MncDc(e) if isinstance(e, tuple) else e)
143 | sin[k] = copy_v
144 |
145 | grd = np.array(np.meshgrid(*sin.values()), dtype=object).T.reshape(-1, len(sin.values()))
146 | return [merge_dicts(
147 | {k: v for k, v in kwargs.items() if not isinstance(v, list)},
148 | {k: vv[i]() if isinstance(vv[i], MncDc) else vv[i] for i, k in enumerate(sin)}
149 | ) for vv in grd]
150 |
151 |
152 | ################################################################################
153 | # Module Command-line Behavior #
154 | ################################################################################
155 | if __name__ == '__main__':
156 | cfg = get_args()
157 | config = get_config(cfg['config'])
158 | if cfg['multi_run']:
159 | multi_run_main(config)
160 | else:
161 | main(config)
162 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import random
3 | import numpy as np
4 | import scipy.sparse as sp
5 | from scipy.io import loadmat
6 | import copy as cp
7 | from sklearn.metrics import f1_score, roc_auc_score, confusion_matrix
8 | from collections import defaultdict
9 | import logging
10 |
11 |
12 | logging.basicConfig(filename='result.log',level=logging.INFO)
13 | """
14 | Utility functions to handle data and evaluate model.
15 | """
16 |
17 | def biased_split(dataset):
18 | prefix = '/data'
19 | if dataset == 'yelp':
20 | data_name = 'YelpChi.mat' # 'Amazon.mat' or 'YelpChi.mat'
21 | elif dataset == 'amazon':
22 | data_name = 'Amazon.mat'
23 | data = loadmat(prefix + data_name)
24 |
25 | if data_name == 'YelpChi.mat':
26 | net_list = [data['net_rur'].nonzero(), data['net_rtr'].nonzero(),
27 | data['net_rsr'].nonzero(), data['homo'].nonzero()]
28 | else: # amazon dataset
29 | net_list = [data['net_upu'].nonzero(), data['net_usu'].nonzero(),
30 | data['net_uvu'].nonzero(), data['homo'].nonzero()]
31 |
32 | label = data['label'][0]
33 | pos_nodes = set(label.nonzero()[0].tolist())
34 |
35 | pos_node_dict, neg_node_dict = defaultdict(lambda: [0, 0]), defaultdict(lambda: [0, 0])
36 | # extract the edges of positive nodes in each relation graph
37 | net = net_list[-1]
38 |
39 | # calculate (homophily) probability for nodes
40 | for u, v in zip(net[0].tolist(), net[1].tolist()):
41 | if dataset == 'amazon' and min(u, v) < 3305: # 0~3304 are unlabelled nodes for amazon
42 | continue
43 | if u in pos_nodes:
44 | pos_node_dict[u][0] += 1
45 | if label[u] == label[v]:
46 | pos_node_dict[u][1] += 1
47 | else:
48 | neg_node_dict[u][0] += 1
49 | if label[u] == label[v]:
50 | neg_node_dict[u][1] += 1
51 |
52 | p1 = np.zeros(len(label))
53 | for k in pos_node_dict:
54 | p1[k] = pos_node_dict[k][1] / pos_node_dict[k][0]
55 | p1 = p1 / p1.sum()
56 | pos_index = np.random.choice(range(len(p1)), size=round(0.6 * len(pos_node_dict)), replace=False, p = p1.ravel())
57 | p2 = np.zeros(len(label))
58 | for k in neg_node_dict:
59 | p2[k] = neg_node_dict[k][1] / neg_node_dict[k][0]
60 | p2 = p2 / p2.sum()
61 | neg_index = np.random.choice(range(len(p2)), size=round(0.6 * len(neg_node_dict)), replace=False, p = p2.ravel())
62 |
63 | idx_train = np.concatenate((pos_index, neg_index))
64 | np.random.shuffle(idx_train)
65 | idx_train = list(idx_train)
66 | y_train = np.array(label[idx_train])
67 |
68 | # find test label
69 | idx_test = list(set(range(len(label))).difference(set(idx_train)).difference(set(range(3305))))
70 | random.shuffle(idx_test)
71 | y_test = np.array(label[idx_test])
72 | return idx_train, idx_test, y_train, y_test
73 |
74 | def load_data(data, prefix='/data'):
75 | """
76 | Load graph, feature, and label given dataset name
77 | :returns: home and single-relation graphs, feature, label
78 | """
79 |
80 | if data == 'yelp':
81 | data_file = loadmat(prefix + 'YelpChi.mat')
82 | labels = data_file['label'].flatten()
83 | feat_data = data_file['features'].todense().A
84 | # load the preprocessed adj_lists
85 | with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
86 | homo = pickle.load(file)
87 | file.close()
88 | with open(prefix + 'yelp_rur_adjlists.pickle', 'rb') as file:
89 | relation1 = pickle.load(file)
90 | file.close()
91 | with open(prefix + 'yelp_rtr_adjlists.pickle', 'rb') as file:
92 | relation2 = pickle.load(file)
93 | file.close()
94 | with open(prefix + 'yelp_rsr_adjlists.pickle', 'rb') as file:
95 | relation3 = pickle.load(file)
96 | file.close()
97 | elif data == 'amazon':
98 | data_file = loadmat(prefix + 'Amazon.mat')
99 | labels = data_file['label'].flatten()
100 | feat_data = data_file['features'].todense().A
101 | # load the preprocessed adj_lists
102 | with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
103 | homo = pickle.load(file)
104 | file.close()
105 | with open(prefix + 'amz_upu_adjlists.pickle', 'rb') as file:
106 | relation1 = pickle.load(file)
107 | file.close()
108 | with open(prefix + 'amz_usu_adjlists.pickle', 'rb') as file:
109 | relation2 = pickle.load(file)
110 | file.close()
111 | with open(prefix + 'amz_uvu_adjlists.pickle', 'rb') as file:
112 | relation3 = pickle.load(file)
113 |
114 | return [homo, relation1, relation2, relation3], feat_data, labels
115 |
116 |
117 | def normalize(mx):
118 | """
119 | Row-normalize sparse matrix
120 | Code from https://github.com/williamleif/graphsage-simple/
121 | """
122 | rowsum = np.array(mx.sum(1)) + 0.01
123 | r_inv = np.power(rowsum, -1).flatten()
124 | r_inv[np.isinf(r_inv)] = 0.
125 | r_mat_inv = sp.diags(r_inv)
126 | mx = r_mat_inv.dot(mx)
127 | return mx
128 |
129 |
130 | def sparse_to_adjlist(sp_matrix, filename):
131 | """
132 | Transfer sparse matrix to adjacency list
133 | :param sp_matrix: the sparse matrix
134 | :param filename: the filename of adjlist
135 | """
136 | # add self loop
137 | homo_adj = sp_matrix + sp.eye(sp_matrix.shape[0])
138 | # create adj_list
139 | adj_lists = defaultdict(set)
140 | edges = homo_adj.nonzero()
141 | for index, node in enumerate(edges[0]):
142 | adj_lists[node].add(edges[1][index])
143 | adj_lists[edges[1][index]].add(node)
144 | with open(filename, 'wb') as file:
145 | pickle.dump(adj_lists, file)
146 | file.close()
147 |
148 |
149 | def pos_neg_split(nodes, labels):
150 | """
151 | Find positive and negative nodes given a list of nodes and their labels
152 | :param nodes: a list of nodes
153 | :param labels: a list of node labels
154 | :returns: the spited positive and negative nodes
155 | """
156 | pos_nodes = []
157 | neg_nodes = cp.deepcopy(nodes)
158 | aux_nodes = cp.deepcopy(nodes)
159 | for idx, label in enumerate(labels):
160 | if label == 1:
161 | pos_nodes.append(aux_nodes[idx])
162 | neg_nodes.remove(aux_nodes[idx])
163 |
164 | return pos_nodes, neg_nodes
165 |
166 |
167 | def test_sage(test_cases, labels, model, batch_size, thres=0.5, save=False):
168 | """
169 | Test the performance of GraphSAGE
170 | :param test_cases: a list of testing node
171 | :param labels: a list of testing node labels
172 | :param model: the GNN model
173 | :param batch_size: number nodes in a batch
174 | """
175 |
176 | test_batch_num = int(len(test_cases) / batch_size) + 1
177 | gnn_pred_list = []
178 | gnn_prob_list = []
179 | for iteration in range(test_batch_num):
180 | i_start = iteration * batch_size
181 | i_end = min((iteration + 1) * batch_size, len(test_cases))
182 | batch_nodes = test_cases[i_start:i_end]
183 | gnn_prob = model.to_prob(batch_nodes, False)
184 |
185 | gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1]
186 | gnn_pred = prob2pred(gnn_prob_arr, thres)
187 |
188 | gnn_pred_list.extend(gnn_pred.tolist())
189 | gnn_prob_list.extend(gnn_prob_arr.tolist())
190 |
191 | auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list))
192 | f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro')
193 | conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list))
194 | tn, fp, fn, tp = conf_gnn.ravel()
195 | gmean_gnn = conf_gmean(conf_gnn)
196 |
197 | logging.info(f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}")
198 | logging.info(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}")
199 | return f1_macro_gnn, auc_gnn, gmean_gnn
200 |
201 |
202 |
203 | def prob2pred(y_prob, thres=0.5):
204 | """
205 | Convert probability to predicted results according to given threshold
206 | :param y_prob: numpy array of probability in [0, 1]
207 | :param thres: binary classification threshold, default 0.5
208 | :returns: the predicted result with the same shape as y_prob
209 | """
210 | y_pred = np.zeros_like(y_prob, dtype=np.int32)
211 | y_pred[y_prob >= thres] = 1
212 | y_pred[y_prob < thres] = 0
213 | return y_pred
214 |
215 |
216 | def test_GDN(test_cases, labels, model, batch_size, thres=0.5, save=False):
217 | """
218 | Test the performance of GDN
219 | :param test_cases: a list of testing node
220 | :param labels: a list of testing node labels
221 | :param model: the GNN model
222 | :param batch_size: number nodes in a batch
223 | :returns: the AUC and Recall of GNN and Simi modules
224 | """
225 |
226 | test_batch_num = int(len(test_cases) / batch_size) + 1
227 | gnn_pred_list = []
228 | gnn_prob_list = []
229 |
230 | for iteration in range(test_batch_num):
231 | i_start = iteration * batch_size
232 | i_end = min((iteration + 1) * batch_size, len(test_cases))
233 | batch_nodes = test_cases[i_start:i_end]
234 | batch_label = labels[i_start:i_end]
235 | gnn_prob = model.to_prob(batch_nodes, batch_label)
236 | gnn_prob_arr = gnn_prob.data.cpu().numpy()[:, 1]
237 | gnn_pred = prob2pred(gnn_prob_arr, thres)
238 |
239 | gnn_pred_list.extend(gnn_pred.tolist())
240 | gnn_prob_list.extend(gnn_prob_arr.tolist())
241 |
242 | auc_gnn = roc_auc_score(labels, np.array(gnn_prob_list))
243 | f1_macro_gnn = f1_score(labels, np.array(gnn_pred_list), average='macro')
244 | conf_gnn = confusion_matrix(labels, np.array(gnn_pred_list))
245 | tn, fp, fn, tp = conf_gnn.ravel()
246 | gmean_gnn = conf_gmean(conf_gnn)
247 |
248 | logging.info(f"\tF1-macro: {f1_macro_gnn:.4f}\tG-Mean: {gmean_gnn:.4f}\tAUC: {auc_gnn:.4f}")
249 | logging.info(f" GNN TP: {tp}\tTN: {tn}\tFN: {fn}\tFP: {fp}")
250 | return f1_macro_gnn, auc_gnn, gmean_gnn
251 |
252 | def conf_gmean(conf):
253 | tn, fp, fn, tp = conf.ravel()
254 | return (tp*tn/((tp+fn)*(tn+fp)))**0.5
--------------------------------------------------------------------------------
/model_handler.py:
--------------------------------------------------------------------------------
1 | import time, datetime
2 | import os
3 | import random
4 | import argparse
5 | import numpy as np
6 | from sklearn.model_selection import train_test_split
7 |
8 | from utils.utils import test_GDN, test_sage, load_data, pos_neg_split, normalize, biased_split
9 | from models.model import GDNLayer
10 | from models.layers import InterAgg, IntraAgg
11 | from models.graphsage import *
12 | import pickle as pkl
13 | import logging
14 | import torch
15 | import torch.nn as nn
16 |
17 | timestamp = time.time()
18 | timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S')
19 | logging.basicConfig(filename='result.log',level=logging.INFO)
20 |
21 | """
22 | Training GDN
23 | """
24 |
25 |
26 | class ModelHandler(object):
27 |
28 | def __init__(self, config):
29 | args = argparse.Namespace(**config)
30 | # load graph, feature, and label
31 | [homo, relation1, relation2, relation3], feat_data, labels = load_data(args.data_name, prefix=args.data_dir)
32 |
33 | # train_test split
34 | np.random.seed(args.seed)
35 | random.seed(args.seed)
36 |
37 | if not args.biased_split:
38 | if args.data_name == 'yelp':
39 | index = list(range(len(labels)))
40 | idx_rest, idx_test, y_rest, y_test = train_test_split(index, labels, stratify=labels, train_size=args.train_ratio,
41 | random_state=2, shuffle=True)
42 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio,
43 | random_state=2, shuffle=True)
44 | elif args.data_name == 'amazon': # amazon
45 | # 0-3304 are unlabeled nodes
46 | index = list(range(3305, len(labels)))
47 | idx_rest, idx_test, y_rest, y_test = train_test_split(index, labels[3305:], stratify=labels[3305:],
48 | train_size=args.train_ratio, random_state=2, shuffle=True)
49 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio,
50 | random_state=2, shuffle=True)
51 | else:
52 | idx_rest, idx_test, y_rest, y_test = biased_split(args.data_name)
53 | idx_train, idx_valid, y_train, y_valid = train_test_split(idx_rest, y_rest, stratify=y_rest, test_size=args.test_ratio,
54 | random_state=2, shuffle=True)
55 |
56 | print(f'Run on {args.data_name}, postive/total num: {np.sum(labels)}/{len(labels)}, train num {len(y_train)},'+
57 | f'valid num {len(y_valid)}, test num {len(y_test)}, test positive num {np.sum(y_test)}')
58 | print(f"Classification threshold: {args.thres}")
59 | print(f"Feature dimension: {feat_data.shape[1]}")
60 |
61 |
62 | # split pos neg sets for under-sampling
63 | train_pos, train_neg = pos_neg_split(idx_train, y_train)
64 |
65 |
66 | if args.data_name == 'amazon':
67 | feat_data = normalize(feat_data)
68 |
69 | args.cuda = not args.no_cuda and torch.cuda.is_available()
70 | os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_id
71 |
72 | # set input graph
73 | if args.model == 'SAGE' or args.model == 'GCN':
74 | adj_lists = homo
75 | else:
76 | adj_lists = [relation1, relation2, relation3]
77 |
78 | print(f'Model: {args.model}, multi-relation aggregator: {args.multi_relation}, emb_size: {args.emb_size}.')
79 |
80 | self.args = args
81 | self.dataset = {'feat_data': feat_data, 'labels': labels, 'adj_lists': adj_lists, 'homo': homo,
82 | 'idx_train': idx_train, 'idx_valid': idx_valid, 'idx_test': idx_test,
83 | 'y_train': y_train, 'y_valid': y_valid, 'y_test': y_test,
84 | 'train_pos': train_pos, 'train_neg': train_neg}
85 |
86 |
87 | def train(self):
88 | args = self.args
89 | feat_data, adj_lists = self.dataset['feat_data'], self.dataset['adj_lists']
90 | idx_train, y_train = self.dataset['idx_train'], self.dataset['y_train']
91 | idx_valid, y_valid, idx_test, y_test = self.dataset['idx_valid'], self.dataset['y_valid'], self.dataset['idx_test'], self.dataset['y_test']
92 | # initialize model input
93 | features = nn.Embedding(feat_data.shape[0], feat_data.shape[1])
94 | features.weight = nn.Parameter(torch.FloatTensor(feat_data), requires_grad=True)
95 | if args.cuda:
96 | features.cuda()
97 |
98 | # build one-layer models
99 | if args.model == 'GDN':
100 | intra1 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda)
101 | intra2 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda)
102 | intra3 = IntraAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], cuda=args.cuda)
103 | inter1 = InterAgg(features, feat_data.shape[1], args.emb_size, self.dataset['train_pos'], self.dataset['train_neg'],
104 | adj_lists, [intra1, intra2, intra3], inter=args.multi_relation, cuda=args.cuda)
105 | elif args.model == 'SAGE':
106 | agg_sage = MeanAggregator(features, cuda=args.cuda)
107 | enc_sage = Encoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_sage, self.dataset['train_pos'], self.dataset['train_neg'], gcn=False, cuda=args.cuda)
108 | elif args.model == 'GCN':
109 | agg_gcn = GCNAggregator(features, cuda=args.cuda)
110 | enc_gcn = GCNEncoder(features, feat_data.shape[1], args.emb_size, adj_lists, agg_gcn, self.dataset['train_pos'],
111 | self.dataset['train_neg'], gcn=True, cuda=args.cuda)
112 |
113 | if args.model == 'GDN':
114 | gnn_model = GDNLayer(2, inter1)
115 | elif args.model == 'SAGE':
116 | # the vanilla GraphSAGE model as baseline
117 | enc_sage.num_samples = 5
118 | gnn_model = GraphSage(2, enc_sage)
119 | elif args.model == 'GCN':
120 | gnn_model = GCN(2, enc_gcn)
121 |
122 | if args.cuda:
123 | gnn_model.cuda()
124 |
125 | if args.model == 'GDN' or 'SAGE' or 'GCN':
126 | group_1 = []
127 | group_2 = []
128 | for name, param in gnn_model.named_parameters():
129 | print(name)
130 | if name == 'inter1.features.weight':
131 | group_2 += [param]
132 | else:
133 | group_1 += [param]
134 | optimizer = torch.optim.Adam([
135 | dict(params=group_1, weight_decay=args.weight_decay, lr=args.lr_1),
136 | dict(params=group_2, weight_decay=args.weight_decay_2, lr=args.lr_2)
137 | ], )
138 | else:
139 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, gnn_model.parameters()), lr=args.lr_1, weight_decay=args.weight_decay)
140 |
141 | dir_saver = args.save_dir+timestamp
142 | path_saver = os.path.join(dir_saver, '{}_{}.pkl'.format(args.data_name, args.model))
143 | f1_mac_best, auc_best, ep_best = 0, 0, -1
144 |
145 | # train the model
146 | for epoch in range(args.num_epochs):
147 | num_batches = int(len(idx_train) / args.batch_size) + 1
148 |
149 | loss = 0.0
150 | epoch_time = 0
151 |
152 | # mini-batch training
153 | for batch in range(num_batches):
154 | start_time = time.time()
155 | i_start = batch * args.batch_size
156 | i_end = min((batch + 1) * args.batch_size, len(idx_train))
157 | batch_nodes = idx_train[i_start:i_end]
158 | batch_label = self.dataset['labels'][np.array(batch_nodes)]
159 | optimizer.zero_grad()
160 | if args.cuda:
161 | loss = gnn_model.loss(batch_nodes, Variable(torch.cuda.LongTensor(batch_label)))
162 | else:
163 | loss = gnn_model.loss(batch_nodes, Variable(torch.LongTensor(batch_label)))
164 | loss.backward(retain_graph=True)
165 |
166 | # calculate grad and rank
167 | if args.add_constraint:
168 | # fl step
169 | if args.model == 'GDN':
170 | grad = torch.abs(torch.autograd.grad(outputs=loss, inputs=gnn_model.inter1.features.weight)[0])
171 | elif args.model == 'GCN' or 'SAGE':
172 | grad = torch.abs(torch.autograd.grad(outputs=loss, inputs=gnn_model.enc.features.weight)[0])
173 | grads_idx = grad.mean(dim=0).topk(k=args.topk).indices
174 |
175 | # find non grad idx
176 | mask_len = feat_data.shape[1] - args.topk
177 | non_grads_idx = torch.zeros(mask_len, dtype=torch.long)
178 | idx = 0
179 | for i in range(feat_data.shape[1]):
180 | if i not in grads_idx:
181 | non_grads_idx[idx] = i
182 | idx += 1
183 |
184 | if args.model == 'GDN':
185 | loss_pos, loss_neg = gnn_model.inter1.fl_loss(grads_idx)
186 | elif args.model == 'GCN' or 'SAGE':
187 | loss_pos, loss_neg = gnn_model.enc.constraint_loss(grads_idx)
188 | loss_stable = args.Beta * torch.exp((loss_pos - loss_neg))
189 | loss_stable.backward(retain_graph=True)
190 |
191 | # fn step
192 | if args.model == 'GDN':
193 | fn_pos, fn_neg = gnn_model.inter1.fn_loss(batch_nodes, non_grads_idx)
194 | elif args.model == 'GCN' or 'SAGE':
195 | fn_pos, fn_neg = gnn_model.enc.fn_loss(batch_nodes, non_grads_idx)
196 | loss_fn = args.Beta * torch.exp((fn_pos - fn_neg))
197 | loss_fn.backward()
198 |
199 | optimizer.step()
200 | end_time = time.time()
201 | epoch_time += end_time - start_time
202 | loss += loss.item()
203 |
204 | print(f'Epoch: {epoch}, loss: {loss.item() / num_batches}, time: {epoch_time}s')
205 |
206 | # Valid the model for every $valid_epoch$ epoch
207 | if epoch % args.valid_epochs == 0:
208 | if args.model == 'SAGE' or args.model == 'GCN':
209 | print("Valid at epoch {}".format(epoch))
210 | f1_mac_val, auc_val, gmean_val = test_sage(idx_valid, y_valid, gnn_model, args.test_batch_size, args.thres)
211 | if auc_val > auc_best:
212 | auc_best, ep_best = auc_val, epoch
213 | if not os.path.exists(dir_saver):
214 | os.makedirs(dir_saver)
215 | print(' Saving model ...')
216 | torch.save(gnn_model.state_dict(), path_saver)
217 | else:
218 | print("Valid at epoch {}".format(epoch))
219 | f1_mac_val, auc_val, gmean_val = test_GDN(idx_valid, y_valid, gnn_model, args.batch_size, args.thres)
220 | if auc_val > auc_best:
221 | auc_best, ep_best = auc_val, epoch
222 | if not os.path.exists(dir_saver):
223 | os.makedirs(dir_saver)
224 | print(' Saving model ...')
225 | torch.save(gnn_model.state_dict(), path_saver)
226 | with open(args.data_name+'_features.pkl', 'wb+') as f:
227 | pkl.dump(gnn_model.inter1.features.weight, f)
228 |
229 | print("Restore model from epoch {}".format(ep_best))
230 | print("Model path: {}".format(path_saver))
231 | gnn_model.load_state_dict(torch.load(path_saver))
232 |
233 | if args.model == 'SAGE' or args.model == 'GCN':
234 | f1_mac_test, auc_test, gmean_test = test_sage(idx_test, y_test, gnn_model, args.test_batch_size, args.thres)
235 | else:
236 | f1_mac_test, auc_test, gmean_test = test_GDN(idx_test, y_test, gnn_model, args.batch_size, args.thres, True)
237 | return f1_mac_test, auc_test, gmean_test
238 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import random
7 |
8 | """
9 | GDN Layers
10 | """
11 |
12 |
13 | class InterAgg(nn.Module):
14 |
15 | def __init__(self, features, feature_dim, embed_dim,
16 | train_pos, train_neg, adj_lists, intraggs, inter='GNN', cuda=True):
17 | """
18 | Initialize the inter-relation aggregator
19 | :param features: the input node features or embeddings for all nodes
20 | :param feature_dim: the input dimension
21 | :param embed_dim: the embed dimension
22 | :param train_pos: positive samples in training set
23 | :param adj_lists: a list of adjacency lists for each single-relation graph
24 | :param intraggs: the intra-relation aggregators used by each single-relation graph
25 | :param inter: NOT used in this version, the aggregator type: 'Att', 'Weight', 'Mean', 'GNN'
26 | :param cuda: whether to use GPU
27 | """
28 | super(InterAgg, self).__init__()
29 |
30 | # stored parameters
31 | self.features = features
32 | self.pos_vector = None
33 | self.neg_vector = None
34 |
35 | # Functions
36 | self.softmax = nn.Softmax(dim=-1)
37 | self.KLDiv = nn.KLDivLoss(reduction='batchmean')
38 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
39 |
40 | self.dropout = 0.6
41 | self.adj_lists = adj_lists
42 | self.intra_agg1 = intraggs[0]
43 | self.intra_agg2 = intraggs[1]
44 | self.intra_agg3 = intraggs[2]
45 | self.embed_dim = embed_dim
46 | self.feat_dim = feature_dim
47 | self.inter = inter
48 | self.cuda = cuda
49 | self.intra_agg1.cuda = cuda
50 | self.intra_agg2.cuda = cuda
51 | self.intra_agg3.cuda = cuda
52 | self.train_pos = train_pos
53 | self.train_neg = train_neg
54 |
55 | # initial filtering thresholds
56 | self.thresholds = [0.5, 0.5, 0.5]
57 |
58 | # parameter used to transform node embeddings before inter-relation aggregation
59 | self.weight = nn.Parameter(torch.FloatTensor(self.embed_dim*len(intraggs)+self.feat_dim, self.embed_dim))
60 | init.xavier_uniform_(self.weight)
61 |
62 | # label predictor for similarity measure
63 | self.label_clf = nn.Linear(self.feat_dim, 2)
64 |
65 | # initialize the parameter logs
66 | self.weights_log = []
67 | self.thresholds_log = [self.thresholds]
68 | self.relation_score_log = []
69 |
70 |
71 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list):
72 | self.pos_index = torch.LongTensor(self.train_pos).cuda()
73 | self.neg_index = torch.LongTensor(self.train_neg).cuda()
74 | else:
75 | self.pos_index = torch.LongTensor(self.train_pos)
76 | self.neg_index = torch.LongTensor(self.train_neg)
77 |
78 | def forward(self, nodes, labels):
79 | """
80 | :param nodes: a list of batch node ids
81 | :param labels: a list of batch node labels
82 | :param train_flag: indicates whether in training or testing mode
83 | :return combined: the embeddings of a batch of input node features
84 | :return center_scores: the label-aware scores of batch nodes
85 | """
86 |
87 | # extract 1-hop neighbor ids from adj lists of each single-relation graph
88 | to_neighs = []
89 | for adj_list in self.adj_lists:
90 | to_neighs.append([set(adj_list[int(node)]) for node in nodes])
91 |
92 | # get neighbor node id list for each batch node and relation
93 | r1_list = [list(to_neigh) for to_neigh in to_neighs[0]]
94 | r2_list = [list(to_neigh) for to_neigh in to_neighs[1]]
95 | r3_list = [list(to_neigh) for to_neigh in to_neighs[2]]
96 |
97 | # find unique nodes and their neighbors used in current batch
98 | unique_nodes = set.union(set.union(*to_neighs[0]), set.union(*to_neighs[1]),
99 | set.union(*to_neighs[2], set(nodes)))
100 | self.unique_nodes = unique_nodes
101 |
102 | # intra-aggregation steps for each relation
103 | r1_feats = self.intra_agg1.forward(nodes, r1_list)
104 | r2_feats = self.intra_agg2.forward(nodes, r2_list)
105 | r3_feats = self.intra_agg3.forward(nodes, r3_list)
106 |
107 | # get features or embeddings for batch nodes
108 | self_feats = self.fetch_feat(nodes)
109 |
110 | # Update label vector
111 | self.update_label_vector(self.features)
112 |
113 | # concat the intra-aggregated embeddings from each relation
114 | cat_feats = torch.cat((self_feats, r1_feats, r2_feats, r3_feats), dim=1)
115 | combined = F.relu(cat_feats.mm(self.weight).t())
116 | return combined
117 |
118 | def fetch_feat(self, nodes):
119 | if self.cuda and isinstance(nodes, list):
120 | index = torch.LongTensor(nodes).cuda()
121 | else:
122 | index = torch.LongTensor(nodes)
123 | return self.features(index)
124 |
125 | def cal_simi_scores(self, nodes):
126 | self_feats = self.fetch_feat(nodes)
127 | cosine_pos = self.cos(self.pos_vector, self_feats).detach()
128 | cosine_neg = self.cos(self.neg_vector, self_feats).detach()
129 | simi_scores = torch.cat((cosine_neg.view(-1, 1), cosine_pos.view(-1, 1)), 1)
130 | return simi_scores
131 |
132 | def update_label_vector(self, x):
133 | # pdb.set_trace()
134 | if isinstance(x, torch.Tensor):
135 | x_pos = x[self.train_pos]
136 | x_neg = x[self.train_neg]
137 | elif isinstance(x, torch.nn.Embedding):
138 | x_pos = x(self.pos_index)
139 | x_neg = x(self.neg_index)
140 | if self.pos_vector is None:
141 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach()
142 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach()
143 | else:
144 | cosine_pos = self.cos(self.pos_vector, x_pos)
145 | cosine_neg = self.cos(self.neg_vector, x_neg)
146 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1)
147 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1)
148 | self.pos_vector = torch.mm(weights_pos, x_pos).detach()
149 | self.neg_vector = torch.mm(weights_neg, x_neg).detach()
150 |
151 | def fl_loss(self, grads_idx):
152 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1)
153 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
154 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
155 | loss_pos = self.KLDiv(x, target_pos)
156 | loss_neg = self.KLDiv(x, target_neg)
157 | return loss_pos, loss_neg
158 |
159 | def fn_loss(self, nodes, non_grad_idx):
160 | pos_nodes = set(self.train_pos)
161 | to_neighs = []
162 | target = []
163 | for adj_list in self.adj_lists:
164 | target_r = []
165 | to_neighs_r = []
166 | for node in nodes:
167 | if int(node) in pos_nodes:
168 | target_r.append(int(node))
169 | to_neighs_r.append(set(adj_list[int(node)]))
170 | to_neighs.append(to_neighs_r)
171 | target.append(target_r)
172 |
173 | to_neighs_all = []
174 | for x, y, z in zip(to_neighs[0], to_neighs[1], to_neighs[2]):
175 | to_neighs_all.append(set.union(x, y, z))
176 |
177 | r1_list = [list(to_neigh) for to_neigh in to_neighs[0]]
178 | r2_list = [list(to_neigh) for to_neigh in to_neighs[1]]
179 | r3_list = [list(to_neigh) for to_neigh in to_neighs[2]]
180 | # print(non_grad_idx)
181 | pos_1, neg_1 = self.intra_agg1.fn_loss(non_grad_idx, target[0], r1_list, self.unique_nodes, to_neighs_all)
182 | pos_2, neg_2 = self.intra_agg2.fn_loss(non_grad_idx, target[1], r2_list, self.unique_nodes, to_neighs_all)
183 | pos_3, neg_3 = self.intra_agg3.fn_loss(non_grad_idx, target[2], r3_list, self.unique_nodes, to_neighs_all)
184 | return pos_1 + pos_2 + pos_3, neg_1 + neg_2 + neg_3
185 | def softmax_with_temperature(self, input, t=1, axis=-1):
186 | ex = torch.exp(input/t)
187 | sum = torch.sum(ex, axis=axis)
188 | return ex/sum
189 |
190 |
191 | class IntraAgg(nn.Module):
192 |
193 | def __init__(self, features, feat_dim, embed_dim, train_pos, cuda=False):
194 | """
195 | Initialize the intra-relation aggregator
196 | :param features: the input node features or embeddings for all nodes
197 | :param feat_dim: the input dimension
198 | :param embed_dim: the embed dimension
199 | :param train_pos: positive samples in training set
200 | :param cuda: whether to use GPU
201 | """
202 | super(IntraAgg, self).__init__()
203 |
204 | self.features = features
205 | self.cuda = cuda
206 | self.feat_dim = feat_dim
207 | self.embed_dim = embed_dim
208 | self.train_pos = train_pos
209 | self.weight = nn.Parameter(torch.FloatTensor(2*self.feat_dim, self.embed_dim))
210 | init.xavier_uniform_(self.weight)
211 |
212 | self.KLDiv = nn.KLDivLoss(reduction='batchmean')
213 |
214 | def forward(self, nodes, to_neighs_list):
215 | """
216 | Code partially from https://github.com/williamleif/graphsage-simple/
217 | :param nodes: list of nodes in a batch
218 | :param to_neighs_list: neighbor node id list for each batch node in one relation
219 | :param batch_scores: the label-aware scores of batch nodes
220 | :param neigh_scores: the label-aware scores 1-hop neighbors each batch node in one relation
221 | :param pos_scores: the label-aware scores 1-hop neighbors for the minority positive nodes
222 | :param train_flag: indicates whether in training or testing mode
223 | :param sample_list: the number of neighbors kept for each batch node in one relation
224 | :return to_feats: the aggregated embeddings of batch nodes neighbors in one relation
225 | :return samp_scores: the average neighbor distances for each relation after filtering
226 | """
227 |
228 | samp_neighs = [set(x) for x in to_neighs_list]
229 | # find the unique nodes among batch nodes and the filtered neighbors
230 | unique_nodes_list = list(set.union(*samp_neighs))
231 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)}
232 |
233 |
234 | # intra-relation aggregation only with sampled neighbors
235 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
236 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
237 | row_indices = [i for i in range(len(samp_neighs)) for _ in range(len(samp_neighs[i]))]
238 | mask[row_indices, column_indices] = 1
239 | if self.cuda:
240 | mask = mask.cuda()
241 | num_neigh = mask.sum(1, keepdim=True)
242 | mask = mask.div(num_neigh) # mean aggregator
243 | if self.cuda:
244 | self_feats = self.features(torch.LongTensor(nodes).cuda())
245 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
246 | else:
247 | self_feats = self.features(torch.LongTensor(nodes))
248 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
249 | agg_feats = mask.mm(embed_matrix) # single relation aggregator
250 | cat_feats = torch.cat((self_feats, agg_feats), dim=1) # concat with last layer
251 | to_feats = F.relu(cat_feats.mm(self.weight))
252 | return to_feats
253 |
254 | def update_label_vector(self, x):
255 | # pdb.set_trace()
256 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list):
257 | pos_index = torch.LongTensor(self.train_pos).cuda()
258 | neg_index = torch.LongTensor(self.train_neg).cuda()
259 | if self.pos_vector is None:
260 | self.pos_vector = torch.mean(x(pos_index), dim=0, keepdim=True).detach()
261 | self.neg_vector = torch.mean(x(neg_index), dim=0, keepdim=True).detach()
262 | else:
263 | cosine_pos = self.cos(self.pos_vector, x(pos_index))
264 | cosine_neg = self.cos(self.neg_vector, x(neg_index))
265 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1)
266 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1)
267 | self.pos_vector = torch.mm(weights_pos, x(pos_index)).detach()
268 | self.neg_vector = torch.mm(weights_neg, x(neg_index)).detach()
269 |
270 | def fetch_feat(self, nodes):
271 | if self.cuda and isinstance(nodes, list):
272 | index = torch.LongTensor(nodes).cuda()
273 | else:
274 | index = torch.LongTensor(nodes)
275 | return self.features(index)
276 |
277 | def softmax_with_temperature(self, input, t=1, axis=-1):
278 | ex = torch.exp(input/t)
279 | sum = torch.sum(ex, axis=axis)
280 | return ex/sum
281 |
282 | def fn_loss(self, non_grad_idx, target, neighs, all_nodes, all_neighs):
283 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1)
284 | pos = torch.zeros_like(self.fetch_feat(target))
285 | neg = torch.zeros_like(self.fetch_feat(target))
286 | for i in range(len(target)):
287 | pos[i] = torch.mean(self.fetch_feat(neighs[i]), dim=0, keepdim=True)
288 | neg_idx = [random.choice(list(all_nodes.difference(all_neighs[i])))]
289 | neg[i] = self.fetch_feat(neg_idx)
290 | # pdb.set_trace()
291 | pos = pos[:, non_grad_idx].softmax(dim=-1)
292 | neg = neg[:, non_grad_idx].softmax(dim=-1)
293 | loss_pos = self.KLDiv(x, pos)
294 | loss_neg = self.KLDiv(x, neg)
295 | # pdb.set_trace()
296 | return loss_pos, loss_neg
--------------------------------------------------------------------------------
/models/graphsage.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import random
7 | import pdb
8 |
9 |
10 | """
11 | GraphSAGE implementations
12 | Paper: Inductive Representation Learning on Large Graphs
13 | Source: https://github.com/williamleif/graphsage-simple/
14 | """
15 |
16 |
17 | class GraphSage(nn.Module):
18 | """
19 | Vanilla GraphSAGE Model
20 | Code partially from https://github.com/williamleif/graphsage-simple/
21 | """
22 | def __init__(self, num_classes, enc):
23 | super(GraphSage, self).__init__()
24 | self.enc = enc
25 | self.xent = nn.CrossEntropyLoss()
26 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
27 | init.xavier_uniform_(self.weight)
28 |
29 | def forward(self, nodes):
30 | embeds = self.enc(nodes)
31 | scores = self.weight.mm(embeds)
32 | return scores.t()
33 |
34 | def to_prob(self, nodes, train_flag=True):
35 | pos_scores = torch.sigmoid(self.forward(nodes))
36 | return pos_scores
37 |
38 | def loss(self, nodes, labels):
39 | scores = self.forward(nodes)
40 | return self.xent(scores, labels.squeeze())
41 |
42 |
43 | class MeanAggregator(nn.Module):
44 | """
45 | Aggregates a node's embeddings using mean of neighbors' embeddings
46 | """
47 |
48 | def __init__(self, features, cuda=False, gcn=False):
49 | """
50 | Initializes the aggregator for a specific graph.
51 |
52 | features -- function mapping LongTensor of node ids to FloatTensor of feature values.
53 | cuda -- whether to use GPU
54 | gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
55 | """
56 |
57 | super(MeanAggregator, self).__init__()
58 |
59 | self.features = features
60 | self.cuda = cuda
61 | self.gcn = gcn
62 |
63 | def forward(self, nodes, to_neighs, num_sample=10):
64 | """
65 | nodes --- list of nodes in a batch
66 | to_neighs --- list of sets, each set is the set of neighbors for node in batch
67 | num_sample --- number of neighbors to sample. No sampling if None.
68 | """
69 | # Local pointers to functions (speed hack)
70 | _set = set
71 | if not num_sample is None:
72 | _sample = random.sample
73 | samp_neighs = [_set(_sample(to_neigh,
74 | num_sample,
75 | )) if len(to_neigh) >= num_sample else to_neigh for to_neigh in to_neighs]
76 | else:
77 | samp_neighs = to_neighs
78 |
79 | if self.gcn:
80 | samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in enumerate(samp_neighs)]
81 | unique_nodes_list = list(set.union(*samp_neighs))
82 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)}
83 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
84 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
85 | row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
86 | mask[row_indices, column_indices] = 1
87 | if self.cuda:
88 | mask = mask.cuda()
89 | num_neigh = mask.sum(1, keepdim=True)
90 | mask = mask.div(num_neigh)
91 | if self.cuda:
92 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
93 | else:
94 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
95 | to_feats = mask.mm(embed_matrix)
96 | return to_feats
97 |
98 |
99 | class Encoder(nn.Module):
100 | """
101 | Vanilla GraphSAGE Encoder Module
102 | Encodes a node's using 'convolutional' GraphSage approach
103 | """
104 | def __init__(self, features, feature_dim,
105 | embed_dim, adj_lists, aggregator,
106 | train_pos, train_neg, num_sample=10,
107 | base_model=None, gcn=False, cuda=False,
108 | feature_transform=False):
109 | super(Encoder, self).__init__()
110 |
111 | self.features = features
112 | self.feat_dim = feature_dim
113 | self.adj_lists = adj_lists
114 | self.aggregator = aggregator
115 | self.num_sample = num_sample
116 | if base_model != None:
117 | self.base_model = base_model
118 |
119 | self.gcn = gcn
120 | self.embed_dim = embed_dim
121 | self.cuda = cuda
122 | self.aggregator.cuda = cuda
123 | self.weight = nn.Parameter(
124 | torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
125 | init.xavier_uniform_(self.weight)
126 |
127 | self.pos_vector = None
128 | self.neg_vector = None
129 | self.train_pos = train_pos
130 | self.train_neg = train_neg
131 | self.softmax = nn.Softmax(dim=-1)
132 | self.KLDiv = nn.KLDivLoss(reduction='batchmean')
133 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
134 |
135 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list):
136 | self.pos_index = torch.LongTensor(self.train_pos).cuda()
137 | self.neg_index = torch.LongTensor(self.train_neg).cuda()
138 | else:
139 | self.pos_index = torch.LongTensor(self.train_pos)
140 | self.neg_index = torch.LongTensor(self.train_neg)
141 |
142 | self.unique_nodes = set()
143 | for _, adj_list in self.adj_lists.items():
144 | self.unique_nodes = self.unique_nodes.union(adj_list)
145 | # def __init__(self, features, feature_dim,
146 | # embed_dim, adj_lists, aggregator,
147 | # num_sample=10,
148 | # base_model=None, gcn=False, cuda=False,
149 | # feature_transform=False):
150 | # super(Encoder, self).__init__()
151 |
152 | # self.features = features
153 | # self.feat_dim = feature_dim
154 | # self.adj_lists = adj_lists
155 | # self.aggregator = aggregator
156 | # self.num_sample = num_sample
157 | # if base_model != None:
158 | # self.base_model = base_model
159 |
160 | # self.gcn = gcn
161 | # self.embed_dim = embed_dim
162 | # self.cuda = cuda
163 | # self.aggregator.cuda = cuda
164 | # self.weight = nn.Parameter(
165 | # torch.FloatTensor(embed_dim, self.feat_dim if self.gcn else 2 * self.feat_dim))
166 | # init.xavier_uniform_(self.weight)
167 |
168 | def forward(self, nodes):
169 | """
170 | Generates embeddings for a batch of nodes.
171 |
172 | nodes -- list of nodes
173 | """
174 | neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes],
175 | self.num_sample)
176 | self.update_label_vector(self.features)
177 | if isinstance(nodes, list):
178 | index = torch.LongTensor(nodes)
179 | else:
180 | index = nodes
181 |
182 | if not self.gcn:
183 | if self.cuda:
184 | self_feats = self.features(index).cuda()
185 | else:
186 | self_feats = self.features(index)
187 | combined = torch.cat((self_feats, neigh_feats), dim=1)
188 | else:
189 | combined = neigh_feats
190 | combined = F.relu(self.weight.mm(combined.t()))
191 | return combined
192 |
193 | def constraint_loss(self, grads_idx, with_grad):
194 | if with_grad:
195 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1)
196 | # x = F.log_softmax(self.stored_hidden[self.train_pos][:, grads_idx], dim=-1)
197 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
198 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
199 | loss_pos = self.KLDiv(x, target_pos)
200 | loss_neg = self.KLDiv(x, target_neg)
201 | # pdb.set_trace()
202 | else:
203 | x = F.log_softmax(self.features(self.pos_index), dim=-1)
204 | # x = F.log_softmax(self.stored_hidden[self.train_pos], dim=-1)
205 | target_pos = self.pos_vector.repeat(x.shape[0], 1).softmax(dim=-1)
206 | target_neg = self.neg_vector.repeat(x.shape[0], 1).softmax(dim=-1)
207 | loss_pos = self.KLDiv(x, target_pos)
208 | loss_neg = self.KLDiv(x, target_neg)
209 | # pdb.set_trace()
210 | return loss_pos, loss_neg
211 |
212 | def softmax_with_temperature(self, input, t=1, axis=-1):
213 | ex = torch.exp(input/t)
214 | sum = torch.sum(ex, axis=axis)
215 | return ex/sum
216 |
217 | def fn_loss(self, nodes, non_grad_idx):
218 | pos_nodes = set(self.train_pos)
219 | target = []
220 | neighs = []
221 | for node in nodes:
222 | if int(node) in pos_nodes:
223 | target.append(int(node))
224 | neighs.append(self.adj_lists[int(node)])
225 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1)
226 | pos = torch.zeros_like(self.fetch_feat(target))
227 | neg = torch.zeros_like(self.fetch_feat(target))
228 | for i in range(len(target)):
229 | pos[i] = torch.mean(self.fetch_feat(list(neighs[i])), dim=0, keepdim=True)
230 | neg_idx = [random.choice(list(self.unique_nodes.difference(neighs[i])))]
231 | neg[i] = self.fetch_feat(neg_idx)
232 | # pdb.set_trace()
233 | pos = pos[:, non_grad_idx].softmax(dim=-1)
234 | neg = neg[:, non_grad_idx].softmax(dim=-1)
235 | loss_pos = self.KLDiv(x, pos)
236 | loss_neg = self.KLDiv(x, neg)
237 | # pdb.set_trace()
238 | return loss_pos, loss_neg
239 |
240 | def fetch_feat(self, nodes):
241 | if self.cuda and isinstance(nodes, list):
242 | index = torch.LongTensor(nodes).cuda()
243 | else:
244 | index = torch.LongTensor(nodes)
245 | return self.features(index)
246 |
247 | def update_label_vector(self, x):
248 | # pdb.set_trace()
249 | if isinstance(x, torch.Tensor):
250 | x_pos = x[self.train_pos]
251 | x_neg = x[self.train_neg]
252 | elif isinstance(x, torch.nn.Embedding):
253 | x_pos = x(self.pos_index)
254 | x_neg = x(self.neg_index)
255 | if self.pos_vector is None:
256 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach()
257 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach()
258 | else:
259 | cosine_pos = self.cos(self.pos_vector, x_pos)
260 | cosine_neg = self.cos(self.neg_vector, x_neg)
261 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1)
262 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1)
263 | self.pos_vector = torch.mm(weights_pos, x_pos).detach()
264 | self.neg_vector = torch.mm(weights_neg, x_neg).detach()
265 |
266 |
267 |
268 | class GCN(nn.Module):
269 | """
270 | Vanilla GCN Model
271 | Code partially from https://github.com/williamleif/graphsage-simple/
272 | """
273 | def __init__(self, num_classes, enc, add_constraint):
274 | super(GCN, self).__init__()
275 | self.enc = enc
276 | self.xent = nn.CrossEntropyLoss()
277 | self.weight = nn.Parameter(torch.FloatTensor(num_classes, enc.embed_dim))
278 | init.xavier_uniform_(self.weight)
279 | self.pos_vector = None
280 | self.neg_vector = None
281 |
282 |
283 | def forward(self, nodes, train_flag=True):
284 | if train_flag:
285 | embeds = self.enc(nodes, train_flag)
286 | scores = self.weight.mm(embeds)
287 | else:
288 | embeds = self.enc(nodes, train_flag)
289 | scores = self.weight.mm(embeds)
290 | return scores.t()
291 |
292 | def to_prob(self, nodes, train_flag=True):
293 | pos_scores = torch.sigmoid(self.forward(nodes, train_flag))
294 | return pos_scores
295 |
296 | def loss(self, nodes, labels):
297 | scores = self.forward(nodes)
298 | return self.xent(scores, labels.squeeze())
299 |
300 |
301 | class GCNAggregator(nn.Module):
302 | """
303 | Aggregates a node's embeddings using normalized mean of neighbors' embeddings
304 | """
305 |
306 | def __init__(self, features, cuda=False, gcn=False):
307 | """
308 | Initializes the aggregator for a specific graph.
309 |
310 | features -- function mapping LongTensor of node ids to FloatTensor of feature values.
311 | cuda -- whether to use GPU
312 | gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
313 | """
314 |
315 | super(GCNAggregator, self).__init__()
316 |
317 | self.features = features
318 | self.cuda = cuda
319 | self.gcn = gcn
320 |
321 | def forward(self, nodes, to_neighs):
322 | """
323 | nodes --- list of nodes in a batch
324 | to_neighs --- list of sets, each set is the set of neighbors for node in batch
325 | """
326 | # Local pointers to functions (speed hack)
327 |
328 | samp_neighs = to_neighs
329 |
330 | # Add self to neighs
331 | samp_neighs = [samp_neigh.union(set([int(nodes[i])])) for i, samp_neigh in enumerate(samp_neighs)]
332 | unique_nodes_list = list(set.union(*samp_neighs))
333 | unique_nodes = {n: i for i, n in enumerate(unique_nodes_list)}
334 | mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes)))
335 | column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh]
336 | row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))]
337 | mask[row_indices, column_indices] = 1.0 # Adjacency matrix for the sub-graph
338 | if self.cuda:
339 | mask = mask.cuda()
340 | row_normalized = mask.sum(1, keepdim=True).sqrt()
341 | col_normalized = mask.sum(0, keepdim=True).sqrt()
342 | mask = mask.div(row_normalized).div(col_normalized)
343 | if self.cuda:
344 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list).cuda())
345 | else:
346 | embed_matrix = self.features(torch.LongTensor(unique_nodes_list))
347 | to_feats = mask.mm(embed_matrix)
348 | return to_feats
349 |
350 | class GCNEncoder(nn.Module):
351 | """
352 | GCN Encoder Module
353 | """
354 |
355 | def __init__(self, features, feature_dim,
356 | embed_dim, adj_lists, aggregator,
357 | train_pos, train_neg, num_sample=10,
358 | base_model=None, gcn=False, cuda=False,
359 | feature_transform=False):
360 | super(GCNEncoder, self).__init__()
361 |
362 | self.features = features
363 | self.feat_dim = feature_dim
364 | self.adj_lists = adj_lists
365 | self.aggregator = aggregator
366 | self.num_sample = num_sample
367 | if base_model != None:
368 | self.base_model = base_model
369 |
370 | self.gcn = gcn
371 | self.embed_dim = embed_dim
372 | self.cuda = cuda
373 | self.aggregator.cuda = cuda
374 | self.weight = nn.Parameter(
375 | torch.FloatTensor(embed_dim, self.feat_dim ))
376 | init.xavier_uniform_(self.weight)
377 |
378 | self.pos_vector = None
379 | self.neg_vector = None
380 | self.train_pos = train_pos
381 | self.train_neg = train_neg
382 | self.softmax = nn.Softmax(dim=-1)
383 | self.KLDiv = nn.KLDivLoss(reduction='batchmean')
384 | self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
385 |
386 | if self.cuda and isinstance(self.train_pos, list) and isinstance(self.train_neg, list):
387 | self.pos_index = torch.LongTensor(self.train_pos).cuda()
388 | self.neg_index = torch.LongTensor(self.train_neg).cuda()
389 | else:
390 | self.pos_index = torch.LongTensor(self.train_pos)
391 | self.neg_index = torch.LongTensor(self.train_neg)
392 |
393 | self.unique_nodes = set()
394 | for _, adj_list in self.adj_lists.items():
395 | self.unique_nodes = self.unique_nodes.union(adj_list)
396 |
397 |
398 |
399 | def forward(self, nodes, train_flag=True):
400 | """
401 | Generates embeddings for a batch of nodes.
402 | Input:
403 | nodes -- list of nodes
404 | Output:
405 | embed_dim*len(nodes)
406 | """
407 | neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes])
408 | self.update_label_vector(self.features)
409 | if isinstance(nodes, list):
410 | index = torch.LongTensor(nodes)
411 | else:
412 | index = nodes
413 | self_feats = self.features(index)
414 |
415 | combined = F.relu(self.weight.mm(neigh_feats.t()))
416 | if not train_flag:
417 | cosine_pos = self.cos(self.pos_vector, self_feats).detach()
418 | cosine_neg = self.cos(self.neg_vector, self_feats).detach()
419 | simi_scores = torch.cat((cosine_neg.view(-1, 1), cosine_pos.view(-1, 1)), 1).t()
420 | # simi_scores = torch.zeros(2, self_feats.shape[0])
421 | return combined, simi_scores
422 | return combined
423 |
424 | def update_label_vector(self, x):
425 | # pdb.set_trace()
426 | if isinstance(x, torch.Tensor):
427 | x_pos = x[self.train_pos]
428 | x_neg = x[self.train_neg]
429 | elif isinstance(x, torch.nn.Embedding):
430 | x_pos = x(self.pos_index)
431 | x_neg = x(self.neg_index)
432 | if self.pos_vector is None:
433 | self.pos_vector = torch.mean(x_pos, dim=0, keepdim=True).detach()
434 | self.neg_vector = torch.mean(x_neg, dim=0, keepdim=True).detach()
435 | else:
436 | cosine_pos = self.cos(self.pos_vector, x_pos)
437 | cosine_neg = self.cos(self.neg_vector, x_neg)
438 | weights_pos = self.softmax_with_temperature(cosine_pos, t=5).reshape(1, -1)
439 | weights_neg = self.softmax_with_temperature(cosine_neg, t=5).reshape(1, -1)
440 | self.pos_vector = torch.mm(weights_pos, x_pos).detach()
441 | self.neg_vector = torch.mm(weights_neg, x_neg).detach()
442 | # pdb.set_trace()
443 |
444 | def constraint_loss(self, grads_idx, with_grad):
445 | if with_grad:
446 | x = F.log_softmax(self.features(self.pos_index)[:, grads_idx], dim=-1)
447 | # x = F.log_softmax(self.stored_hidden[self.train_pos][:, grads_idx], dim=-1)
448 | target_pos = self.pos_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
449 | target_neg = self.neg_vector[:, grads_idx].repeat(x.shape[0], 1).softmax(dim=-1)
450 | loss_pos = self.KLDiv(x, target_pos)
451 | loss_neg = self.KLDiv(x, target_neg)
452 | # pdb.set_trace()
453 | else:
454 | x = F.log_softmax(self.features(self.pos_index), dim=-1)
455 | # x = F.log_softmax(self.stored_hidden[self.train_pos], dim=-1)
456 | target_pos = self.pos_vector.repeat(x.shape[0], 1).softmax(dim=-1)
457 | target_neg = self.neg_vector.repeat(x.shape[0], 1).softmax(dim=-1)
458 | loss_pos = self.KLDiv(x, target_pos)
459 | loss_neg = self.KLDiv(x, target_neg)
460 | # pdb.set_trace()
461 | return loss_pos, loss_neg
462 |
463 | def softmax_with_temperature(self, input, t=1, axis=-1):
464 | ex = torch.exp(input/t)
465 | sum = torch.sum(ex, axis=axis)
466 | return ex/sum
467 |
468 | def fn_loss(self, nodes, non_grad_idx):
469 | pos_nodes = set(self.train_pos)
470 | target = []
471 | neighs = []
472 | for node in nodes:
473 | if int(node) in pos_nodes:
474 | target.append(int(node))
475 | neighs.append(self.adj_lists[int(node)])
476 | x = F.log_softmax(self.fetch_feat(target)[:, non_grad_idx], dim=-1)
477 | pos = torch.zeros_like(self.fetch_feat(target))
478 | neg = torch.zeros_like(self.fetch_feat(target))
479 | for i in range(len(target)):
480 | pos[i] = torch.mean(self.fetch_feat(list(neighs[i])), dim=0, keepdim=True)
481 | neg_idx = [random.choice(list(self.unique_nodes.difference(neighs[i])))]
482 | neg[i] = self.fetch_feat(neg_idx)
483 | # pdb.set_trace()
484 | pos = pos[:, non_grad_idx].softmax(dim=-1)
485 | neg = neg[:, non_grad_idx].softmax(dim=-1)
486 | loss_pos = self.KLDiv(x, pos)
487 | loss_neg = self.KLDiv(x, neg)
488 | # pdb.set_trace()
489 | return loss_pos, loss_neg
490 |
491 | def fetch_feat(self, nodes):
492 | if self.cuda and isinstance(nodes, list):
493 | index = torch.LongTensor(nodes).cuda()
494 | else:
495 | index = torch.LongTensor(nodes)
496 | return self.features(index)
497 |
--------------------------------------------------------------------------------