├── README.md ├── graph ├── dataset.py └── train.py ├── img └── arc.png ├── node ├── dataset.py └── train.py ├── requirements.txt └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Contrastive Multi-View Representation Learning on Graphs 2 | 3 | This work introduces a self-supervised approach based on contrastive multi-view 4 | learning to learn node and graph level representations. 5 | 6 | It has been accepted at ICML 2020: 7 | 8 | [https://arxiv.org/abs/2006.05582](https://arxiv.org/abs/2006.05582) 9 | 10 |
11 |
12 | 13 | 14 | ![](img/arc.png) 15 | 16 | 17 | ## Reference 18 | 19 | ``` 20 | @incollection{icml2020_1971, 21 | author = {Hassani, Kaveh and Khasahmadi, Amir Hosein}, 22 | booktitle = {Proceedings of International Conference on Machine Learning}, 23 | pages = {3451--3461}, 24 | title = {Contrastive Multi-View Representation Learning on Graphs}, 25 | year = {2020} 26 | } 27 | ``` 28 | -------------------------------------------------------------------------------- /graph/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | import networkx as nx 5 | from collections import Counter 6 | from utils import compute_ppr, normalize_adj 7 | 8 | 9 | def download(dataset): 10 | basedir = os.path.dirname(os.path.abspath(__file__)) 11 | datadir = os.path.join(basedir, 'data', dataset) 12 | if not os.path.exists(datadir): 13 | os.makedirs(datadir) 14 | url = 'https://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/{0}.zip'.format(dataset) 15 | zipfile = os.path.basename(url) 16 | os.system('wget {0}; unzip {1}'.format(url, zipfile)) 17 | os.system('mv {0}/* {1}'.format(dataset, datadir)) 18 | os.system('rm -r {0}'.format(dataset)) 19 | os.system('rm {0}'.format(zipfile)) 20 | 21 | 22 | def process(dataset): 23 | src = os.path.join(os.path.dirname(__file__), 'data') 24 | prefix = os.path.join(src, dataset, dataset) 25 | 26 | graph_node_dict = {} 27 | with open('{0}_graph_indicator.txt'.format(prefix), 'r') as f: 28 | for idx, line in enumerate(f): 29 | graph_node_dict[idx + 1] = int(line.strip('\n')) 30 | max_nodes = Counter(graph_node_dict.values()).most_common(1)[0][1] 31 | 32 | node_labels = [] 33 | if os.path.exists('{0}_node_labels.txt'.format(prefix)): 34 | with open('{0}_node_labels.txt'.format(prefix), 'r') as f: 35 | for line in f: 36 | node_labels += [int(line.strip('\n')) - 1] 37 | num_unique_node_labels = max(node_labels) + 1 38 | else: 39 | print('No node labels') 40 | 41 | node_attrs = [] 42 | if os.path.exists('{0}_node_attributes.txt'.format(prefix)): 43 | with open('{0}_node_attributes.txt'.format(prefix), 'r') as f: 44 | for line in f: 45 | node_attrs.append( 46 | np.array([float(attr) for attr in re.split("[,\s]+", line.strip("\s\n")) if attr], dtype=np.float) 47 | ) 48 | else: 49 | print('No node attributes') 50 | 51 | graph_labels = [] 52 | unique_labels = set() 53 | with open('{0}_graph_labels.txt'.format(prefix), 'r') as f: 54 | for line in f: 55 | val = int(line.strip('\n')) 56 | if val not in unique_labels: 57 | unique_labels.add(val) 58 | graph_labels.append(val) 59 | label_idx_dict = {val: idx for idx, val in enumerate(unique_labels)} 60 | graph_labels = np.array([label_idx_dict[l] for l in graph_labels]) 61 | 62 | adj_list = {idx: [] for idx in range(1, len(graph_labels) + 1)} 63 | index_graph = {idx: [] for idx in range(1, len(graph_labels) + 1)} 64 | with open('{0}_A.txt'.format(prefix), 'r') as f: 65 | for line in f: 66 | u, v = tuple(map(int, line.strip('\n').split(','))) 67 | adj_list[graph_node_dict[u]].append((u, v)) 68 | index_graph[graph_node_dict[u]] += [u, v] 69 | 70 | for k in index_graph.keys(): 71 | index_graph[k] = [u - 1 for u in set(index_graph[k])] 72 | 73 | graphs, pprs = [], [] 74 | for idx in range(1, 1 + len(adj_list)): 75 | graph = nx.from_edgelist(adj_list[idx]) 76 | if max_nodes is not None and graph.number_of_nodes() > max_nodes: 77 | continue 78 | 79 | graph.graph['label'] = graph_labels[idx - 1] 80 | for u in graph.nodes(): 81 | if len(node_labels) > 0: 82 | node_label_one_hot = [0] * num_unique_node_labels 83 | node_label = node_labels[u - 1] 84 | node_label_one_hot[node_label] = 1 85 | graph.nodes[u]['label'] = node_label_one_hot 86 | if len(node_attrs) > 0: 87 | graph.nodes[u]['feat'] = node_attrs[u - 1] 88 | if len(node_attrs) > 0: 89 | graph.graph['feat_dim'] = node_attrs[0].shape[0] 90 | 91 | # relabeling 92 | mapping = {} 93 | for node_idx, node in enumerate(graph.nodes()): 94 | mapping[node] = node_idx 95 | 96 | graphs.append(nx.relabel_nodes(graph, mapping)) 97 | pprs.append(compute_ppr(graph, alpha=0.2)) 98 | 99 | if 'feat_dim' in graphs[0].graph: 100 | pass 101 | else: 102 | max_deg = max([max(dict(graph.degree).values()) for graph in graphs]) 103 | for graph in graphs: 104 | for u in graph.nodes(data=True): 105 | f = np.zeros(max_deg + 1) 106 | f[graph.degree[u[0]]] = 1.0 107 | if 'label' in u[1]: 108 | f = np.concatenate((np.array(u[1]['label'], dtype=np.float), f)) 109 | graph.nodes[u[0]]['feat'] = f 110 | return graphs, pprs 111 | 112 | 113 | def load(dataset): 114 | basedir = os.path.dirname(os.path.abspath(__file__)) 115 | datadir = os.path.join(basedir, 'data', dataset) 116 | 117 | if not os.path.exists(datadir): 118 | download(dataset) 119 | graphs, diff = process(dataset) 120 | feat, adj, labels = [], [], [] 121 | 122 | for idx, graph in enumerate(graphs): 123 | adj.append(nx.to_numpy_array(graph)) 124 | labels.append(graph.graph['label']) 125 | feat.append(np.array(list(nx.get_node_attributes(graph, 'feat').values()))) 126 | 127 | adj, diff, feat, labels = np.array(adj), np.array(diff), np.array(feat), np.array(labels) 128 | 129 | np.save(f'{datadir}/adj.npy', adj) 130 | np.save(f'{datadir}/diff.npy', diff) 131 | np.save(f'{datadir}/feat.npy', feat) 132 | np.save(f'{datadir}/labels.npy', labels) 133 | 134 | else: 135 | adj = np.load(f'{datadir}/adj.npy', allow_pickle=True) 136 | diff = np.load(f'{datadir}/diff.npy', allow_pickle=True) 137 | feat = np.load(f'{datadir}/feat.npy', allow_pickle=True) 138 | labels = np.load(f'{datadir}/labels.npy', allow_pickle=True) 139 | 140 | max_nodes = max([a.shape[0] for a in adj]) 141 | feat_dim = feat[0].shape[-1] 142 | 143 | num_nodes = [] 144 | 145 | for idx in range(adj.shape[0]): 146 | 147 | num_nodes.append(adj[idx].shape[-1]) 148 | 149 | adj[idx] = normalize_adj(adj[idx]).todense() 150 | 151 | diff[idx] = np.hstack( 152 | (np.vstack((diff[idx], np.zeros((max_nodes - diff[idx].shape[0], diff[idx].shape[0])))), 153 | np.zeros((max_nodes, max_nodes - diff[idx].shape[1])))) 154 | 155 | adj[idx] = np.hstack( 156 | (np.vstack((adj[idx], np.zeros((max_nodes - adj[idx].shape[0], adj[idx].shape[0])))), 157 | np.zeros((max_nodes, max_nodes - adj[idx].shape[1])))) 158 | 159 | feat[idx] = np.vstack((feat[idx], np.zeros((max_nodes - feat[idx].shape[0], feat_dim)))) 160 | 161 | adj = np.array(adj.tolist()).reshape(-1, max_nodes, max_nodes) 162 | diff = np.array(diff.tolist()).reshape(-1, max_nodes, max_nodes) 163 | feat = np.array(feat.tolist()).reshape(-1, max_nodes, feat_dim) 164 | 165 | return adj, diff, feat, labels, num_nodes 166 | 167 | 168 | if __name__ == '__main__': 169 | # MUTAG, PTC_MR, IMDB-BINARY, IMDB-MULTI, REDDIT-BINARY, REDDIT-MULTI-5K, 170 | adj, diff, feat, labels = load('PTC_MR') 171 | print('done') 172 | 173 | -------------------------------------------------------------------------------- /graph/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from sklearn.model_selection import GridSearchCV, StratifiedKFold 6 | from graph.dataset import load 7 | 8 | 9 | class GCNLayer(nn.Module): 10 | def __init__(self, in_ft, out_ft, bias=True): 11 | super(GCNLayer, self).__init__() 12 | self.fc = nn.Linear(in_ft, out_ft, bias=False) 13 | self.act = nn.PReLU() 14 | 15 | if bias: 16 | self.bias = nn.Parameter(torch.FloatTensor(out_ft)) 17 | self.bias.data.fill_(0.0) 18 | else: 19 | self.register_parameter('bias', None) 20 | 21 | for m in self.modules(): 22 | self.weights_init(m) 23 | 24 | def weights_init(self, m): 25 | if isinstance(m, nn.Linear): 26 | torch.nn.init.xavier_uniform_(m.weight.data) 27 | if m.bias is not None: 28 | m.bias.data.fill_(0.0) 29 | 30 | def forward(self, feat, adj): 31 | feat = self.fc(feat) 32 | out = torch.bmm(adj, feat) 33 | if self.bias is not None: 34 | out += self.bias 35 | return self.act(out) 36 | 37 | 38 | class GCN(nn.Module): 39 | def __init__(self, in_ft, out_ft, num_layers): 40 | super(GCN, self).__init__() 41 | n_h = out_ft 42 | self.layers = [] 43 | self.num_layers = num_layers 44 | self.layers.append(GCNLayer(in_ft, n_h).cuda()) 45 | for __ in range(num_layers - 1): 46 | self.layers.append(GCNLayer(n_h, n_h).cuda()) 47 | 48 | def forward(self, feat, adj, mask): 49 | h_1 = self.layers[0](feat, adj) 50 | h_1g = torch.sum(h_1, 1) 51 | for idx in range(self.num_layers - 1): 52 | h_1 = self.layers[idx + 1](h_1, adj) 53 | h_1g = torch.cat((h_1g, torch.sum(h_1, 1)), -1) 54 | return h_1, h_1g 55 | 56 | 57 | class MLP(nn.Module): 58 | def __init__(self, in_ft, out_ft): 59 | super(MLP, self).__init__() 60 | self.ffn = nn.Sequential( 61 | nn.Linear(in_ft, out_ft), 62 | nn.PReLU(), 63 | nn.Linear(out_ft, out_ft), 64 | nn.PReLU(), 65 | nn.Linear(out_ft, out_ft), 66 | nn.PReLU() 67 | ) 68 | self.linear_shortcut = nn.Linear(in_ft, out_ft) 69 | 70 | def forward(self, x): 71 | return self.ffn(x) + self.linear_shortcut(x) 72 | 73 | 74 | class Model(nn.Module): 75 | def __init__(self, n_in, n_h, num_layers): 76 | super(Model, self).__init__() 77 | self.mlp1 = MLP(1 * n_h, n_h) 78 | self.mlp2 = MLP(num_layers * n_h, n_h) 79 | self.gnn1 = GCN(n_in, n_h, num_layers) 80 | self.gnn2 = GCN(n_in, n_h, num_layers) 81 | 82 | def forward(self, adj, diff, feat, mask): 83 | lv1, gv1 = self.gnn1(feat, adj, mask) 84 | lv2, gv2 = self.gnn2(feat, diff, mask) 85 | 86 | lv1 = self.mlp1(lv1) 87 | lv2 = self.mlp1(lv2) 88 | 89 | gv1 = self.mlp2(gv1) 90 | gv2 = self.mlp2(gv2) 91 | 92 | return lv1, gv1, lv2, gv2 93 | 94 | def embed(self, feat, adj, diff, mask): 95 | __, gv1, __, gv2 = self.forward(adj, diff, feat, mask) 96 | return (gv1 + gv2).detach() 97 | 98 | 99 | # Borrowed from https://github.com/fanyun-sun/InfoGraph 100 | def get_positive_expectation(p_samples, measure, average=True): 101 | """Computes the positive part of a divergence / difference. 102 | Args: 103 | p_samples: Positive samples. 104 | measure: Measure to compute for. 105 | average: Average the result over samples. 106 | Returns: 107 | torch.Tensor 108 | """ 109 | log_2 = np.log(2.) 110 | 111 | if measure == 'GAN': 112 | Ep = - F.softplus(-p_samples) 113 | elif measure == 'JSD': 114 | Ep = log_2 - F.softplus(- p_samples) 115 | elif measure == 'X2': 116 | Ep = p_samples ** 2 117 | elif measure == 'KL': 118 | Ep = p_samples + 1. 119 | elif measure == 'RKL': 120 | Ep = -torch.exp(-p_samples) 121 | elif measure == 'DV': 122 | Ep = p_samples 123 | elif measure == 'H2': 124 | Ep = 1. - torch.exp(-p_samples) 125 | elif measure == 'W1': 126 | Ep = p_samples 127 | 128 | if average: 129 | return Ep.mean() 130 | else: 131 | return Ep 132 | 133 | 134 | # Borrowed from https://github.com/fanyun-sun/InfoGraph 135 | def get_negative_expectation(q_samples, measure, average=True): 136 | """Computes the negative part of a divergence / difference. 137 | Args: 138 | q_samples: Negative samples. 139 | measure: Measure to compute for. 140 | average: Average the result over samples. 141 | Returns: 142 | torch.Tensor 143 | """ 144 | log_2 = np.log(2.) 145 | 146 | if measure == 'GAN': 147 | Eq = F.softplus(-q_samples) + q_samples 148 | elif measure == 'JSD': 149 | Eq = F.softplus(-q_samples) + q_samples - log_2 150 | elif measure == 'X2': 151 | Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2) 152 | elif measure == 'KL': 153 | Eq = torch.exp(q_samples) 154 | elif measure == 'RKL': 155 | Eq = q_samples - 1. 156 | elif measure == 'H2': 157 | Eq = torch.exp(q_samples) - 1. 158 | elif measure == 'W1': 159 | Eq = q_samples 160 | 161 | if average: 162 | return Eq.mean() 163 | else: 164 | return Eq 165 | 166 | 167 | # Borrowed from https://github.com/fanyun-sun/InfoGraph 168 | def local_global_loss_(l_enc, g_enc, batch, measure, mask): 169 | ''' 170 | Args: 171 | l: Local feature map. 172 | g: Global features. 173 | measure: Type of f-divergence. For use with mode `fd` 174 | mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. 175 | Returns: 176 | torch.Tensor: Loss. 177 | ''' 178 | num_graphs = g_enc.shape[0] 179 | num_nodes = l_enc.shape[0] 180 | max_nodes = num_nodes // num_graphs 181 | 182 | pos_mask = torch.zeros((num_nodes, num_graphs)).cuda() 183 | neg_mask = torch.ones((num_nodes, num_graphs)).cuda() 184 | msk = torch.ones((num_nodes, num_graphs)).cuda() 185 | for nodeidx, graphidx in enumerate(batch): 186 | pos_mask[nodeidx][graphidx] = 1. 187 | neg_mask[nodeidx][graphidx] = 0. 188 | 189 | for idx, m in enumerate(mask): 190 | msk[idx * max_nodes + m: idx * max_nodes + max_nodes, idx] = 0. 191 | 192 | res = torch.mm(l_enc, g_enc.t()) * msk 193 | 194 | E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum() 195 | E_pos = E_pos / num_nodes 196 | E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum() 197 | E_neg = E_neg / (num_nodes * (num_graphs - 1)) 198 | return E_neg - E_pos 199 | 200 | 201 | def global_global_loss_(g1_enc, g2_enc, measure): 202 | ''' 203 | Args: 204 | l: Local feature map. 205 | g: Global features. 206 | measure: Type of f-divergence. For use with mode `fd` 207 | mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`. 208 | Returns: 209 | torch.Tensor: Loss. 210 | ''' 211 | num_graphs = g1_enc.shape[0] 212 | 213 | pos_mask = torch.zeros((num_graphs, num_graphs)).cuda() 214 | neg_mask = torch.ones((num_graphs, num_graphs)).cuda() 215 | for graphidx in range(num_graphs): 216 | pos_mask[graphidx][graphidx] = 1. 217 | neg_mask[graphidx][graphidx] = 0. 218 | 219 | res = torch.mm(g1_enc, g2_enc.t()) 220 | 221 | E_pos = get_positive_expectation(res * pos_mask, measure, average=False).sum() 222 | E_pos = E_pos / num_graphs 223 | E_neg = get_negative_expectation(res * neg_mask, measure, average=False).sum() 224 | E_neg = E_neg / (num_graphs * (num_graphs - 1)) 225 | return E_neg - E_pos 226 | 227 | 228 | def train(dataset, gpu, num_layer=4, epoch=40, batch=64): 229 | nb_epochs = epoch 230 | batch_size = batch 231 | patience = 20 232 | lr = 0.001 233 | l2_coef = 0.0 234 | hid_units = 512 235 | 236 | adj, diff, feat, labels, num_nodes = load(dataset) 237 | 238 | feat = torch.FloatTensor(feat).cuda() 239 | diff = torch.FloatTensor(diff).cuda() 240 | adj = torch.FloatTensor(adj).cuda() 241 | labels = torch.LongTensor(labels).cuda() 242 | 243 | ft_size = feat[0].shape[1] 244 | max_nodes = feat[0].shape[0] 245 | 246 | model = Model(ft_size, hid_units, num_layer) 247 | optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef) 248 | 249 | model.cuda() 250 | 251 | cnt_wait = 0 252 | best = 1e9 253 | 254 | itr = (adj.shape[0] // batch_size) + 1 255 | for epoch in range(nb_epochs): 256 | epoch_loss = 0.0 257 | train_idx = np.arange(adj.shape[0]) 258 | np.random.shuffle(train_idx) 259 | 260 | for idx in range(0, len(train_idx), batch_size): 261 | model.train() 262 | optimiser.zero_grad() 263 | 264 | batch = train_idx[idx: idx + batch_size] 265 | mask = num_nodes[idx: idx + batch_size] 266 | 267 | lv1, gv1, lv2, gv2 = model(adj[batch], diff[batch], feat[batch], mask) 268 | 269 | lv1 = lv1.view(batch.shape[0] * max_nodes, -1) 270 | lv2 = lv2.view(batch.shape[0] * max_nodes, -1) 271 | 272 | batch = torch.LongTensor(np.repeat(np.arange(batch.shape[0]), max_nodes)).cuda() 273 | 274 | loss1 = local_global_loss_(lv1, gv2, batch, 'JSD', mask) 275 | loss2 = local_global_loss_(lv2, gv1, batch, 'JSD', mask) 276 | # loss3 = global_global_loss_(gv1, gv2, 'JSD') 277 | loss = loss1 + loss2 #+ loss3 278 | epoch_loss += loss 279 | loss.backward() 280 | optimiser.step() 281 | 282 | epoch_loss /= itr 283 | 284 | # print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, epoch_loss)) 285 | 286 | if epoch_loss < best: 287 | best = epoch_loss 288 | best_t = epoch 289 | cnt_wait = 0 290 | torch.save(model.state_dict(), f'{dataset}-{gpu}.pkl') 291 | else: 292 | cnt_wait += 1 293 | 294 | if cnt_wait == patience: 295 | break 296 | 297 | model.load_state_dict(torch.load(f'{dataset}-{gpu}.pkl')) 298 | 299 | features = feat.cuda() 300 | adj = adj.cuda() 301 | diff = diff.cuda() 302 | labels = labels.cuda() 303 | 304 | embeds = model.embed(features, adj, diff, num_nodes) 305 | 306 | x = embeds.cpu().numpy() 307 | y = labels.cpu().numpy() 308 | 309 | from sklearn.svm import LinearSVC 310 | from sklearn.metrics import accuracy_score 311 | params = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]} 312 | kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None) 313 | accuracies = [] 314 | for train_index, test_index in kf.split(x, y): 315 | 316 | x_train, x_test = x[train_index], x[test_index] 317 | y_train, y_test = y[train_index], y[test_index] 318 | classifier = GridSearchCV(LinearSVC(), params, cv=5, scoring='accuracy', verbose=0) 319 | classifier.fit(x_train, y_train) 320 | accuracies.append(accuracy_score(y_test, classifier.predict(x_test))) 321 | print(np.mean(accuracies), np.std(accuracies)) 322 | 323 | 324 | if __name__ == '__main__': 325 | import warnings 326 | warnings.filterwarnings("ignore") 327 | gpu = 1 328 | torch.cuda.set_device(gpu) 329 | layers = [2, 8, 12] 330 | batch = [32, 64, 128, 256] 331 | epoch = [20, 40, 100] 332 | ds = ['MUTAG', 'PTC_MR', 'IMDB-BINARY', 'IMDB-MULTI', 'REDDIT-BINARY', 'REDDIT-MULTI-5K'] 333 | seeds = [123, 132, 321, 312, 231] 334 | for d in ds: 335 | print(f'####################{d}####################') 336 | for l in layers: 337 | for b in batch: 338 | for e in epoch: 339 | for i in range(5): 340 | seed = seeds[i] 341 | torch.manual_seed(seed) 342 | torch.backends.cudnn.deterministic = True 343 | torch.backends.cudnn.benchmark = False 344 | np.random.seed(seed) 345 | print(f'Dataset: {d}, Layer:{l}, Batch: {b}, Epoch: {e}, Seed: {seed}') 346 | train(d, gpu, l, e, b) 347 | print('################################################') -------------------------------------------------------------------------------- /img/arc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kavehhassani/mvgrl/628ed2bdb4496f8519556f7b9e25f93b57cf6902/img/arc.png -------------------------------------------------------------------------------- /node/dataset.py: -------------------------------------------------------------------------------- 1 | from dgl.data import CoraDataset, CitationGraphDataset 2 | from utils import preprocess_features, normalize_adj 3 | from sklearn.preprocessing import MinMaxScaler 4 | from utils import compute_ppr 5 | import scipy.sparse as sp 6 | import networkx as nx 7 | import numpy as np 8 | import os 9 | 10 | 11 | def download(dataset): 12 | if dataset == 'cora': 13 | return CoraDataset() 14 | elif dataset == 'citeseer' or 'pubmed': 15 | return CitationGraphDataset(name=dataset) 16 | else: 17 | return None 18 | 19 | 20 | def load(dataset): 21 | datadir = os.path.join('data', dataset) 22 | 23 | if not os.path.exists(datadir): 24 | os.makedirs(datadir) 25 | ds = download(dataset) 26 | adj = nx.to_numpy_array(ds.graph) 27 | diff = compute_ppr(ds.graph, 0.2) 28 | feat = ds.features[:] 29 | labels = ds.labels[:] 30 | 31 | idx_train = np.argwhere(ds.train_mask == 1).reshape(-1) 32 | idx_val = np.argwhere(ds.val_mask == 1).reshape(-1) 33 | idx_test = np.argwhere(ds.test_mask == 1).reshape(-1) 34 | 35 | np.save(f'{datadir}/adj.npy', adj) 36 | np.save(f'{datadir}/diff.npy', diff) 37 | np.save(f'{datadir}/feat.npy', feat) 38 | np.save(f'{datadir}/labels.npy', labels) 39 | np.save(f'{datadir}/idx_train.npy', idx_train) 40 | np.save(f'{datadir}/idx_val.npy', idx_val) 41 | np.save(f'{datadir}/idx_test.npy', idx_test) 42 | else: 43 | adj = np.load(f'{datadir}/adj.npy') 44 | diff = np.load(f'{datadir}/diff.npy') 45 | feat = np.load(f'{datadir}/feat.npy') 46 | labels = np.load(f'{datadir}/labels.npy') 47 | idx_train = np.load(f'{datadir}/idx_train.npy') 48 | idx_val = np.load(f'{datadir}/idx_val.npy') 49 | idx_test = np.load(f'{datadir}/idx_test.npy') 50 | 51 | if dataset == 'citeseer': 52 | feat = preprocess_features(feat) 53 | 54 | epsilons = [1e-5, 1e-4, 1e-3, 1e-2] 55 | avg_degree = np.sum(adj) / adj.shape[0] 56 | epsilon = epsilons[np.argmin([abs(avg_degree - np.argwhere(diff >= e).shape[0] / diff.shape[0]) 57 | for e in epsilons])] 58 | 59 | diff[diff < epsilon] = 0.0 60 | scaler = MinMaxScaler() 61 | scaler.fit(diff) 62 | diff = scaler.transform(diff) 63 | 64 | adj = normalize_adj(adj + sp.eye(adj.shape[0])).todense() 65 | 66 | return adj, diff, feat, labels, idx_train, idx_val, idx_test 67 | 68 | 69 | if __name__ == '__main__': 70 | load('cora') 71 | -------------------------------------------------------------------------------- /node/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch 4 | import torch.nn as nn 5 | from utils import sparse_mx_to_torch_sparse_tensor 6 | from node.dataset import load 7 | 8 | 9 | # Borrowed from https://github.com/PetarV-/DGI 10 | class GCN(nn.Module): 11 | def __init__(self, in_ft, out_ft, bias=True): 12 | super(GCN, self).__init__() 13 | self.fc = nn.Linear(in_ft, out_ft, bias=False) 14 | self.act = nn.PReLU() 15 | 16 | if bias: 17 | self.bias = nn.Parameter(torch.FloatTensor(out_ft)) 18 | self.bias.data.fill_(0.0) 19 | else: 20 | self.register_parameter('bias', None) 21 | 22 | for m in self.modules(): 23 | self.weights_init(m) 24 | 25 | def weights_init(self, m): 26 | if isinstance(m, nn.Linear): 27 | torch.nn.init.xavier_uniform_(m.weight.data) 28 | if m.bias is not None: 29 | m.bias.data.fill_(0.0) 30 | 31 | # Shape of seq: (batch, nodes, features) 32 | def forward(self, seq, adj, sparse=False): 33 | seq_fts = self.fc(seq) 34 | if sparse: 35 | out = torch.unsqueeze(torch.spmm(adj, torch.squeeze(seq_fts, 0)), 0) 36 | else: 37 | out = torch.bmm(adj, seq_fts) 38 | if self.bias is not None: 39 | out += self.bias 40 | return self.act(out) 41 | 42 | 43 | # Borrowed from https://github.com/PetarV-/DGI 44 | class Readout(nn.Module): 45 | def __init__(self): 46 | super(Readout, self).__init__() 47 | 48 | def forward(self, seq, msk): 49 | if msk is None: 50 | return torch.mean(seq, 1) 51 | else: 52 | msk = torch.unsqueeze(msk, -1) 53 | return torch.mean(seq * msk, 1) / torch.sum(msk) 54 | 55 | 56 | # Borrowed from https://github.com/PetarV-/DGI 57 | class Discriminator(nn.Module): 58 | def __init__(self, n_h): 59 | super(Discriminator, self).__init__() 60 | self.f_k = nn.Bilinear(n_h, n_h, 1) 61 | 62 | for m in self.modules(): 63 | self.weights_init(m) 64 | 65 | def weights_init(self, m): 66 | if isinstance(m, nn.Bilinear): 67 | torch.nn.init.xavier_uniform_(m.weight.data) 68 | if m.bias is not None: 69 | m.bias.data.fill_(0.0) 70 | 71 | def forward(self, c1, c2, h1, h2, h3, h4, s_bias1=None, s_bias2=None): 72 | c_x1 = torch.unsqueeze(c1, 1) 73 | c_x1 = c_x1.expand_as(h1).contiguous() 74 | c_x2 = torch.unsqueeze(c2, 1) 75 | c_x2 = c_x2.expand_as(h2).contiguous() 76 | 77 | # positive 78 | sc_1 = torch.squeeze(self.f_k(h2, c_x1), 2) 79 | sc_2 = torch.squeeze(self.f_k(h1, c_x2), 2) 80 | 81 | # negetive 82 | sc_3 = torch.squeeze(self.f_k(h4, c_x1), 2) 83 | sc_4 = torch.squeeze(self.f_k(h3, c_x2), 2) 84 | 85 | logits = torch.cat((sc_1, sc_2, sc_3, sc_4), 1) 86 | return logits 87 | 88 | 89 | class Model(nn.Module): 90 | def __init__(self, n_in, n_h): 91 | super(Model, self).__init__() 92 | self.gcn1 = GCN(n_in, n_h) 93 | self.gcn2 = GCN(n_in, n_h) 94 | self.read = Readout() 95 | 96 | self.sigm = nn.Sigmoid() 97 | 98 | self.disc = Discriminator(n_h) 99 | 100 | def forward(self, seq1, seq2, adj, diff, sparse, msk, samp_bias1, samp_bias2): 101 | h_1 = self.gcn1(seq1, adj, sparse) 102 | c_1 = self.read(h_1, msk) 103 | c_1 = self.sigm(c_1) 104 | 105 | h_2 = self.gcn2(seq1, diff, sparse) 106 | c_2 = self.read(h_2, msk) 107 | c_2 = self.sigm(c_2) 108 | 109 | h_3 = self.gcn1(seq2, adj, sparse) 110 | h_4 = self.gcn2(seq2, diff, sparse) 111 | 112 | ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4, samp_bias1, samp_bias2) 113 | 114 | return ret, h_1, h_2 115 | 116 | def embed(self, seq, adj, diff, sparse, msk): 117 | h_1 = self.gcn1(seq, adj, sparse) 118 | c = self.read(h_1, msk) 119 | 120 | h_2 = self.gcn2(seq, diff, sparse) 121 | return (h_1 + h_2).detach(), c.detach() 122 | 123 | 124 | class LogReg(nn.Module): 125 | def __init__(self, ft_in, nb_classes): 126 | super(LogReg, self).__init__() 127 | self.fc = nn.Linear(ft_in, nb_classes) 128 | self.sigm = nn.Sigmoid() 129 | 130 | for m in self.modules(): 131 | self.weights_init(m) 132 | 133 | def weights_init(self, m): 134 | if isinstance(m, nn.Linear): 135 | torch.nn.init.xavier_uniform_(m.weight.data) 136 | if m.bias is not None: 137 | m.bias.data.fill_(0.0) 138 | 139 | def forward(self, seq): 140 | ret = torch.log_softmax(self.fc(seq), dim=-1) 141 | return ret 142 | 143 | 144 | def train(dataset, verbose=False): 145 | 146 | nb_epochs = 3000 147 | patience = 20 148 | lr = 0.001 149 | l2_coef = 0.0 150 | hid_units = 512 151 | sparse = False 152 | 153 | adj, diff, features, labels, idx_train, idx_val, idx_test = load(dataset) 154 | 155 | ft_size = features.shape[1] 156 | nb_classes = np.unique(labels).shape[0] 157 | 158 | sample_size = 2000 159 | batch_size = 4 160 | 161 | labels = torch.LongTensor(labels) 162 | idx_train = torch.LongTensor(idx_train) 163 | idx_test = torch.LongTensor(idx_test) 164 | 165 | lbl_1 = torch.ones(batch_size, sample_size * 2) 166 | lbl_2 = torch.zeros(batch_size, sample_size * 2) 167 | lbl = torch.cat((lbl_1, lbl_2), 1) 168 | 169 | model = Model(ft_size, hid_units) 170 | optimiser = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=l2_coef) 171 | 172 | if torch.cuda.is_available(): 173 | model.cuda() 174 | labels = labels.cuda() 175 | lbl = lbl.cuda() 176 | idx_train = idx_train.cuda() 177 | idx_test = idx_test.cuda() 178 | 179 | b_xent = nn.BCEWithLogitsLoss() 180 | xent = nn.CrossEntropyLoss() 181 | cnt_wait = 0 182 | best = 1e9 183 | best_t = 0 184 | 185 | for epoch in range(nb_epochs): 186 | 187 | idx = np.random.randint(0, adj.shape[-1] - sample_size + 1, batch_size) 188 | ba, bd, bf = [], [], [] 189 | for i in idx: 190 | ba.append(adj[i: i + sample_size, i: i + sample_size]) 191 | bd.append(diff[i: i + sample_size, i: i + sample_size]) 192 | bf.append(features[i: i + sample_size]) 193 | 194 | ba = np.array(ba).reshape(batch_size, sample_size, sample_size) 195 | bd = np.array(bd).reshape(batch_size, sample_size, sample_size) 196 | bf = np.array(bf).reshape(batch_size, sample_size, ft_size) 197 | 198 | if sparse: 199 | ba = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(ba)) 200 | bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd)) 201 | else: 202 | ba = torch.FloatTensor(ba) 203 | bd = torch.FloatTensor(bd) 204 | 205 | bf = torch.FloatTensor(bf) 206 | idx = np.random.permutation(sample_size) 207 | shuf_fts = bf[:, idx, :] 208 | 209 | if torch.cuda.is_available(): 210 | bf = bf.cuda() 211 | ba = ba.cuda() 212 | bd = bd.cuda() 213 | shuf_fts = shuf_fts.cuda() 214 | 215 | model.train() 216 | optimiser.zero_grad() 217 | 218 | logits, __, __ = model(bf, shuf_fts, ba, bd, sparse, None, None, None) 219 | 220 | loss = b_xent(logits, lbl) 221 | 222 | loss.backward() 223 | optimiser.step() 224 | 225 | if verbose: 226 | print('Epoch: {0}, Loss: {1:0.4f}'.format(epoch, loss.item())) 227 | 228 | if loss < best: 229 | best = loss 230 | best_t = epoch 231 | cnt_wait = 0 232 | torch.save(model.state_dict(), 'model.pkl') 233 | else: 234 | cnt_wait += 1 235 | 236 | if cnt_wait == patience: 237 | if verbose: 238 | print('Early stopping!') 239 | break 240 | 241 | if verbose: 242 | print('Loading {}th epoch'.format(best_t)) 243 | model.load_state_dict(torch.load('model.pkl')) 244 | 245 | if sparse: 246 | adj = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(adj)) 247 | diff = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(diff)) 248 | 249 | features = torch.FloatTensor(features[np.newaxis]) 250 | adj = torch.FloatTensor(adj[np.newaxis]) 251 | diff = torch.FloatTensor(diff[np.newaxis]) 252 | features = features.cuda() 253 | adj = adj.cuda() 254 | diff = diff.cuda() 255 | 256 | embeds, _ = model.embed(features, adj, diff, sparse, None) 257 | train_embs = embeds[0, idx_train] 258 | test_embs = embeds[0, idx_test] 259 | 260 | train_lbls = labels[idx_train] 261 | test_lbls = labels[idx_test] 262 | 263 | accs = [] 264 | wd = 0.01 if dataset == 'citeseer' else 0.0 265 | 266 | for _ in range(50): 267 | log = LogReg(hid_units, nb_classes) 268 | opt = torch.optim.Adam(log.parameters(), lr=1e-2, weight_decay=wd) 269 | log.cuda() 270 | for _ in range(300): 271 | log.train() 272 | opt.zero_grad() 273 | 274 | logits = log(train_embs) 275 | loss = xent(logits, train_lbls) 276 | 277 | loss.backward() 278 | opt.step() 279 | 280 | logits = log(test_embs) 281 | preds = torch.argmax(logits, dim=1) 282 | acc = torch.sum(preds == test_lbls).float() / test_lbls.shape[0] 283 | accs.append(acc * 100) 284 | 285 | accs = torch.stack(accs) 286 | print(accs.mean().item(), accs.std().item()) 287 | 288 | 289 | if __name__ == '__main__': 290 | import warnings 291 | warnings.filterwarnings("ignore") 292 | torch.cuda.set_device(3) 293 | 294 | # 'cora', 'citeseer', 'pubmed' 295 | dataset = 'cora' 296 | for __ in range(50): 297 | train(dataset) 298 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dgl==0.4.1 2 | networkx==2.4 3 | numpy==1.17.4 4 | opt-einsum==3.1.0 5 | pickleshare==0.7.5 6 | pytz==2019.2 7 | pyzmq==18.1.0 8 | requests==2.22.0 9 | scikit-learn==0.21.3 10 | scipy==1.3.2 11 | sklearn==0.0 12 | torch==1.3.1 13 | torch-cluster==1.4.5 14 | torch-geometric==1.3.2 15 | torch-scatter==1.4.0 16 | torch-sparse==0.4.3 17 | torchtext==0.4.0 18 | torchvision==0.4.2 19 | urllib3==1.25.6 20 | zipp==0.6.0 21 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import networkx as nx 3 | import torch 4 | from scipy.linalg import fractional_matrix_power, inv 5 | import scipy.sparse as sp 6 | 7 | 8 | def compute_ppr(graph: nx.Graph, alpha=0.2, self_loop=True): 9 | a = nx.convert_matrix.to_numpy_array(graph) 10 | if self_loop: 11 | a = a + np.eye(a.shape[0]) # A^ = A + I_n 12 | d = np.diag(np.sum(a, 1)) # D^ = Sigma A^_ii 13 | dinv = fractional_matrix_power(d, -0.5) # D^(-1/2) 14 | at = np.matmul(np.matmul(dinv, a), dinv) # A~ = D^(-1/2) x A^ x D^(-1/2) 15 | return alpha * inv((np.eye(a.shape[0]) - (1 - alpha) * at)) # a(I_n-(1-a)A~)^-1 16 | 17 | 18 | def compute_heat(graph: nx.Graph, t=5, self_loop=True): 19 | a = nx.convert_matrix.to_numpy_array(graph) 20 | if self_loop: 21 | a = a + np.eye(a.shape[0]) 22 | d = np.diag(np.sum(a, 1)) 23 | return np.exp(t * (np.matmul(a, inv(d)) - 1)) 24 | 25 | 26 | def sparse_to_tuple(sparse_mx): 27 | """Convert sparse matrix to tuple representation.""" 28 | 29 | def to_tuple(mx): 30 | if not sp.isspmatrix_coo(mx): 31 | mx = mx.tocoo() 32 | coords = np.vstack((mx.row, mx.col)).transpose() 33 | values = mx.data 34 | shape = mx.shape 35 | return coords, values, shape 36 | 37 | if isinstance(sparse_mx, list): 38 | for i in range(len(sparse_mx)): 39 | sparse_mx[i] = to_tuple(sparse_mx[i]) 40 | else: 41 | sparse_mx = to_tuple(sparse_mx) 42 | 43 | return sparse_mx 44 | 45 | 46 | def preprocess_features(features): 47 | """Row-normalize feature matrix and convert to tuple representation""" 48 | rowsum = np.array(features.sum(1)) 49 | r_inv = np.power(rowsum, -1).flatten() 50 | r_inv[np.isinf(r_inv)] = 0. 51 | r_mat_inv = sp.diags(r_inv) 52 | features = r_mat_inv.dot(features) 53 | if isinstance(features, np.ndarray): 54 | return features 55 | else: 56 | return features.todense(), sparse_to_tuple(features) 57 | 58 | 59 | def normalize_adj(adj, self_loop=True): 60 | """Symmetrically normalize adjacency matrix.""" 61 | if self_loop: 62 | adj = adj + sp.eye(adj.shape[0]) 63 | adj = sp.coo_matrix(adj) 64 | rowsum = np.array(adj.sum(1)) 65 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 66 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 67 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 68 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 69 | 70 | 71 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 72 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 73 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 74 | indices = torch.from_numpy( 75 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 76 | values = torch.from_numpy(sparse_mx.data) 77 | shape = torch.Size(sparse_mx.shape) 78 | return torch.sparse.FloatTensor(indices, values, shape) --------------------------------------------------------------------------------