├── README.md
├── graph
├── dataset.py
└── train.py
├── img
└── arc.png
├── node
├── dataset.py
└── train.py
├── requirements.txt
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Contrastive Multi-View Representation Learning on Graphs
2 |
3 | This work introduces a self-supervised approach based on contrastive multi-view
4 | learning to learn node and graph level representations.
5 |
6 | It has been accepted at ICML 2020:
7 |
8 | [https://arxiv.org/abs/2006.05582](https://arxiv.org/abs/2006.05582)
9 |
10 |
11 |
12 |
13 |
14 | 
15 |
16 |
17 | ## Reference
18 |
19 | ```
20 | @incollection{icml2020_1971,
21 | author = {Hassani, Kaveh and Khasahmadi, Amir Hosein},
22 | booktitle = {Proceedings of International Conference on Machine Learning},
23 | pages = {3451--3461},
24 | title = {Contrastive Multi-View Representation Learning on Graphs},
25 | year = {2020}
26 | }
27 | ```
28 |
--------------------------------------------------------------------------------
/graph/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import numpy as np
4 | import networkx as nx
5 | from collections import Counter
6 | from utils import compute_ppr, normalize_adj
7 |
8 |
9 | def download(dataset):
10 | basedir = os.path.dirname(os.path.abspath(__file__))
11 | datadir = os.path.join(basedir, 'data', dataset)
12 | if not os.path.exists(datadir):
13 | os.makedirs(datadir)
14 | url = 'https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{0}.zip'.format(dataset)
15 | zipfile = os.path.basename(url)
16 | os.system('wget {0}; unzip {1}'.format(url, zipfile))
17 | os.system('mv {0}/* {1}'.format(dataset, datadir))
18 | os.system('rm -r {0}'.format(dataset))
19 | os.system('rm {0}'.format(zipfile))
20 |
21 |
22 | def process(dataset):
23 | src = os.path.join(os.path.dirname(__file__), 'data')
24 | prefix = os.path.join(src, dataset, dataset)
25 |
26 | graph_node_dict = {}
27 | with open('{0}_graph_indicator.txt'.format(prefix), 'r') as f:
28 | for idx, line in enumerate(f):
29 | graph_node_dict[idx + 1] = int(line.strip('\n'))
30 | max_nodes = Counter(graph_node_dict.values()).most_common(1)[0][1]
31 |
32 | node_labels = []
33 | if os.path.exists('{0}_node_labels.txt'.format(prefix)):
34 | with open('{0}_node_labels.txt'.format(prefix), 'r') as f:
35 | for line in f:
36 | node_labels += [int(line.strip('\n')) - 1]
37 | num_unique_node_labels = max(node_labels) + 1
38 | else:
39 | print('No node labels')
40 |
41 | node_attrs = []
42 | if os.path.exists('{0}_node_attributes.txt'.format(prefix)):
43 | with open('{0}_node_attributes.txt'.format(prefix), 'r') as f:
44 | for line in f:
45 | node_attrs.append(
46 | np.array([float(attr) for attr in re.split("[,\s]+", line.strip("\s\n")) if attr], dtype=np.float)
47 | )
48 | else:
49 | print('No node attributes')
50 |
51 | graph_labels = []
52 | unique_labels = set()
53 | with open('{0}_graph_labels.txt'.format(prefix), 'r') as f:
54 | for line in f:
55 | val = int(line.strip('\n'))
56 | if val not in unique_labels:
57 | unique_labels.add(val)
58 | graph_labels.append(val)
59 | label_idx_dict = {val: idx for idx, val in enumerate(unique_labels)}
60 | graph_labels = np.array([label_idx_dict[l] for l in graph_labels])
61 |
62 | adj_list = {idx: [] for idx in range(1, len(graph_labels) + 1)}
63 | index_graph = {idx: [] for idx in range(1, len(graph_labels) + 1)}
64 | with open('{0}_A.txt'.format(prefix), 'r') as f:
65 | for line in f:
66 | u, v = tuple(map(int, line.strip('\n').split(',')))
67 | adj_list[graph_node_dict[u]].append((u, v))
68 | index_graph[graph_node_dict[u]] += [u, v]
69 |
70 | for k in index_graph.keys():
71 | index_graph[k] = [u - 1 for u in set(index_graph[k])]
72 |
73 | graphs, pprs = [], []
74 | for idx in range(1, 1 + len(adj_list)):
75 | graph = nx.from_edgelist(adj_list[idx])
76 | if max_nodes is not None and graph.number_of_nodes() > max_nodes:
77 | continue
78 |
79 | graph.graph['label'] = graph_labels[idx - 1]
80 | for u in graph.nodes():
81 | if len(node_labels) > 0:
82 | node_label_one_hot = [0] * num_unique_node_labels
83 | node_label = node_labels[u - 1]
84 | node_label_one_hot[node_label] = 1
85 | graph.nodes[u]['label'] = node_label_one_hot
86 | if len(node_attrs) > 0:
87 | graph.nodes[u]['feat'] = node_attrs[u - 1]
88 | if len(node_attrs) > 0:
89 | graph.graph['feat_dim'] = node_attrs[0].shape[0]
90 |
91 | # relabeling
92 | mapping = {}
93 | for node_idx, node in enumerate(graph.nodes()):
94 | mapping[node] = node_idx
95 |
96 | graphs.append(nx.relabel_nodes(graph, mapping))
97 | pprs.append(compute_ppr(graph, alpha=0.2))
98 |
99 | if 'feat_dim' in graphs[0].graph:
100 | pass
101 | else:
102 | max_deg = max([max(dict(graph.degree).values()) for graph in graphs])
103 | for graph in graphs:
104 | for u in graph.nodes(data=True):
105 | f = np.zeros(max_deg + 1)
106 | f[graph.degree[u[0]]] = 1.0
107 | if 'label' in u[1]:
108 | f = np.concatenate((np.array(u[1]['label'], dtype=np.float), f))
109 | graph.nodes[u[0]]['feat'] = f
110 | return graphs, pprs
111 |
112 |
113 | def load(dataset):
114 | basedir = os.path.dirname(os.path.abspath(__file__))
115 | datadir = os.path.join(basedir, 'data', dataset)
116 |
117 | if not os.path.exists(datadir):
118 | download(dataset)
119 | graphs, diff = process(dataset)
120 | feat, adj, labels = [], [], []
121 |
122 | for idx, graph in enumerate(graphs):
123 | adj.append(nx.to_numpy_array(graph))
124 | labels.append(graph.graph['label'])
125 | feat.append(np.array(list(nx.get_node_attributes(graph, 'feat').values())))
126 |
127 | adj, diff, feat, labels = np.array(adj), np.array(diff), np.array(feat), np.array(labels)
128 |
129 | np.save(f'{datadir}/adj.npy', adj)
130 | np.save(f'{datadir}/diff.npy', diff)
131 | np.save(f'{datadir}/feat.npy', feat)
132 | np.save(f'{datadir}/labels.npy', labels)
133 |
134 | else:
135 | adj = np.load(f'{datadir}/adj.npy', allow_pickle=True)
136 | diff = np.load(f'{datadir}/diff.npy', allow_pickle=True)
137 | feat = np.load(f'{datadir}/feat.npy', allow_pickle=True)
138 | labels = np.load(f'{datadir}/labels.npy', allow_pickle=True)
139 |
140 | max_nodes = max([a.shape[0] for a in adj])
141 | feat_dim = feat[0].shape[-1]
142 |
143 | num_nodes = []
144 |
145 | for idx in range(adj.shape[0]):
146 |
147 | num_nodes.append(adj[idx].shape[-1])
148 |
149 | adj[idx] = normalize_adj(adj[idx]).todense()
150 |
151 | diff[idx] = np.hstack(
152 | (np.vstack((diff[idx], np.zeros((max_nodes - diff[idx].shape[0], diff[idx].shape[0])))),
153 | np.zeros((max_nodes, max_nodes - diff[idx].shape[1]))))
154 |
155 | adj[idx] = np.hstack(
156 | (np.vstack((adj[idx], np.zeros((max_nodes - adj[idx].shape[0], adj[idx].shape[0])))),
157 | np.zeros((max_nodes, max_nodes - adj[idx].shape[1]))))
158 |
159 | feat[idx] = np.vstack((feat[idx], np.zeros((max_nodes - feat[idx].shape[0], feat_dim))))
160 |
161 | adj = np.array(adj.tolist()).reshape(-1, max_nodes, max_nodes)
162 | diff = np.array(diff.tolist()).reshape(-1, max_nodes, max_nodes)
163 | feat = np.array(feat.tolist()).reshape(-1, max_nodes, feat_dim)
164 |
165 | return adj, diff, feat, labels, num_nodes
166 |
167 |
168 | if __name__ == '__main__':
169 | # MUTAG, PTC_MR, IMDB-BINARY, IMDB-MULTI, REDDIT-BINARY, REDDIT-MULTI-5K,
170 | adj, diff, feat, labels = load('PTC_MR')
171 | print('done')
172 |
173 |
--------------------------------------------------------------------------------
/graph/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from sklearn.model_selection import GridSearchCV, StratifiedKFold
6 | from graph.dataset import load
7 |
8 |
9 | class GCNLayer(nn.Module):
10 | def __init__(self, in_ft, out_ft, bias=True):
11 | super(GCNLayer, self).__init__()
12 | self.fc = nn.Linear(in_ft, out_ft, bias=False)
13 | self.act = nn.PReLU()
14 |
15 | if bias:
16 | self.bias = nn.Parameter(torch.FloatTensor(out_ft))
17 | self.bias.data.fill_(0.0)
18 | else:
19 | self.register_parameter('bias', None)
20 |
21 | for m in self.modules():
22 | self.weights_init(m)
23 |
24 | def weights_init(self, m):
25 | if isinstance(m, nn.Linear):
26 | torch.nn.init.xavier_uniform_(m.weight.data)
27 | if m.bias is not None:
28 | m.bias.data.fill_(0.0)
29 |
30 | def forward(self, feat, adj):
31 | feat = self.fc(feat)
32 | out = torch.bmm(adj, feat)
33 | if self.bias is not None:
34 | out += self.bias
35 | return self.act(out)
36 |
37 |
38 | class GCN(nn.Module):
39 | def __init__(self, in_ft, out_ft, num_layers):
40 | super(GCN, self).__init__()
41 | n_h = out_ft
42 | self.layers = []
43 | self.num_layers = num_layers
44 | self.layers.append(GCNLayer(in_ft, n_h).cuda())
45 | for __ in range(num_layers - 1):
46 | self.layers.append(GCNLayer(n_h, n_h).cuda())
47 |
48 | def forward(self, feat, adj, mask):
49 | h_1 = self.layers[0](feat, adj)
50 | h_1g = torch.sum(h_1, 1)
51 | for idx in range(self.num_layers - 1):
52 | h_1 = self.layers[idx + 1](h_1, adj)
53 | h_1g = torch.cat((h_1g, torch.sum(h_1, 1)), -1)
54 | return h_1, h_1g
55 |
56 |
57 | class MLP(nn.Module):
58 | def __init__(self, in_ft, out_ft):
59 | super(MLP, self).__init__()
60 | self.ffn = nn.Sequential(
61 | nn.Linear(in_ft, out_ft),
62 | nn.PReLU(),
63 | nn.Linear(out_ft, out_ft),
64 | nn.PReLU(),
65 | nn.Linear(out_ft, out_ft),
66 | nn.PReLU()
67 | )
68 | self.linear_shortcut = nn.Linear(in_ft, out_ft)
69 |
70 | def forward(self, x):
71 | return self.ffn(x) + self.linear_shortcut(x)
72 |
73 |
74 | class Model(nn.Module):
75 | def __init__(self, n_in, n_h, num_layers):
76 | super(Model, self).__init__()
77 | self.mlp1 = MLP(1 * n_h, n_h)
78 | self.mlp2 = MLP(num_layers * n_h, n_h)
79 | self.gnn1 = GCN(n_in, n_h, num_layers)
80 | self.gnn2 = GCN(n_in, n_h, num_layers)
81 |
82 | def forward(self, adj, diff, feat, mask):
83 | lv1, gv1 = self.gnn1(feat, adj, mask)
84 | lv2, gv2 = self.gnn2(feat, diff, mask)
85 |
86 | lv1 = self.mlp1(lv1)
87 | lv2 = self.mlp1(lv2)
88 |
89 | gv1 = self.mlp2(gv1)
90 | gv2 = self.mlp2(gv2)
91 |
92 | return lv1, gv1, lv2, gv2
93 |
94 | def embed(self, feat, adj, diff, mask):
95 | __, gv1, __, gv2 = self.forward(adj, diff, feat, mask)
96 | return (gv1 + gv2).detach()
97 |
98 |
99 | # Borrowed from https://github.com/fanyun-sun/InfoGraph
100 | def get_positive_expectation(p_samples, measure, average=True):
101 | """Computes the positive part of a divergence / difference.
102 | Args:
103 | p_samples: Positive samples.
104 | measure: Measure to compute for.
105 | average: Average the result over samples.
106 | Returns:
107 | torch.Tensor
108 | """
109 | log_2 = np.log(2.)
110 |
111 | if measure == 'GAN':
112 | Ep = - F.softplus(-p_samples)
113 | elif measure == 'JSD':
114 | Ep = log_2 - F.softplus(- p_samples)
115 | elif measure == 'X2':
116 | Ep = p_samples ** 2
117 | elif measure == 'KL':
118 | Ep = p_samples + 1.
119 | elif measure == 'RKL':
120 | Ep = -torch.exp(-p_samples)
121 | elif measure == 'DV':
122 | Ep = p_samples
123 | elif measure == 'H2':
124 | Ep = 1. - torch.exp(-p_samples)
125 | elif measure == 'W1':
126 | Ep = p_samples
127 |
128 | if average:
129 | return Ep.mean()
130 | else:
131 | return Ep
132 |
133 |
134 | # Borrowed from https://github.com/fanyun-sun/InfoGraph
135 | def get_negative_expectation(q_samples, measure, average=True):
136 | """Computes the negative part of a divergence / difference.
137 | Args:
138 | q_samples: Negative samples.
139 | measure: Measure to compute for.
140 | average: Average the result over samples.
141 | Returns:
142 | torch.Tensor
143 | """
144 | log_2 = np.log(2.)
145 |
146 | if measure == 'GAN':
147 | Eq = F.softplus(-q_samples) + q_samples
148 | elif measure == 'JSD':
149 | Eq = F.softplus(-q_samples) + q_samples - log_2
150 | elif measure == 'X2':
151 | Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
152 | elif measure == 'KL':
153 | Eq = torch.exp(q_samples)
154 | elif measure == 'RKL':
155 | Eq = q_samples - 1.
156 | elif measure == 'H2':
157 | Eq = torch.exp(q_samples) - 1.
158 | elif measure == 'W1':
159 | Eq = q_samples
160 |
161 | if average:
162 | return Eq.mean()
163 | else:
164 | return Eq
165 |
166 |
167 | # Borrowed from https://github.com/fanyun-sun/InfoGraph
168 | def local_global_loss_(l_enc, g_enc, batch, measure, mask):
169 | '''
170 | Args:
171 | l: Local feature map.
172 | g: Global features.
173 | measure: Type of f-divergence. For use with mode `fd`
174 | mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
175 | Returns:
176 | torch.Tensor: Loss.
177 | '''
178 | num_graphs = g_enc.shape[0]
179 | num_nodes = l_enc.shape[0]
180 | max_nodes = num_nodes // num_graphs
181 |
182 | pos_mask = torch.zeros((num_nodes, num_graphs)).cuda()
183 | neg_mask = torch.ones((num_nodes, num_graphs)).cuda()
184 | msk = torch.ones((num_nodes, num_graphs)).cuda()
185 | for nodeidx, graphidx in enumerate(batch):
186 | pos_mask[nodeidx][graphidx] = 1.
187 | neg_mask[nodeidx][graphidx] = 0.
188 |
189 | for idx, m in enumerate(mask):
190 | msk[idx * max_nodes + m: idx * max_nodes + max_nodes, idx] = 0.
191 |
192 | res = torch.mm(l_enc, g_enc.t()) * msk
193 |
194 | E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum()
195 | E_pos = E_pos / num_nodes
196 | E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum()
197 | E_neg = E_neg / (num_nodes * (num_graphs - 1))
198 | return E_neg - E_pos
199 |
200 |
201 | def global_global_loss_(g1_enc, g2_enc, measure):
202 | '''
203 | Args:
204 | l: Local feature map.
205 | g: Global features.
206 | measure: Type of f-divergence. For use with mode `fd`
207 | mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
208 | Returns:
209 | torch.Tensor: Loss.
210 | '''
211 | num_graphs = g1_enc.shape[0]
212 |
213 | pos_mask = torch.zeros((num_graphs, num_graphs)).cuda()
214 | neg_mask = torch.ones((num_graphs, num_graphs)).cuda()
215 | for graphidx in range(num_graphs):
216 | pos_mask[graphidx][graphidx] = 1.
217 | neg_mask[graphidx][graphidx] = 0.
218 |
219 | res = torch.mm(g1_enc, g2_enc.t())
220 |
221 | E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum()
222 | E_pos = E_pos / num_graphs
223 | E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum()
224 | E_neg = E_neg / (num_graphs * (num_graphs - 1))
225 | return E_neg - E_pos
226 |
227 |
228 | def train(dataset, gpu, num_layer=4, epoch=40, batch=64):
229 | nb_epochs = epoch
230 | batch_size = batch
231 | patience = 20
232 | lr = 0.001
233 | l2_coef = 0.0
234 | hid_units = 512
235 |
236 | adj, diff, feat, labels, num_nodes = load(dataset)
237 |
238 | feat = torch.FloatTensor(feat).cuda()
239 | diff = torch.FloatTensor(diff).cuda()
240 | adj = torch.FloatTensor(adj).cuda()
241 | labels = torch.LongTensor(labels).cuda()
242 |
243 | ft_size = feat[0].shape[1]
244 | max_nodes = feat[0].shape[0]
245 |
246 | model = Model(ft_size, hid_units, num_layer)
247 | optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)
248 |
249 | model.cuda()
250 |
251 | cnt_wait = 0
252 | best = 1e9
253 |
254 | itr = (adj.shape[0] // batch_size) + 1
255 | for epoch in range(nb_epochs):
256 | epoch_loss = 0.0
257 | train_idx = np.arange(adj.shape[0])
258 | np.random.shuffle(train_idx)
259 |
260 | for idx in range(0, len(train_idx), batch_size):
261 | model.train()
262 | optimiser.zero_grad()
263 |
264 | batch = train_idx[idx: idx + batch_size]
265 | mask = num_nodes[idx: idx + batch_size]
266 |
267 | lv1, gv1, lv2, gv2 = model(adj[batch], diff[batch], feat[batch], mask)
268 |
269 | lv1 = lv1.view(batch.shape[0] * max_nodes, -1)
270 | lv2 = lv2.view(batch.shape[0] * max_nodes, -1)
271 |
272 | batch = torch.LongTensor(np.repeat(np.arange(batch.shape[0]), max_nodes)).cuda()
273 |
274 | loss1 = local_global_loss_(lv1, gv2, batch, 'JSD', mask)
275 | loss2 = local_global_loss_(lv2, gv1, batch, 'JSD', mask)
276 | # loss3 = global_global_loss_(gv1, gv2, 'JSD')
277 | loss = loss1 + loss2 #+ loss3
278 | epoch_loss += loss
279 | loss.backward()
280 | optimiser.step()
281 |
282 | epoch_loss /= itr
283 |
284 | # print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, epoch_loss))
285 |
286 | if epoch_loss < best:
287 | best = epoch_loss
288 | best_t = epoch
289 | cnt_wait = 0
290 | torch.save(model.state_dict(), f'{dataset}-{gpu}.pkl')
291 | else:
292 | cnt_wait += 1
293 |
294 | if cnt_wait == patience:
295 | break
296 |
297 | model.load_state_dict(torch.load(f'{dataset}-{gpu}.pkl'))
298 |
299 | features = feat.cuda()
300 | adj = adj.cuda()
301 | diff = diff.cuda()
302 | labels = labels.cuda()
303 |
304 | embeds = model.embed(features, adj, diff, num_nodes)
305 |
306 | x = embeds.cpu().numpy()
307 | y = labels.cpu().numpy()
308 |
309 | from sklearn.svm import LinearSVC
310 | from sklearn.metrics import accuracy_score
311 | params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]}
312 | kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
313 | accuracies = []
314 | for train_index, test_index in kf.split(x, y):
315 |
316 | x_train, x_test = x[train_index], x[test_index]
317 | y_train, y_test = y[train_index], y[test_index]
318 | classifier = GridSearchCV(LinearSVC(), params, cv=5, scoring='accuracy', verbose=0)
319 | classifier.fit(x_train, y_train)
320 | accuracies.append(accuracy_score(y_test, classifier.predict(x_test)))
321 | print(np.mean(accuracies), np.std(accuracies))
322 |
323 |
324 | if __name__ == '__main__':
325 | import warnings
326 | warnings.filterwarnings("ignore")
327 | gpu = 1
328 | torch.cuda.set_device(gpu)
329 | layers = [2, 8, 12]
330 | batch = [32, 64, 128, 256]
331 | epoch = [20, 40, 100]
332 | ds = ['MUTAG', 'PTC_MR', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI-5K']
333 | seeds = [123, 132, 321, 312, 231]
334 | for d in ds:
335 | print(f'####################{d}####################')
336 | for l in layers:
337 | for b in batch:
338 | for e in epoch:
339 | for i in range(5):
340 | seed = seeds[i]
341 | torch.manual_seed(seed)
342 | torch.backends.cudnn.deterministic = True
343 | torch.backends.cudnn.benchmark = False
344 | np.random.seed(seed)
345 | print(f'Dataset: {d}, Layer:{l}, Batch: {b}, Epoch: {e}, Seed: {seed}')
346 | train(d, gpu, l, e, b)
347 | print('################################################')
--------------------------------------------------------------------------------
/img/arc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kavehhassani/mvgrl/628ed2bdb4496f8519556f7b9e25f93b57cf6902/img/arc.png
--------------------------------------------------------------------------------
/node/dataset.py:
--------------------------------------------------------------------------------
1 | from dgl.data import CoraDataset, CitationGraphDataset
2 | from utils import preprocess_features, normalize_adj
3 | from sklearn.preprocessing import MinMaxScaler
4 | from utils import compute_ppr
5 | import scipy.sparse as sp
6 | import networkx as nx
7 | import numpy as np
8 | import os
9 |
10 |
11 | def download(dataset):
12 | if dataset == 'cora':
13 | return CoraDataset()
14 | elif dataset == 'citeseer' or 'pubmed':
15 | return CitationGraphDataset(name=dataset)
16 | else:
17 | return None
18 |
19 |
20 | def load(dataset):
21 | datadir = os.path.join('data', dataset)
22 |
23 | if not os.path.exists(datadir):
24 | os.makedirs(datadir)
25 | ds = download(dataset)
26 | adj = nx.to_numpy_array(ds.graph)
27 | diff = compute_ppr(ds.graph, 0.2)
28 | feat = ds.features[:]
29 | labels = ds.labels[:]
30 |
31 | idx_train = np.argwhere(ds.train_mask == 1).reshape(-1)
32 | idx_val = np.argwhere(ds.val_mask == 1).reshape(-1)
33 | idx_test = np.argwhere(ds.test_mask == 1).reshape(-1)
34 |
35 | np.save(f'{datadir}/adj.npy', adj)
36 | np.save(f'{datadir}/diff.npy', diff)
37 | np.save(f'{datadir}/feat.npy', feat)
38 | np.save(f'{datadir}/labels.npy', labels)
39 | np.save(f'{datadir}/idx_train.npy', idx_train)
40 | np.save(f'{datadir}/idx_val.npy', idx_val)
41 | np.save(f'{datadir}/idx_test.npy', idx_test)
42 | else:
43 | adj = np.load(f'{datadir}/adj.npy')
44 | diff = np.load(f'{datadir}/diff.npy')
45 | feat = np.load(f'{datadir}/feat.npy')
46 | labels = np.load(f'{datadir}/labels.npy')
47 | idx_train = np.load(f'{datadir}/idx_train.npy')
48 | idx_val = np.load(f'{datadir}/idx_val.npy')
49 | idx_test = np.load(f'{datadir}/idx_test.npy')
50 |
51 | if dataset == 'citeseer':
52 | feat = preprocess_features(feat)
53 |
54 | epsilons = [1e-5, 1e-4, 1e-3, 1e-2]
55 | avg_degree = np.sum(adj) / adj.shape[0]
56 | epsilon = epsilons[np.argmin([abs(avg_degree - np.argwhere(diff >= e).shape[0] / diff.shape[0])
57 | for e in epsilons])]
58 |
59 | diff[diff < epsilon] = 0.0
60 | scaler = MinMaxScaler()
61 | scaler.fit(diff)
62 | diff = scaler.transform(diff)
63 |
64 | adj = normalize_adj(adj + sp.eye(adj.shape[0])).todense()
65 |
66 | return adj, diff, feat, labels, idx_train, idx_val, idx_test
67 |
68 |
69 | if __name__ == '__main__':
70 | load('cora')
71 |
--------------------------------------------------------------------------------
/node/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import torch
4 | import torch.nn as nn
5 | from utils import sparse_mx_to_torch_sparse_tensor
6 | from node.dataset import load
7 |
8 |
9 | # Borrowed from https://github.com/PetarV-/DGI
10 | class GCN(nn.Module):
11 | def __init__(self, in_ft, out_ft, bias=True):
12 | super(GCN, self).__init__()
13 | self.fc = nn.Linear(in_ft, out_ft, bias=False)
14 | self.act = nn.PReLU()
15 |
16 | if bias:
17 | self.bias = nn.Parameter(torch.FloatTensor(out_ft))
18 | self.bias.data.fill_(0.0)
19 | else:
20 | self.register_parameter('bias', None)
21 |
22 | for m in self.modules():
23 | self.weights_init(m)
24 |
25 | def weights_init(self, m):
26 | if isinstance(m, nn.Linear):
27 | torch.nn.init.xavier_uniform_(m.weight.data)
28 | if m.bias is not None:
29 | m.bias.data.fill_(0.0)
30 |
31 | # Shape of seq: (batch, nodes, features)
32 | def forward(self, seq, adj, sparse=False):
33 | seq_fts = self.fc(seq)
34 | if sparse:
35 | out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0)
36 | else:
37 | out = torch.bmm(adj, seq_fts)
38 | if self.bias is not None:
39 | out += self.bias
40 | return self.act(out)
41 |
42 |
43 | # Borrowed from https://github.com/PetarV-/DGI
44 | class Readout(nn.Module):
45 | def __init__(self):
46 | super(Readout, self).__init__()
47 |
48 | def forward(self, seq, msk):
49 | if msk is None:
50 | return torch.mean(seq, 1)
51 | else:
52 | msk = torch.unsqueeze(msk, -1)
53 | return torch.mean(seq * msk, 1) / torch.sum(msk)
54 |
55 |
56 | # Borrowed from https://github.com/PetarV-/DGI
57 | class Discriminator(nn.Module):
58 | def __init__(self, n_h):
59 | super(Discriminator, self).__init__()
60 | self.f_k = nn.Bilinear(n_h, n_h, 1)
61 |
62 | for m in self.modules():
63 | self.weights_init(m)
64 |
65 | def weights_init(self, m):
66 | if isinstance(m, nn.Bilinear):
67 | torch.nn.init.xavier_uniform_(m.weight.data)
68 | if m.bias is not None:
69 | m.bias.data.fill_(0.0)
70 |
71 | def forward(self, c1, c2, h1, h2, h3, h4, s_bias1=None, s_bias2=None):
72 | c_x1 = torch.unsqueeze(c1, 1)
73 | c_x1 = c_x1.expand_as(h1).contiguous()
74 | c_x2 = torch.unsqueeze(c2, 1)
75 | c_x2 = c_x2.expand_as(h2).contiguous()
76 |
77 | # positive
78 | sc_1 = torch.squeeze(self.f_k(h2, c_x1), 2)
79 | sc_2 = torch.squeeze(self.f_k(h1, c_x2), 2)
80 |
81 | # negetive
82 | sc_3 = torch.squeeze(self.f_k(h4, c_x1), 2)
83 | sc_4 = torch.squeeze(self.f_k(h3, c_x2), 2)
84 |
85 | logits = torch.cat((sc_1, sc_2, sc_3, sc_4), 1)
86 | return logits
87 |
88 |
89 | class Model(nn.Module):
90 | def __init__(self, n_in, n_h):
91 | super(Model, self).__init__()
92 | self.gcn1 = GCN(n_in, n_h)
93 | self.gcn2 = GCN(n_in, n_h)
94 | self.read = Readout()
95 |
96 | self.sigm = nn.Sigmoid()
97 |
98 | self.disc = Discriminator(n_h)
99 |
100 | def forward(self, seq1, seq2, adj, diff, sparse, msk, samp_bias1, samp_bias2):
101 | h_1 = self.gcn1(seq1, adj, sparse)
102 | c_1 = self.read(h_1, msk)
103 | c_1 = self.sigm(c_1)
104 |
105 | h_2 = self.gcn2(seq1, diff, sparse)
106 | c_2 = self.read(h_2, msk)
107 | c_2 = self.sigm(c_2)
108 |
109 | h_3 = self.gcn1(seq2, adj, sparse)
110 | h_4 = self.gcn2(seq2, diff, sparse)
111 |
112 | ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4, samp_bias1, samp_bias2)
113 |
114 | return ret, h_1, h_2
115 |
116 | def embed(self, seq, adj, diff, sparse, msk):
117 | h_1 = self.gcn1(seq, adj, sparse)
118 | c = self.read(h_1, msk)
119 |
120 | h_2 = self.gcn2(seq, diff, sparse)
121 | return (h_1 + h_2).detach(), c.detach()
122 |
123 |
124 | class LogReg(nn.Module):
125 | def __init__(self, ft_in, nb_classes):
126 | super(LogReg, self).__init__()
127 | self.fc = nn.Linear(ft_in, nb_classes)
128 | self.sigm = nn.Sigmoid()
129 |
130 | for m in self.modules():
131 | self.weights_init(m)
132 |
133 | def weights_init(self, m):
134 | if isinstance(m, nn.Linear):
135 | torch.nn.init.xavier_uniform_(m.weight.data)
136 | if m.bias is not None:
137 | m.bias.data.fill_(0.0)
138 |
139 | def forward(self, seq):
140 | ret = torch.log_softmax(self.fc(seq), dim=-1)
141 | return ret
142 |
143 |
144 | def train(dataset, verbose=False):
145 |
146 | nb_epochs = 3000
147 | patience = 20
148 | lr = 0.001
149 | l2_coef = 0.0
150 | hid_units = 512
151 | sparse = False
152 |
153 | adj, diff, features, labels, idx_train, idx_val, idx_test = load(dataset)
154 |
155 | ft_size = features.shape[1]
156 | nb_classes = np.unique(labels).shape[0]
157 |
158 | sample_size = 2000
159 | batch_size = 4
160 |
161 | labels = torch.LongTensor(labels)
162 | idx_train = torch.LongTensor(idx_train)
163 | idx_test = torch.LongTensor(idx_test)
164 |
165 | lbl_1 = torch.ones(batch_size, sample_size * 2)
166 | lbl_2 = torch.zeros(batch_size, sample_size * 2)
167 | lbl = torch.cat((lbl_1, lbl_2), 1)
168 |
169 | model = Model(ft_size, hid_units)
170 | optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef)
171 |
172 | if torch.cuda.is_available():
173 | model.cuda()
174 | labels = labels.cuda()
175 | lbl = lbl.cuda()
176 | idx_train = idx_train.cuda()
177 | idx_test = idx_test.cuda()
178 |
179 | b_xent = nn.BCEWithLogitsLoss()
180 | xent = nn.CrossEntropyLoss()
181 | cnt_wait = 0
182 | best = 1e9
183 | best_t = 0
184 |
185 | for epoch in range(nb_epochs):
186 |
187 | idx = np.random.randint(0, adj.shape[-1] - sample_size + 1, batch_size)
188 | ba, bd, bf = [], [], []
189 | for i in idx:
190 | ba.append(adj[i: i + sample_size, i: i + sample_size])
191 | bd.append(diff[i: i + sample_size, i: i + sample_size])
192 | bf.append(features[i: i + sample_size])
193 |
194 | ba = np.array(ba).reshape(batch_size, sample_size, sample_size)
195 | bd = np.array(bd).reshape(batch_size, sample_size, sample_size)
196 | bf = np.array(bf).reshape(batch_size, sample_size, ft_size)
197 |
198 | if sparse:
199 | ba = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(ba))
200 | bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd))
201 | else:
202 | ba = torch.FloatTensor(ba)
203 | bd = torch.FloatTensor(bd)
204 |
205 | bf = torch.FloatTensor(bf)
206 | idx = np.random.permutation(sample_size)
207 | shuf_fts = bf[:, idx, :]
208 |
209 | if torch.cuda.is_available():
210 | bf = bf.cuda()
211 | ba = ba.cuda()
212 | bd = bd.cuda()
213 | shuf_fts = shuf_fts.cuda()
214 |
215 | model.train()
216 | optimiser.zero_grad()
217 |
218 | logits, __, __ = model(bf, shuf_fts, ba, bd, sparse, None, None, None)
219 |
220 | loss = b_xent(logits, lbl)
221 |
222 | loss.backward()
223 | optimiser.step()
224 |
225 | if verbose:
226 | print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item()))
227 |
228 | if loss < best:
229 | best = loss
230 | best_t = epoch
231 | cnt_wait = 0
232 | torch.save(model.state_dict(), 'model.pkl')
233 | else:
234 | cnt_wait += 1
235 |
236 | if cnt_wait == patience:
237 | if verbose:
238 | print('Early stopping!')
239 | break
240 |
241 | if verbose:
242 | print('Loading {}th epoch'.format(best_t))
243 | model.load_state_dict(torch.load('model.pkl'))
244 |
245 | if sparse:
246 | adj = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(adj))
247 | diff = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(diff))
248 |
249 | features = torch.FloatTensor(features[np.newaxis])
250 | adj = torch.FloatTensor(adj[np.newaxis])
251 | diff = torch.FloatTensor(diff[np.newaxis])
252 | features = features.cuda()
253 | adj = adj.cuda()
254 | diff = diff.cuda()
255 |
256 | embeds, _ = model.embed(features, adj, diff, sparse, None)
257 | train_embs = embeds[0, idx_train]
258 | test_embs = embeds[0, idx_test]
259 |
260 | train_lbls = labels[idx_train]
261 | test_lbls = labels[idx_test]
262 |
263 | accs = []
264 | wd = 0.01 if dataset == 'citeseer' else 0.0
265 |
266 | for _ in range(50):
267 | log = LogReg(hid_units, nb_classes)
268 | opt = torch.optim.Adam(log.parameters(), lr=1e-2, weight_decay=wd)
269 | log.cuda()
270 | for _ in range(300):
271 | log.train()
272 | opt.zero_grad()
273 |
274 | logits = log(train_embs)
275 | loss = xent(logits, train_lbls)
276 |
277 | loss.backward()
278 | opt.step()
279 |
280 | logits = log(test_embs)
281 | preds = torch.argmax(logits, dim=1)
282 | acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0]
283 | accs.append(acc * 100)
284 |
285 | accs = torch.stack(accs)
286 | print(accs.mean().item(), accs.std().item())
287 |
288 |
289 | if __name__ == '__main__':
290 | import warnings
291 | warnings.filterwarnings("ignore")
292 | torch.cuda.set_device(3)
293 |
294 | # 'cora', 'citeseer', 'pubmed'
295 | dataset = 'cora'
296 | for __ in range(50):
297 | train(dataset)
298 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dgl==0.4.1
2 | networkx==2.4
3 | numpy==1.17.4
4 | opt-einsum==3.1.0
5 | pickleshare==0.7.5
6 | pytz==2019.2
7 | pyzmq==18.1.0
8 | requests==2.22.0
9 | scikit-learn==0.21.3
10 | scipy==1.3.2
11 | sklearn==0.0
12 | torch==1.3.1
13 | torch-cluster==1.4.5
14 | torch-geometric==1.3.2
15 | torch-scatter==1.4.0
16 | torch-sparse==0.4.3
17 | torchtext==0.4.0
18 | torchvision==0.4.2
19 | urllib3==1.25.6
20 | zipp==0.6.0
21 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 | import torch
4 | from scipy.linalg import fractional_matrix_power, inv
5 | import scipy.sparse as sp
6 |
7 |
8 | def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True):
9 | a = nx.convert_matrix.to_numpy_array(graph)
10 | if self_loop:
11 | a = a + np.eye(a.shape[0]) # A^ = A + I_n
12 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii
13 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2)
14 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2)
15 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1
16 |
17 |
18 | def compute_heat(graph: nx.Graph, t=5, self_loop=True):
19 | a = nx.convert_matrix.to_numpy_array(graph)
20 | if self_loop:
21 | a = a + np.eye(a.shape[0])
22 | d = np.diag(np.sum(a, 1))
23 | return np.exp(t * (np.matmul(a, inv(d)) - 1))
24 |
25 |
26 | def sparse_to_tuple(sparse_mx):
27 | """Convert sparse matrix to tuple representation."""
28 |
29 | def to_tuple(mx):
30 | if not sp.isspmatrix_coo(mx):
31 | mx = mx.tocoo()
32 | coords = np.vstack((mx.row, mx.col)).transpose()
33 | values = mx.data
34 | shape = mx.shape
35 | return coords, values, shape
36 |
37 | if isinstance(sparse_mx, list):
38 | for i in range(len(sparse_mx)):
39 | sparse_mx[i] = to_tuple(sparse_mx[i])
40 | else:
41 | sparse_mx = to_tuple(sparse_mx)
42 |
43 | return sparse_mx
44 |
45 |
46 | def preprocess_features(features):
47 | """Row-normalize feature matrix and convert to tuple representation"""
48 | rowsum = np.array(features.sum(1))
49 | r_inv = np.power(rowsum, -1).flatten()
50 | r_inv[np.isinf(r_inv)] = 0.
51 | r_mat_inv = sp.diags(r_inv)
52 | features = r_mat_inv.dot(features)
53 | if isinstance(features, np.ndarray):
54 | return features
55 | else:
56 | return features.todense(), sparse_to_tuple(features)
57 |
58 |
59 | def normalize_adj(adj, self_loop=True):
60 | """Symmetrically normalize adjacency matrix."""
61 | if self_loop:
62 | adj = adj + sp.eye(adj.shape[0])
63 | adj = sp.coo_matrix(adj)
64 | rowsum = np.array(adj.sum(1))
65 | d_inv_sqrt = np.power(rowsum, -0.5).flatten()
66 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
67 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
68 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
69 |
70 |
71 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
72 | """Convert a scipy sparse matrix to a torch sparse tensor."""
73 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
74 | indices = torch.from_numpy(
75 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
76 | values = torch.from_numpy(sparse_mx.data)
77 | shape = torch.Size(sparse_mx.shape)
78 | return torch.sparse.FloatTensor(indices, values, shape)
--------------------------------------------------------------------------------