├── NeighborOverlap.py ├── NeighborOverlapCitation2.py ├── README.md ├── env.yaml ├── model.py ├── ogbdataset.py └── utils.py /NeighborOverlap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import torch.nn as nn 6 | from torch_sparse import SparseTensor 7 | import torch_geometric.transforms as T 8 | from model import predictor_dict, convdict, GCN, DropEdge 9 | from functools import partial 10 | from sklearn.metrics import roc_auc_score, average_precision_score 11 | from ogb.linkproppred import PygLinkPropPredDataset, Evaluator 12 | from torch_geometric.utils import negative_sampling 13 | from torch.utils.tensorboard import SummaryWriter 14 | from utils import PermIterator 15 | import time 16 | from ogbdataset import loaddataset 17 | from typing import Iterable 18 | 19 | 20 | def set_seed(seed): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | 25 | 26 | def train(model, 27 | predictor, 28 | data, 29 | split_edge, 30 | optimizer, 31 | batch_size, 32 | maskinput: bool = True, 33 | cnprobs: Iterable[float]=[], 34 | alpha: float=None): 35 | 36 | if alpha is not None: 37 | predictor.setalpha(alpha) 38 | 39 | model.train() 40 | predictor.train() 41 | 42 | pos_train_edge = split_edge['train']['edge'].to(data.x.device) 43 | pos_train_edge = pos_train_edge.t() 44 | 45 | total_loss = [] 46 | adjmask = torch.ones_like(pos_train_edge[0], dtype=torch.bool) 47 | 48 | negedge = negative_sampling(data.edge_index.to(pos_train_edge.device), data.adj_t.sizes()[0]) 49 | for perm in PermIterator( 50 | adjmask.device, adjmask.shape[0], batch_size 51 | ): 52 | optimizer.zero_grad() 53 | if maskinput: 54 | adjmask[perm] = 0 55 | tei = pos_train_edge[:, adjmask] 56 | adj = SparseTensor.from_edge_index(tei, 57 | sparse_sizes=(data.num_nodes, data.num_nodes)).to_device( 58 | pos_train_edge.device, non_blocking=True) 59 | adjmask[perm] = 1 60 | adj = adj.to_symmetric() 61 | else: 62 | adj = data.adj_t 63 | h = model(data.x, adj) 64 | edge = pos_train_edge[:, perm] 65 | pos_outs = predictor.multidomainforward(h, 66 | adj, 67 | edge, 68 | cndropprobs=cnprobs) 69 | 70 | pos_losss = -F.logsigmoid(pos_outs).mean() 71 | edge = negedge[:, perm] 72 | neg_outs = predictor.multidomainforward(h, adj, edge, cndropprobs=cnprobs) 73 | neg_losss = -F.logsigmoid(-neg_outs).mean() 74 | loss = neg_losss + pos_losss 75 | loss.backward() 76 | optimizer.step() 77 | 78 | total_loss.append(loss) 79 | total_loss = np.average([_.item() for _ in total_loss]) 80 | return total_loss 81 | 82 | 83 | @torch.no_grad() 84 | def test(model, predictor, data, split_edge, evaluator, batch_size, 85 | use_valedges_as_input): 86 | model.eval() 87 | predictor.eval() 88 | 89 | pos_train_edge = split_edge['train']['edge'].to(data.adj_t.device()) 90 | pos_valid_edge = split_edge['valid']['edge'].to(data.adj_t.device()) 91 | neg_valid_edge = split_edge['valid']['edge_neg'].to(data.adj_t.device()) 92 | pos_test_edge = split_edge['test']['edge'].to(data.adj_t.device()) 93 | neg_test_edge = split_edge['test']['edge_neg'].to(data.adj_t.device()) 94 | 95 | adj = data.adj_t 96 | h = model(data.x, adj) 97 | 98 | 99 | pos_train_pred = torch.cat([ 100 | predictor(h, adj, pos_train_edge[perm].t()).squeeze().cpu() 101 | for perm in PermIterator(pos_train_edge.device, 102 | pos_train_edge.shape[0], batch_size, False) 103 | ], 104 | dim=0) 105 | 106 | 107 | pos_valid_pred = torch.cat([ 108 | predictor(h, adj, pos_valid_edge[perm].t()).squeeze().cpu() 109 | for perm in PermIterator(pos_valid_edge.device, 110 | pos_valid_edge.shape[0], batch_size, False) 111 | ], 112 | dim=0) 113 | neg_valid_pred = torch.cat([ 114 | predictor(h, adj, neg_valid_edge[perm].t()).squeeze().cpu() 115 | for perm in PermIterator(neg_valid_edge.device, 116 | neg_valid_edge.shape[0], batch_size, False) 117 | ], 118 | dim=0) 119 | if use_valedges_as_input: 120 | adj = data.full_adj_t 121 | h = model(data.x, adj) 122 | 123 | pos_test_pred = torch.cat([ 124 | predictor(h, adj, pos_test_edge[perm].t()).squeeze().cpu() 125 | for perm in PermIterator(pos_test_edge.device, pos_test_edge.shape[0], 126 | batch_size, False) 127 | ], 128 | dim=0) 129 | 130 | neg_test_pred = torch.cat([ 131 | predictor(h, adj, neg_test_edge[perm].t()).squeeze().cpu() 132 | for perm in PermIterator(neg_test_edge.device, neg_test_edge.shape[0], 133 | batch_size, False) 134 | ], 135 | dim=0) 136 | 137 | results = {} 138 | for K in [20, 50, 100]: 139 | evaluator.K = K 140 | 141 | train_hits = evaluator.eval({ 142 | 'y_pred_pos': pos_train_pred, 143 | 'y_pred_neg': neg_valid_pred, 144 | })[f'hits@{K}'] 145 | 146 | valid_hits = evaluator.eval({ 147 | 'y_pred_pos': pos_valid_pred, 148 | 'y_pred_neg': neg_valid_pred, 149 | })[f'hits@{K}'] 150 | test_hits = evaluator.eval({ 151 | 'y_pred_pos': pos_test_pred, 152 | 'y_pred_neg': neg_test_pred, 153 | })[f'hits@{K}'] 154 | 155 | results[f'Hits@{K}'] = (train_hits, valid_hits, test_hits) 156 | return results, h.cpu() 157 | 158 | 159 | def parseargs(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument('--use_valedges_as_input', action='store_true', help="whether to add validation edges to the input adjacency matrix of gnn") 162 | parser.add_argument('--epochs', type=int, default=40, help="number of epochs") 163 | parser.add_argument('--runs', type=int, default=3, help="number of repeated runs") 164 | parser.add_argument('--dataset', type=str, default="collab") 165 | 166 | parser.add_argument('--batch_size', type=int, default=8192, help="batch size") 167 | parser.add_argument('--testbs', type=int, default=8192, help="batch size for test") 168 | parser.add_argument('--maskinput', action="store_true", help="whether to use target link removal") 169 | 170 | parser.add_argument('--mplayers', type=int, default=1, help="number of message passing layers") 171 | parser.add_argument('--nnlayers', type=int, default=3, help="number of mlp layers") 172 | parser.add_argument('--hiddim', type=int, default=32, help="hidden dimension") 173 | parser.add_argument('--ln', action="store_true", help="whether to use layernorm in MPNN") 174 | parser.add_argument('--lnnn', action="store_true", help="whether to use layernorm in mlp") 175 | parser.add_argument('--res', action="store_true", help="whether to use residual connection") 176 | parser.add_argument('--jk', action="store_true", help="whether to use JumpingKnowledge connection") 177 | parser.add_argument('--gnndp', type=float, default=0.3, help="dropout ratio of gnn") 178 | parser.add_argument('--xdp', type=float, default=0.3, help="dropout ratio of gnn") 179 | parser.add_argument('--tdp', type=float, default=0.3, help="dropout ratio of gnn") 180 | parser.add_argument('--gnnedp', type=float, default=0.3, help="edge dropout ratio of gnn") 181 | parser.add_argument('--predp', type=float, default=0.3, help="dropout ratio of predictor") 182 | parser.add_argument('--preedp', type=float, default=0.3, help="edge dropout ratio of predictor") 183 | parser.add_argument('--gnnlr', type=float, default=0.0003, help="learning rate of gnn") 184 | parser.add_argument('--prelr', type=float, default=0.0003, help="learning rate of predictor") 185 | # detailed hyperparameters 186 | parser.add_argument('--beta', type=float, default=1) 187 | parser.add_argument('--alpha', type=float, default=1) 188 | parser.add_argument("--use_xlin", action="store_true") 189 | parser.add_argument("--tailact", action="store_true") 190 | parser.add_argument("--twolayerlin", action="store_true") 191 | parser.add_argument("--increasealpha", action="store_true") 192 | 193 | parser.add_argument('--splitsize', type=int, default=-1, help="split some operations inner the model. Only speed and GPU memory consumption are affected.") 194 | 195 | # parameters used to calibrate the edge existence probability in NCNC 196 | parser.add_argument('--probscale', type=float, default=5) 197 | parser.add_argument('--proboffset', type=float, default=3) 198 | parser.add_argument('--pt', type=float, default=0.5) 199 | parser.add_argument("--learnpt", action="store_true") 200 | 201 | # For scalability, NCNC samples neighbors to complete common neighbor. 202 | parser.add_argument('--trndeg', type=int, default=-1, help="maximum number of sampled neighbors during the training process. -1 means no sample") 203 | parser.add_argument('--tstdeg', type=int, default=-1, help="maximum number of sampled neighbors during the test process") 204 | # NCN can sample common neighbors for scalability. Generally not used. 205 | parser.add_argument('--cndeg', type=int, default=-1) 206 | 207 | # predictor used, such as NCN, NCNC 208 | parser.add_argument('--predictor', choices=predictor_dict.keys()) 209 | parser.add_argument("--depth", type=int, default=1, help="number of completion steps in NCNC") 210 | # gnn used, such as gin, gcn. 211 | parser.add_argument('--model', choices=convdict.keys()) 212 | 213 | parser.add_argument('--save_gemb', action="store_true", help="whether to save node representations produced by GNN") 214 | parser.add_argument('--load', type=str, help="where to load node representations produced by GNN") 215 | parser.add_argument("--loadmod", action="store_true", help="whether to load trained models") 216 | parser.add_argument("--savemod", action="store_true", help="whether to save trained models") 217 | 218 | parser.add_argument("--savex", action="store_true", help="whether to save trained node embeddings") 219 | parser.add_argument("--loadx", action="store_true", help="whether to load trained node embeddings") 220 | 221 | 222 | # not used in experiments 223 | parser.add_argument('--cnprob', type=float, default=0) 224 | args = parser.parse_args() 225 | return args 226 | 227 | 228 | def main(): 229 | args = parseargs() 230 | print(args, flush=True) 231 | 232 | hpstr = str(args).replace(" ", "").replace("Namespace(", "").replace( 233 | ")", "").replace("True", "1").replace("False", "0").replace("=", "").replace("epochs", "").replace("runs", "").replace("save_gemb", "") 234 | writer = SummaryWriter(f"./rec/{args.model}_{args.predictor}") 235 | writer.add_text("hyperparams", hpstr) 236 | 237 | if args.dataset in ["Cora", "Citeseer", "Pubmed"]: 238 | evaluator = Evaluator(name=f'ogbl-ppa') 239 | else: 240 | evaluator = Evaluator(name=f'ogbl-{args.dataset}') 241 | 242 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 243 | data, split_edge = loaddataset(args.dataset, args.use_valedges_as_input, args.load) 244 | data = data.to(device) 245 | 246 | predfn = predictor_dict[args.predictor] 247 | if args.predictor != "cn0": 248 | predfn = partial(predfn, cndeg=args.cndeg) 249 | if args.predictor in ["cn1", "incn1cn1", "scn1", "catscn1", "sincn1cn1"]: 250 | predfn = partial(predfn, use_xlin=args.use_xlin, tailact=args.tailact, twolayerlin=args.twolayerlin, beta=args.beta) 251 | if args.predictor == "incn1cn1": 252 | predfn = partial(predfn, depth=args.depth, splitsize=args.splitsize, scale=args.probscale, offset=args.proboffset, trainresdeg=args.trndeg, testresdeg=args.tstdeg, pt=args.pt, learnablept=args.learnpt, alpha=args.alpha) 253 | 254 | ret = [] 255 | 256 | for run in range(0, args.runs): 257 | set_seed(run) 258 | if args.dataset in ["Cora", "Citeseer", "Pubmed"]: 259 | data, split_edge = loaddataset(args.dataset, args.use_valedges_as_input, args.load) # get a new split of dataset 260 | data = data.to(device) 261 | bestscore = None 262 | 263 | # build model 264 | model = GCN(data.num_features, args.hiddim, args.hiddim, args.mplayers, 265 | args.gnndp, args.ln, args.res, data.max_x, 266 | args.model, args.jk, args.gnnedp, xdropout=args.xdp, taildropout=args.tdp, noinputlin=args.loadx).to(device) 267 | if args.loadx: 268 | with torch.no_grad(): 269 | model.xemb[0].weight.copy_(torch.load(f"gemb/{args.dataset}_{args.model}_cn1_{args.hiddim}_{run}.pt", map_location="cpu")) 270 | model.xemb[0].weight.requires_grad_(False) 271 | predictor = predfn(args.hiddim, args.hiddim, 1, args.nnlayers, 272 | args.predp, args.preedp, args.lnnn).to(device) 273 | if args.loadmod: 274 | keys = model.load_state_dict(torch.load(f"gmodel/{args.dataset}_{args.model}_cn1_{args.hiddim}_{run}.pt", map_location="cpu"), strict=False) 275 | print("unmatched params", keys, flush=True) 276 | keys = predictor.load_state_dict(torch.load(f"gmodel/{args.dataset}_{args.model}_cn1_{args.hiddim}_{run}.pre.pt", map_location="cpu"), strict=False) 277 | print("unmatched params", keys, flush=True) 278 | 279 | 280 | optimizer = torch.optim.Adam([{'params': model.parameters(), "lr": args.gnnlr}, 281 | {'params': predictor.parameters(), 'lr': args.prelr}]) 282 | 283 | for epoch in range(1, 1 + args.epochs): 284 | alpha = max(0, min((epoch-5)*0.1, 1)) if args.increasealpha else None 285 | t1 = time.time() 286 | loss = train(model, predictor, data, split_edge, optimizer, 287 | args.batch_size, args.maskinput, [], alpha) 288 | print(f"trn time {time.time()-t1:.2f} s", flush=True) 289 | if True: 290 | t1 = time.time() 291 | results, h = test(model, predictor, data, split_edge, evaluator, 292 | args.testbs, args.use_valedges_as_input) 293 | print(f"test time {time.time()-t1:.2f} s") 294 | if bestscore is None: 295 | bestscore = {key: list(results[key]) for key in results} 296 | for key, result in results.items(): 297 | writer.add_scalars(f"{key}_{run}", { 298 | "trn": result[0], 299 | "val": result[1], 300 | "tst": result[2] 301 | }, epoch) 302 | 303 | if True: 304 | for key, result in results.items(): 305 | train_hits, valid_hits, test_hits = result 306 | if valid_hits > bestscore[key][1]: 307 | bestscore[key] = list(result) 308 | if args.save_gemb: 309 | torch.save(h, f"gemb/{args.dataset}_{args.model}_{args.predictor}_{args.hiddim}.pt") 310 | if args.savex: 311 | torch.save(model.xemb[0].weight.detach(), f"gemb/{args.dataset}_{args.model}_{args.predictor}_{args.hiddim}_{run}.pt") 312 | if args.savemod: 313 | torch.save(model.state_dict(), f"gmodel/{args.dataset}_{args.model}_{args.predictor}_{args.hiddim}_{run}.pt") 314 | torch.save(predictor.state_dict(), f"gmodel/{args.dataset}_{args.model}_{args.predictor}_{args.hiddim}_{run}.pre.pt") 315 | print(key) 316 | print(f'Run: {run + 1:02d}, ' 317 | f'Epoch: {epoch:02d}, ' 318 | f'Loss: {loss:.4f}, ' 319 | f'Train: {100 * train_hits:.2f}%, ' 320 | f'Valid: {100 * valid_hits:.2f}%, ' 321 | f'Test: {100 * test_hits:.2f}%') 322 | print('---', flush=True) 323 | print(f"best {bestscore}") 324 | if args.dataset == "collab": 325 | ret.append(bestscore["Hits@50"][-2:]) 326 | elif args.dataset == "ppa": 327 | ret.append(bestscore["Hits@100"][-2:]) 328 | elif args.dataset == "ddi": 329 | ret.append(bestscore["Hits@20"][-2:]) 330 | elif args.dataset == "citation2": 331 | ret.append(bestscore[-2:]) 332 | elif args.dataset in ["Pubmed", "Cora", "Citeseer"]: 333 | ret.append(bestscore["Hits@100"][-2:]) 334 | else: 335 | raise NotImplementedError 336 | ret = np.array(ret) 337 | print(ret) 338 | print(f"Final result: val {np.average(ret[:, 0]):.4f} {np.std(ret[:, 0]):.4f} tst {np.average(ret[:, 1]):.4f} {np.std(ret[:, 1]):.4f}") 339 | 340 | 341 | if __name__ == "__main__": 342 | main() -------------------------------------------------------------------------------- /NeighborOverlapCitation2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from sklearn.metrics import accuracy_score, roc_auc_score 4 | import torch 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | from torch_sparse import SparseTensor 8 | import torch_geometric.transforms as T 9 | from model import predictor_dict, convdict, GCN, DropEdge 10 | from functools import partial 11 | 12 | from ogb.linkproppred import Evaluator 13 | from ogbdataset import loaddataset 14 | from torch.utils.tensorboard import SummaryWriter 15 | from utils import PermIterator 16 | import time 17 | 18 | 19 | def set_seed(seed): 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | np.random.seed(seed) 23 | 24 | 25 | def train(model, 26 | predictor, 27 | data, 28 | split_edge, 29 | optimizer, 30 | batch_size, 31 | maskinput: bool = True): 32 | model.train() 33 | predictor.train() 34 | 35 | source_edge = split_edge['train']['source_node'].to(data.x.device) 36 | target_edge = split_edge['train']['target_node'].to(data.x.device) 37 | 38 | total_loss = [] 39 | adjmask = torch.ones_like(source_edge, dtype=torch.bool) 40 | for perm in PermIterator( 41 | source_edge.device, source_edge.shape[0], batch_size 42 | ): 43 | optimizer.zero_grad() 44 | if maskinput: 45 | adjmask[perm] = 0 46 | tei = torch.stack((source_edge[adjmask], target_edge[adjmask]), dim=0) 47 | adj = SparseTensor.from_edge_index(tei, 48 | sparse_sizes=(data.num_nodes, data.num_nodes)).to_device( 49 | source_edge.device, non_blocking=True) 50 | adjmask[perm] = 1 51 | adj = adj.to_symmetric() 52 | else: 53 | adj = data.adj_t 54 | h = model(data.x, adj) 55 | 56 | src, dst = source_edge[perm], target_edge[perm] 57 | pos_out = predictor(h, adj, torch.stack((src, dst))) 58 | 59 | pos_loss = -F.logsigmoid(pos_out).mean() 60 | 61 | dst_neg = torch.randint(0, data.num_nodes, src.size(), 62 | dtype=torch.long, device=h.device) 63 | neg_out = predictor(h, adj, torch.stack((src, dst_neg))) 64 | neg_loss = -F.logsigmoid(-neg_out).mean() 65 | 66 | loss = pos_loss + neg_loss 67 | loss.backward() 68 | 69 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 70 | nn.utils.clip_grad_norm_(predictor.parameters(), 1.0) 71 | 72 | optimizer.step() 73 | 74 | total_loss.append(loss) 75 | total_loss = np.average([_.item() for _ in total_loss]) 76 | return total_loss 77 | 78 | 79 | @torch.no_grad() 80 | def test(model, predictor, data, split_edge, evaluator, batch_size): 81 | model.eval() 82 | predictor.eval() 83 | adj = data.full_adj_t 84 | h = model(data.x, adj) 85 | 86 | def test_split(split): 87 | source = split_edge[split]['source_node'].to(h.device) 88 | target = split_edge[split]['target_node'].to(h.device) 89 | target_neg = split_edge[split]['target_node_neg'].to(h.device) 90 | 91 | pos_preds = [] 92 | for perm in PermIterator(source.device, source.shape[0], batch_size, False): 93 | src, dst = source[perm], target[perm] 94 | pos_preds += [predictor(h, adj, torch.stack((src, dst))).squeeze().cpu()] 95 | pos_pred = torch.cat(pos_preds, dim=0) 96 | 97 | neg_preds = [] 98 | source = source.view(-1, 1).repeat(1, 1000).view(-1) 99 | target_neg = target_neg.view(-1) 100 | for perm in PermIterator(source.device, source.shape[0], batch_size, False): 101 | src, dst_neg = source[perm], target_neg[perm] 102 | neg_preds += [predictor(h, adj, torch.stack((src, dst_neg))).squeeze().cpu()] 103 | neg_pred = torch.cat(neg_preds, dim=0).view(-1, 1000) 104 | 105 | return evaluator.eval({ 106 | 'y_pred_pos': pos_pred, 107 | 'y_pred_neg': neg_pred, 108 | })['mrr_list'].mean().item() 109 | 110 | train_mrr = 0.0 #test_split('eval_train') 111 | valid_mrr = test_split('valid') 112 | test_mrr = test_split('test') 113 | 114 | return train_mrr, valid_mrr, test_mrr, h.cpu() 115 | 116 | 117 | def parseargs(): 118 | #please refer to NeighborOverlap.py/parseargs for the meanings of these options 119 | parser = argparse.ArgumentParser() 120 | parser.add_argument('--maskinput', action="store_true") 121 | 122 | parser.add_argument('--mplayers', type=int, default=1) 123 | parser.add_argument('--nnlayers', type=int, default=3) 124 | parser.add_argument('--hiddim', type=int, default=32) 125 | parser.add_argument('--ln', action="store_true") 126 | parser.add_argument('--lnnn', action="store_true") 127 | parser.add_argument('--res', action="store_true") 128 | parser.add_argument('--jk', action="store_true") 129 | parser.add_argument('--gnndp', type=float, default=0.3) 130 | parser.add_argument('--xdp', type=float, default=0.3) 131 | parser.add_argument('--tdp', type=float, default=0.3) 132 | parser.add_argument('--gnnedp', type=float, default=0.3) 133 | parser.add_argument('--predp', type=float, default=0.3) 134 | parser.add_argument('--preedp', type=float, default=0.3) 135 | parser.add_argument('--gnnlr', type=float, default=0.0003) 136 | parser.add_argument('--prelr', type=float, default=0.0003) 137 | parser.add_argument('--batch_size', type=int, default=8192) 138 | parser.add_argument('--testbs', type=int, default=8192) 139 | parser.add_argument('--epochs', type=int, default=40) 140 | parser.add_argument('--runs', type=int, default=3) 141 | parser.add_argument('--probscale', type=float, default=5) 142 | parser.add_argument('--proboffset', type=float, default=3) 143 | parser.add_argument('--beta', type=float, default=1) 144 | parser.add_argument('--alpha', type=float, default=1) 145 | parser.add_argument('--trndeg', type=int, default=-1) 146 | parser.add_argument('--tstdeg', type=int, default=-1) 147 | parser.add_argument('--dataset', type=str, default="collab") 148 | parser.add_argument('--predictor', choices=predictor_dict.keys()) 149 | parser.add_argument('--model', choices=convdict.keys()) 150 | parser.add_argument('--cndeg', type=int, default=-1) 151 | parser.add_argument('--save_gemb', action="store_true") 152 | parser.add_argument('--load', type=str) 153 | parser.add_argument('--cnprob', type=float, default=0) 154 | parser.add_argument('--pt', type=float, default=0.5) 155 | parser.add_argument("--learnpt", action="store_true") 156 | parser.add_argument("--use_xlin", action="store_true") 157 | parser.add_argument("--tailact", action="store_true") 158 | parser.add_argument("--twolayerlin", action="store_true") 159 | parser.add_argument("--use_valedges_as_input", action="store_true") 160 | parser.add_argument('--splitsize', type=int, default=-1) 161 | parser.add_argument('--depth', type=int, default=-1) 162 | args = parser.parse_args() 163 | return args 164 | 165 | 166 | def main(): 167 | args = parseargs() 168 | print(args, flush=True) 169 | hpstr = str(args).replace(" ", "").replace("Namespace(", "").replace( 170 | ")", "").replace("True", "1").replace("False", "0").replace("=", "").replace("epochs", "").replace("runs", "").replace("save_gemb", "") 171 | writer = SummaryWriter(f"./rec/{args.model}_{args.predictor}") 172 | writer.add_text("hyperparams", hpstr) 173 | 174 | device = torch.device(f'cuda' if torch.cuda.is_available() else 'cpu') 175 | evaluator = Evaluator(name=f'ogbl-{args.dataset}') 176 | 177 | data, split_edge = loaddataset(args.dataset, False, args.load) 178 | 179 | data = data.to(device) 180 | 181 | predfn = predictor_dict[args.predictor] 182 | 183 | if args.predictor != "cn0": 184 | predfn = partial(predfn, cndeg=args.cndeg) 185 | if args.predictor in ["cn1", "incn1cn1", "scn1", "catscn1", "sincn1cn1"]: 186 | predfn = partial(predfn, use_xlin=args.use_xlin, tailact=args.tailact, twolayerlin=args.twolayerlin, beta=args.beta) 187 | if args.predictor in ["incn1cn1", "sincn1cn1"]: 188 | predfn = partial(predfn, depth=args.depth, splitsize=args.splitsize, scale=args.probscale, offset=args.proboffset, trainresdeg=args.trndeg, testresdeg=args.tstdeg, pt=args.pt, learnablept=args.learnpt, alpha=args.alpha) 189 | ret = [] 190 | 191 | for run in range(args.runs): 192 | set_seed(run) 193 | bestscore = [0, 0, 0] 194 | model = GCN(data.num_features, args.hiddim, args.hiddim, args.mplayers, 195 | args.gnndp, args.ln, args.res, data.max_x, 196 | args.model, args.jk, args.gnnedp, xdropout=args.xdp, taildropout=args.tdp).to(device) 197 | 198 | predictor = predfn(args.hiddim, args.hiddim, 1, args.nnlayers, 199 | args.predp, args.preedp, args.lnnn).to(device) 200 | optimizer = torch.optim.Adam([{'params': model.parameters(), "lr": args.gnnlr}, 201 | {'params': predictor.parameters(), 'lr': args.prelr}]) 202 | 203 | for epoch in range(1, 1 + args.epochs): 204 | t1 = time.time() 205 | loss = train(model, predictor, data, split_edge, optimizer, 206 | args.batch_size, args.maskinput) 207 | print(f"trn time {time.time()-t1:.2f} s") 208 | if True: 209 | t1 = time.time() 210 | results = test(model, predictor, data, split_edge, evaluator, 211 | args.testbs) 212 | results, h = results[:-1], results[-1] 213 | print(f"test time {time.time()-t1:.2f} s") 214 | writer.add_scalars(f"mrr_{run}", { 215 | "trn": results[0], 216 | "val": results[1], 217 | "tst": results[2] 218 | }, epoch) 219 | 220 | if True: 221 | train_mrr, valid_mrr, test_mrr = results 222 | train_mrr, valid_mrr, test_mrr = results 223 | if valid_mrr > bestscore[1]: 224 | bestscore = list(results) 225 | bestscore = list(results) 226 | if args.save_gemb: 227 | torch.save(h, f"gemb/citation2_{args.model}_{args.predictor}.pt") 228 | 229 | print(f'Run: {run + 1:02d}, ' 230 | f'Epoch: {epoch:02d}, ' 231 | f'Loss: {loss:.4f}, ' 232 | f'Train: {100 * train_mrr:.2f}%, ' 233 | f'Valid: {100 * valid_mrr:.2f}%, ' 234 | f'Test: {100 * test_mrr:.2f}%') 235 | print('---', flush=True) 236 | print(f"best {bestscore}") 237 | if args.dataset == "citation2": 238 | ret.append(bestscore) 239 | else: 240 | raise NotImplementedError 241 | ret = np.array(ret) 242 | print(ret) 243 | print(f"Final result: {np.average(ret[:, 1])} {np.std(ret[:, 1])} {np.average(ret[:, 2])} {np.std(ret[:, 2])}") 244 | 245 | if __name__ == "__main__": 246 | main() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains the official code for the paper [Neural Common Neighbor with Completion for Link Prediction](https://arxiv.org/pdf/2302.00890.pdf). 2 | 3 | **Environment** 4 | 5 | Tested Combination: 6 | torch 1.13.0 + pyg 2.2.0 + ogb 1.3.5 7 | 8 | ``` 9 | conda env create -f env.yaml 10 | ``` 11 | 12 | **Prepare Datasets** 13 | 14 | ``` 15 | python ogbdataset.py 16 | ``` 17 | 18 | **Reproduce Results** 19 | 20 | We implement the following models. 21 | 22 | | name | $model | command change | 23 | |----------|-----------|--------------------| 24 | | GAE | cn0 | | 25 | | NCN | cn1 | | 26 | | NCNC | incn1cn1 | | 27 | | NCNC2 | incn1cn1 | add --depth 2 --splitsize 131072 | 28 | | GAE+CN | scn1 | | 29 | | NCN2 | cn1.5 | | 30 | | NCN-diff | cn1res | | 31 | | NoTLR | cn1 | delete --maskinput | 32 | 33 | To reproduce the results, please modify the following commands as shown in the table above. 34 | 35 | Cora 36 | ``` 37 | python NeighborOverlap.py --xdp 0.7 --tdp 0.3 --pt 0.75 --gnnedp 0.0 --preedp 0.4 --predp 0.05 --gnndp 0.05 --probscale 4.3 --proboffset 2.8 --alpha 1.0 --gnnlr 0.0043 --prelr 0.0024 --batch_size 1152 --ln --lnnn --predictor $model --dataset Cora --epochs 100 --runs 10 --model puregcn --hiddim 256 --mplayers 1 --testbs 8192 --maskinput --jk --use_xlin --tailact 38 | ``` 39 | 40 | Citeseer 41 | ``` 42 | python NeighborOverlap.py --xdp 0.4 --tdp 0.0 --pt 0.75 --gnnedp 0.0 --preedp 0.0 --predp 0.55 --gnndp 0.75 --probscale 6.5 --proboffset 4.4 --alpha 0.4 --gnnlr 0.0085 --prelr 0.0078 --batch_size 384 --ln --lnnn --predictor $model --dataset Citeseer --epochs 100 --runs 10 --model puregcn --hiddim 256 --mplayers 1 --testbs 4096 --maskinput --jk --use_xlin --tailact --twolayerlin 43 | ``` 44 | 45 | Pubmed 46 | ``` 47 | python NeighborOverlap.py --xdp 0.3 --tdp 0.0 --pt 0.5 --gnnedp 0.0 --preedp 0.0 --predp 0.05 --gnndp 0.1 --probscale 5.3 --proboffset 0.5 --alpha 0.3 --gnnlr 0.0097 --prelr 0.002 --batch_size 2048 --ln --lnnn --predictor $model --dataset Pubmed --epochs 100 --runs 10 --model puregcn --hiddim 256 --mplayers 1 --testbs 8192 --maskinput --jk --use_xlin --tailact 48 | ``` 49 | 50 | collab 51 | ``` 52 | python NeighborOverlap.py --xdp 0.25 --tdp 0.05 --pt 0.1 --gnnedp 0.25 --preedp 0.0 --predp 0.3 --gnndp 0.1 --probscale 2.5 --proboffset 6.0 --alpha 1.05 --gnnlr 0.0082 --prelr 0.0037 --batch_size 65536 --ln --lnnn --predictor $model --dataset collab --epochs 100 --runs 10 --model gcn --hiddim 64 --mplayers 1 --testbs 131072 --maskinput --use_valedges_as_input --res --use_xlin --tailact 53 | ``` 54 | 55 | ppa 56 | ``` 57 | python NeighborOverlap.py --xdp 0.0 --tdp 0.0 --gnnedp 0.1 --preedp 0.0 --predp 0.1 --gnndp 0.0 --gnnlr 0.0013 --prelr 0.0013 --batch_size 16384 --ln --lnnn --predictor $model --dataset ppa --epochs 25 --runs 10 --model gcn --hiddim 64 --mplayers 3 --maskinput --tailact --res --testbs 65536 --proboffset 8.5 --probscale 4.0 --pt 0.1 --alpha 0.9 --splitsize 131072 58 | ``` 59 | 60 | The following datasets use separate commands for NCN and NCNC. To use other models, please modify NCN's command. Note that NCNC models in these datasets initialize parameters with trained NCN models to accelerate training. Please use our pre-trained model or run NCN first. 61 | 62 | citation2 63 | ``` 64 | python NeighborOverlapCitation2.py --xdp 0.0 --tdp 0.3 --gnnedp 0.0 --preedp 0.0 --predp 0.2 --gnndp 0.2 --gnnlr 0.0088 --prelr 0.0058 --batch_size 32768 --ln --lnnn --predictor cn1 --dataset citation2 --epochs 20 --runs 10 --model puregcn --hiddim 64 --mplayers 3 --res --testbs 65536 --use_xlin --tailact --proboffset 4.7 --probscale 7.0 --pt 0.3 --trndeg 128 --tstdeg 128 --save_gemb 65 | 66 | 67 | python NeighborOverlapCitation2.py --xdp 0.0 --tdp 0.3 --gnnedp 0.0 --preedp 0.0 --predp 0.2 --gnndp 0.2 --gnnlr 0.0088 --prelr 0.001 --batch_size 24576 --ln --lnnn --predictor incn1cn1 --dataset citation2 --epochs 20 --runs 10 --model none --hiddim 64 --mplayers 0 --res --testbs 65536 --use_xlin --tailact --load gemb/citation2_puregcn_cn1.pt --proboffset -0.3 --probscale 1.4 --pt 0.25 --trndeg 96 --tstdeg 96 --load gemb/citation2_puregcn_cn1.pt 68 | ``` 69 | 70 | 71 | ddi 72 | ``` 73 | python NeighborOverlap.py --xdp 0.05 --tdp 0.0 --gnnedp 0.0 --preedp 0.0 --predp 0.6 --gnndp 0.4 --gnnlr 0.0021 --prelr 0.0018 --batch_size 24576 --ln --lnnn --predictor cn1 --dataset ddi --epochs 100 --runs 10 --model puresum --hiddim 224 --mplayers 1 --testbs 131072 --use_xlin --twolayerlin --res --maskinput --savemod 74 | 75 | python NeighborOverlap.py --xdp 0.05 --tdp 0.0 --gnnedp 0.0 --preedp 0.0 --predp 0.6 --gnndp 0.4 --gnnlr 0.0000000 --prelr 0.0025 --batch_size 24576 --ln --lnnn --predictor incn1cn1 --dataset ddi --proboffset 3 --probscale 10 --pt 0.1 --alpha 0.5 --epochs 2 --runs 10 --model puresum --hiddim 224 --mplayers 1 --testbs 24576 --splitsize 262144 --use_xlin --twolayerlin --res --maskinput --loadmod 76 | ``` -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pyg 4 | - pytorch 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=5.1=1_gnu 11 | - absl-py=1.3.0=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h5764c6d_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - anyio=3.5.0=py310h06a4308_0 15 | - argon2-cffi=21.3.0=pyhd3eb1b0_0 16 | - argon2-cffi-bindings=21.2.0=py310h7f8727e_0 17 | - asttokens=2.0.5=pyhd3eb1b0_0 18 | - async-timeout=4.0.2=pyhd8ed1ab_0 19 | - babel=2.11.0=py310h06a4308_0 20 | - backcall=0.2.0=pyhd3eb1b0_0 21 | - beautifulsoup4=4.11.1=py310h06a4308_0 22 | - blas=1.0=mkl 23 | - bleach=4.1.0=pyhd3eb1b0_0 24 | - blinker=1.5=pyhd8ed1ab_0 25 | - brotli=1.0.9=h5eee18b_7 26 | - brotli-bin=1.0.9=h5eee18b_7 27 | - brotlipy=0.7.0=py310h7f8727e_1002 28 | - bzip2=1.0.8=h7b6447c_0 29 | - c-ares=1.18.1=h7f98852_0 30 | - ca-certificates=2022.12.7=ha878542_0 31 | - cachetools=5.2.0=pyhd8ed1ab_0 32 | - certifi=2022.12.7=pyhd8ed1ab_0 33 | - cffi=1.15.1=py310h5eee18b_3 34 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 35 | - click=8.1.3=unix_pyhd8ed1ab_2 36 | - conda=22.11.1=py310hff52083_1 37 | - conda-content-trust=0.1.3=py310h06a4308_0 38 | - conda-package-handling=1.9.0=py310h5eee18b_1 39 | - contourpy=1.0.5=py310hdb19cb5_0 40 | - cryptography=38.0.1=py310h9ce1e76_0 41 | - cuda=11.7.1=0 42 | - cuda-cccl=11.7.91=0 43 | - cuda-command-line-tools=11.7.1=0 44 | - cuda-compiler=11.7.1=0 45 | - cuda-cudart=11.7.99=0 46 | - cuda-cudart-dev=11.7.99=0 47 | - cuda-cuobjdump=11.7.91=0 48 | - cuda-cupti=11.7.101=0 49 | - cuda-cuxxfilt=11.7.91=0 50 | - cuda-demo-suite=12.0.76=0 51 | - cuda-documentation=12.0.76=0 52 | - cuda-driver-dev=11.7.99=0 53 | - cuda-gdb=12.0.90=0 54 | - cuda-libraries=11.7.1=0 55 | - cuda-libraries-dev=11.7.1=0 56 | - cuda-memcheck=11.8.86=0 57 | - cuda-nsight=12.0.78=0 58 | - cuda-nsight-compute=12.0.0=0 59 | - cuda-nvcc=11.7.99=0 60 | - cuda-nvdisasm=12.0.76=0 61 | - cuda-nvml-dev=11.7.91=0 62 | - cuda-nvprof=12.0.90=0 63 | - cuda-nvprune=11.7.91=0 64 | - cuda-nvrtc=11.7.99=0 65 | - cuda-nvrtc-dev=11.7.99=0 66 | - cuda-nvtx=11.7.91=0 67 | - cuda-nvvp=12.0.90=0 68 | - cuda-runtime=11.7.1=0 69 | - cuda-sanitizer-api=12.0.90=0 70 | - cuda-toolkit=11.7.1=0 71 | - cuda-tools=11.7.1=0 72 | - cuda-visual-tools=11.7.1=0 73 | - cycler=0.11.0=pyhd3eb1b0_0 74 | - dbus=1.13.18=hb2f20db_0 75 | - debugpy=1.5.1=py310h295c915_0 76 | - decorator=5.1.1=pyhd3eb1b0_0 77 | - defusedxml=0.7.1=pyhd3eb1b0_0 78 | - entrypoints=0.4=py310h06a4308_0 79 | - executing=0.8.3=pyhd3eb1b0_0 80 | - expat=2.4.9=h6a678d5_0 81 | - ffmpeg=4.3=hf484d3e_0 82 | - fftw=3.3.9=h27cfd23_1 83 | - flit-core=3.6.0=pyhd3eb1b0_0 84 | - fontconfig=2.14.1=h52c9d5c_1 85 | - fonttools=4.25.0=pyhd3eb1b0_0 86 | - freetype=2.12.1=h4a9f257_0 87 | - frozenlist=1.3.3=py310h5eee18b_0 88 | - gds-tools=1.5.0.59=0 89 | - giflib=5.2.1=h7b6447c_0 90 | - glib=2.69.1=he621ea3_2 91 | - gmp=6.2.1=h295c915_3 92 | - gnutls=3.6.15=he1e5248_0 93 | - google-auth=2.15.0=pyh1a96a4e_0 94 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 95 | - grpcio=1.42.0=py310hce63b2e_0 96 | - gst-plugins-base=1.14.0=h8213a91_2 97 | - gstreamer=1.14.0=h28cd5cc_2 98 | - icu=58.2=he6710b0_3 99 | - idna=3.4=py310h06a4308_0 100 | - intel-openmp=2021.4.0=h06a4308_3561 101 | - ipykernel=6.15.2=py310h06a4308_0 102 | - ipython=8.7.0=py310h06a4308_0 103 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 104 | - ipywidgets=7.6.5=pyhd3eb1b0_1 105 | - jedi=0.18.1=py310h06a4308_1 106 | - jinja2=3.1.2=py310h06a4308_0 107 | - joblib=1.1.1=py310h06a4308_0 108 | - jpeg=9e=h7f8727e_0 109 | - json5=0.9.6=pyhd3eb1b0_0 110 | - jsonschema=4.16.0=py310h06a4308_0 111 | - jupyter=1.0.0=py310h06a4308_8 112 | - jupyter_client=7.4.8=py310h06a4308_0 113 | - jupyter_console=6.4.4=py310h06a4308_0 114 | - jupyter_core=4.11.2=py310h06a4308_0 115 | - jupyter_server=1.23.4=py310h06a4308_0 116 | - jupyterlab=3.5.0=py310h06a4308_0 117 | - jupyterlab_pygments=0.1.2=py_0 118 | - jupyterlab_server=2.16.3=py310h06a4308_0 119 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 120 | - kiwisolver=1.4.2=py310h295c915_0 121 | - krb5=1.19.2=hac12032_0 122 | - lame=3.100=h7b6447c_0 123 | - lcms2=2.12=h3be6417_0 124 | - ld_impl_linux-64=2.38=h1181459_1 125 | - lerc=3.0=h295c915_0 126 | - libbrotlicommon=1.0.9=h5eee18b_7 127 | - libbrotlidec=1.0.9=h5eee18b_7 128 | - libbrotlienc=1.0.9=h5eee18b_7 129 | - libclang=10.0.1=default_hb85057a_2 130 | - libcublas=11.10.3.66=0 131 | - libcublas-dev=11.10.3.66=0 132 | - libcufft=10.7.2.124=h4fbf590_0 133 | - libcufft-dev=10.7.2.124=h98a8f43_0 134 | - libcufile=1.5.0.59=0 135 | - libcufile-dev=1.5.0.59=0 136 | - libcurand=10.3.1.50=0 137 | - libcurand-dev=10.3.1.50=0 138 | - libcusolver=11.4.0.1=0 139 | - libcusolver-dev=11.4.0.1=0 140 | - libcusparse=11.7.4.91=0 141 | - libcusparse-dev=11.7.4.91=0 142 | - libdeflate=1.8=h7f8727e_5 143 | - libedit=3.1.20221030=h5eee18b_0 144 | - libevent=2.1.12=h8f2d780_0 145 | - libffi=3.4.2=h6a678d5_6 146 | - libgcc-ng=11.2.0=h1234567_1 147 | - libgfortran-ng=11.2.0=h00389a5_1 148 | - libgfortran5=11.2.0=h1234567_1 149 | - libgomp=11.2.0=h1234567_1 150 | - libiconv=1.16=h7f8727e_2 151 | - libidn2=2.3.2=h7f8727e_0 152 | - libllvm10=10.0.1=hbcb73fb_5 153 | - libnpp=11.7.4.75=0 154 | - libnpp-dev=11.7.4.75=0 155 | - libnvjpeg=11.8.0.2=0 156 | - libnvjpeg-dev=11.8.0.2=0 157 | - libpng=1.6.37=hbc83047_0 158 | - libpq=12.9=h16c4e8d_3 159 | - libprotobuf=3.20.1=h4ff587b_0 160 | - libsodium=1.0.18=h7b6447c_0 161 | - libstdcxx-ng=11.2.0=h1234567_1 162 | - libtasn1=4.16.0=h27cfd23_0 163 | - libtiff=4.4.0=hecacb30_2 164 | - libunistring=0.9.10=h27cfd23_0 165 | - libuuid=1.41.5=h5eee18b_0 166 | - libwebp=1.2.4=h11a3e52_0 167 | - libwebp-base=1.2.4=h5eee18b_0 168 | - libxcb=1.15=h7f8727e_0 169 | - libxkbcommon=1.0.1=hfa300c1_0 170 | - libxml2=2.9.14=h74e7548_0 171 | - libxslt=1.1.35=h4e12654_0 172 | - lxml=4.9.1=py310h1edc446_0 173 | - lz4-c=1.9.4=h6a678d5_0 174 | - markdown=3.4.1=pyhd8ed1ab_0 175 | - markupsafe=2.1.1=py310h7f8727e_0 176 | - matplotlib=3.6.2=py310h06a4308_0 177 | - matplotlib-base=3.6.2=py310h945d387_0 178 | - matplotlib-inline=0.1.6=py310h06a4308_0 179 | - mistune=0.8.4=py310h7f8727e_1000 180 | - mkl=2021.4.0=h06a4308_640 181 | - mkl-service=2.4.0=py310h7f8727e_0 182 | - mkl_fft=1.3.1=py310hd6ae3a3_0 183 | - mkl_random=1.2.2=py310h00e6091_0 184 | - multidict=6.0.2=py310h5764c6d_1 185 | - munkres=1.1.4=py_0 186 | - nbclassic=0.4.8=py310h06a4308_0 187 | - nbclient=0.5.13=py310h06a4308_0 188 | - nbconvert=6.5.4=py310h06a4308_0 189 | - nbformat=5.7.0=py310h06a4308_0 190 | - ncurses=6.3=h5eee18b_3 191 | - nest-asyncio=1.5.5=py310h06a4308_0 192 | - nettle=3.7.3=hbbd107a_1 193 | - notebook=6.5.2=py310h06a4308_0 194 | - notebook-shim=0.2.2=py310h06a4308_0 195 | - nsight-compute=2022.4.0.15=0 196 | - nspr=4.33=h295c915_0 197 | - nss=3.74=h0370c37_0 198 | - numpy=1.23.4=py310hd5efca6_0 199 | - numpy-base=1.23.4=py310h8e6c178_0 200 | - oauthlib=3.2.2=pyhd8ed1ab_0 201 | - openh264=2.1.1=h4ff587b_0 202 | - openssl=1.1.1s=h7f8727e_0 203 | - packaging=22.0=py310h06a4308_0 204 | - pandocfilters=1.5.0=pyhd3eb1b0_0 205 | - parso=0.8.3=pyhd3eb1b0_0 206 | - pcre=8.45=h295c915_0 207 | - pexpect=4.8.0=pyhd3eb1b0_3 208 | - pickleshare=0.7.5=pyhd3eb1b0_1003 209 | - pillow=9.3.0=py310hace64e9_1 210 | - pip=22.3.1=py310h06a4308_0 211 | - pluggy=1.0.0=py310h06a4308_1 212 | - ply=3.11=py310h06a4308_0 213 | - prometheus_client=0.14.1=py310h06a4308_0 214 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 215 | - prompt_toolkit=3.0.20=hd3eb1b0_0 216 | - protobuf=3.20.1=py310h295c915_0 217 | - psutil=5.9.0=py310h5eee18b_0 218 | - ptyprocess=0.7.0=pyhd3eb1b0_2 219 | - pure_eval=0.2.2=pyhd3eb1b0_0 220 | - pyasn1=0.4.8=py_0 221 | - pyasn1-modules=0.2.7=py_0 222 | - pycosat=0.6.4=py310h5eee18b_0 223 | - pycparser=2.21=pyhd3eb1b0_0 224 | - pyg=2.2.0=py310_torch_1.13.0_cu117 225 | - pygments=2.11.2=pyhd3eb1b0_0 226 | - pyjwt=2.6.0=pyhd8ed1ab_0 227 | - pyopenssl=22.0.0=pyhd3eb1b0_0 228 | - pyparsing=3.0.9=py310h06a4308_0 229 | - pyqt=5.15.7=py310h6a678d5_1 230 | - pyrsistent=0.18.0=py310h7f8727e_0 231 | - pysocks=1.7.1=py310h06a4308_0 232 | - python=3.10.8=h7a1cb2a_1 233 | - python-dateutil=2.8.2=pyhd3eb1b0_0 234 | - python-fastjsonschema=2.16.2=py310h06a4308_0 235 | - python_abi=3.10=2_cp310 236 | - pytorch=1.13.1=py3.10_cuda11.7_cudnn8.5.0_0 237 | - pytorch-cluster=1.6.0=py310_torch_1.13.0_cu117 238 | - pytorch-cuda=11.7=h67b0de4_1 239 | - pytorch-mutex=1.0=cuda 240 | - pytorch-scatter=2.1.0=py310_torch_1.13.0_cu117 241 | - pytorch-sparse=0.6.16=py310_torch_1.13.0_cu117 242 | - pytz=2022.7=py310h06a4308_0 243 | - pyu2f=0.1.5=pyhd8ed1ab_0 244 | - pyzmq=23.2.0=py310h6a678d5_0 245 | - qt-main=5.15.2=h327a75a_7 246 | - qt-webengine=5.15.9=hd2b0992_4 247 | - qtconsole=5.3.2=py310h06a4308_0 248 | - qtpy=2.2.0=py310h06a4308_0 249 | - qtwebkit=5.212=h4eab89a_4 250 | - readline=8.2=h5eee18b_0 251 | - requests=2.28.1=py310h06a4308_0 252 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 253 | - rsa=4.9=pyhd8ed1ab_0 254 | - ruamel.yaml=0.17.21=py310h5eee18b_0 255 | - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 256 | - scikit-learn=1.1.3=py310h6a678d5_0 257 | - send2trash=1.8.0=pyhd3eb1b0_1 258 | - setuptools=65.5.0=py310h06a4308_0 259 | - sip=6.6.2=py310h6a678d5_0 260 | - six=1.16.0=pyhd3eb1b0_1 261 | - sniffio=1.2.0=py310h06a4308_1 262 | - soupsieve=2.3.2.post1=py310h06a4308_0 263 | - sqlite=3.40.0=h5082296_0 264 | - stack_data=0.2.0=pyhd3eb1b0_0 265 | - tensorboard=2.11.0=pyhd8ed1ab_0 266 | - tensorboard-data-server=0.6.1=py310h52d8a92_0 267 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 268 | - terminado=0.17.1=py310h06a4308_0 269 | - threadpoolctl=2.2.0=pyh0d69192_0 270 | - tinycss2=1.2.1=py310h06a4308_0 271 | - tk=8.6.12=h1ccaba5_0 272 | - toml=0.10.2=pyhd3eb1b0_0 273 | - tomli=2.0.1=py310h06a4308_0 274 | - toolz=0.12.0=py310h06a4308_0 275 | - torchaudio=0.13.1=py310_cu117 276 | - torchvision=0.14.1=py310_cu117 277 | - tornado=6.2=py310h5eee18b_0 278 | - tqdm=4.64.1=py310h06a4308_0 279 | - traitlets=5.7.1=py310h06a4308_0 280 | - typing-extensions=4.4.0=py310h06a4308_0 281 | - typing_extensions=4.4.0=py310h06a4308_0 282 | - tzdata=2022g=h04d1e81_0 283 | - urllib3=1.26.13=py310h06a4308_0 284 | - wcwidth=0.2.5=pyhd3eb1b0_0 285 | - webencodings=0.5.1=py310h06a4308_1 286 | - websocket-client=0.58.0=py310h06a4308_4 287 | - werkzeug=2.2.2=pyhd8ed1ab_0 288 | - wheel=0.37.1=pyhd3eb1b0_0 289 | - widgetsnbextension=3.5.2=py310h06a4308_0 290 | - xz=5.2.8=h5eee18b_0 291 | - yarl=1.7.2=py310h5764c6d_2 292 | - zeromq=4.3.4=h2531618_0 293 | - zipp=3.11.0=pyhd8ed1ab_0 294 | - zlib=1.2.13=h5eee18b_0 295 | - zstd=1.5.2=ha4553b6_0 296 | - pip: 297 | - alembic==1.9.1 298 | - attrs==22.2.0 299 | - autopage==0.5.1 300 | - cliff==4.1.0 301 | - cmaes==0.9.0 302 | - cmd2==2.4.2 303 | - colorlog==6.7.0 304 | - greenlet==2.0.1 305 | - importlib-metadata==4.13.0 306 | - littleutils==0.2.2 307 | - mako==1.2.4 308 | - ogb==1.3.5 309 | - optuna==3.0.5 310 | - outdated==0.2.2 311 | - pandas==1.5.2 312 | - pbr==5.11.0 313 | - prettytable==3.5.0 314 | - pyperclip==1.8.2 315 | - pyqt5-sip==12.11.0 316 | - pyyaml==6.0 317 | - scipy==1.8.1 318 | - sqlalchemy==1.4.45 319 | - stevedore==4.1.1 320 | - torch-tb-profiler==0.4.0 321 | - yapf==0.32.0 322 | prefix: /home/wangxiyuan/miniconda3 323 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv, SAGEConv 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | import torch 6 | from utils import adjoverlap 7 | from torch_sparse.matmul import spmm_max, spmm_mean, spmm_add 8 | from torch_sparse import SparseTensor 9 | import torch_sparse 10 | from torch_scatter import scatter_add 11 | from typing import Iterable, Final 12 | 13 | # a vanilla message passing layer 14 | class PureConv(nn.Module): 15 | aggr: Final[str] 16 | def __init__(self, indim, outdim, aggr="gcn") -> None: 17 | super().__init__() 18 | self.aggr = aggr 19 | if indim == outdim: 20 | self.lin = nn.Identity() 21 | else: 22 | raise NotImplementedError 23 | 24 | def forward(self, x, adj_t): 25 | x = self.lin(x) 26 | if self.aggr == "mean": 27 | return spmm_mean(adj_t, x) 28 | elif self.aggr == "max": 29 | return spmm_max(adj_t, x)[0] 30 | elif self.aggr == "sum": 31 | return spmm_add(adj_t, x) 32 | elif self.aggr == "gcn": 33 | norm = torch.rsqrt_((1+adj_t.sum(dim=-1))).reshape(-1, 1) 34 | x = norm * x 35 | x = spmm_add(adj_t, x) + x 36 | x = norm * x 37 | return x 38 | 39 | 40 | convdict = { 41 | "gcn": 42 | GCNConv, 43 | "gcn_cached": 44 | lambda indim, outdim: GCNConv(indim, outdim, cached=True), 45 | "sage": 46 | lambda indim, outdim: GCNConv( 47 | indim, outdim, aggr="mean", normalize=False, add_self_loops=False), 48 | "gin": 49 | lambda indim, outdim: GCNConv( 50 | indim, outdim, aggr="sum", normalize=False, add_self_loops=False), 51 | "max": 52 | lambda indim, outdim: GCNConv( 53 | indim, outdim, aggr="max", normalize=False, add_self_loops=False), 54 | "puremax": 55 | lambda indim, outdim: PureConv(indim, outdim, aggr="max"), 56 | "puresum": 57 | lambda indim, outdim: PureConv(indim, outdim, aggr="sum"), 58 | "puremean": 59 | lambda indim, outdim: PureConv(indim, outdim, aggr="mean"), 60 | "puregcn": 61 | lambda indim, outdim: PureConv(indim, outdim, aggr="gcn"), 62 | "none": 63 | None 64 | } 65 | 66 | predictor_dict = {} 67 | 68 | # Edge dropout 69 | class DropEdge(nn.Module): 70 | 71 | def __init__(self, dp: float = 0.0) -> None: 72 | super().__init__() 73 | self.dp = dp 74 | 75 | def forward(self, edge_index: Tensor): 76 | if self.dp == 0: 77 | return edge_index 78 | mask = torch.rand_like(edge_index[0], dtype=torch.float) > self.dp 79 | return edge_index[:, mask] 80 | 81 | # Edge dropout with adjacency matrix as input 82 | class DropAdj(nn.Module): 83 | doscale: Final[bool] # whether to rescale edge weight 84 | def __init__(self, dp: float = 0.0, doscale=True) -> None: 85 | super().__init__() 86 | self.dp = dp 87 | self.register_buffer("ratio", torch.tensor(1/(1-dp))) 88 | self.doscale = doscale 89 | 90 | def forward(self, adj: SparseTensor)->SparseTensor: 91 | if self.dp < 1e-6 or not self.training: 92 | return adj 93 | mask = torch.rand_like(adj.storage.col(), dtype=torch.float) > self.dp 94 | adj = torch_sparse.masked_select_nnz(adj, mask, layout="coo") 95 | if self.doscale: 96 | if adj.storage.has_value(): 97 | adj.storage.set_value_(adj.storage.value()*self.ratio, layout="coo") 98 | else: 99 | adj.fill_value_(1/(1-self.dp), dtype=torch.float) 100 | return adj 101 | 102 | 103 | # Vanilla MPNN composed of several layers. 104 | class GCN(nn.Module): 105 | 106 | def __init__(self, 107 | in_channels, 108 | hidden_channels, 109 | out_channels, 110 | num_layers, 111 | dropout, 112 | ln=False, 113 | res=False, 114 | max_x=-1, 115 | conv_fn="gcn", 116 | jk=False, 117 | edrop=0.0, 118 | xdropout=0.0, 119 | taildropout=0.0, 120 | noinputlin=False): 121 | super().__init__() 122 | 123 | self.adjdrop = DropAdj(edrop) 124 | 125 | if max_x >= 0: 126 | tmp = nn.Embedding(max_x + 1, hidden_channels) 127 | nn.init.orthogonal_(tmp.weight) 128 | self.xemb = nn.Sequential(tmp, nn.Dropout(dropout)) 129 | in_channels = hidden_channels 130 | else: 131 | self.xemb = nn.Sequential(nn.Dropout(xdropout)) #nn.Identity() 132 | if not noinputlin and ("pure" in conv_fn or num_layers==0): 133 | self.xemb.append(nn.Linear(in_channels, hidden_channels)) 134 | self.xemb.append(nn.Dropout(dropout, inplace=True) if dropout > 1e-6 else nn.Identity()) 135 | 136 | self.res = res 137 | self.jk = jk 138 | if jk: 139 | self.register_parameter("jkparams", nn.Parameter(torch.randn((num_layers,)))) 140 | 141 | if num_layers == 0 or conv_fn =="none": 142 | self.jk = False 143 | return 144 | 145 | convfn = convdict[conv_fn] 146 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 147 | 148 | if num_layers == 1: 149 | hidden_channels = out_channels 150 | 151 | self.convs = nn.ModuleList() 152 | self.lins = nn.ModuleList() 153 | if "pure" in conv_fn: 154 | self.convs.append(convfn(hidden_channels, hidden_channels)) 155 | for i in range(num_layers-1): 156 | self.lins.append(nn.Identity()) 157 | self.convs.append(convfn(hidden_channels, hidden_channels)) 158 | self.lins.append(nn.Dropout(taildropout, True)) 159 | else: 160 | self.convs.append(convfn(in_channels, hidden_channels)) 161 | self.lins.append( 162 | nn.Sequential(lnfn(hidden_channels, ln), nn.Dropout(dropout, True), 163 | nn.ReLU(True))) 164 | for i in range(num_layers - 1): 165 | self.convs.append( 166 | convfn( 167 | hidden_channels, 168 | hidden_channels if i == num_layers - 2 else out_channels)) 169 | if i < num_layers - 2: 170 | self.lins.append( 171 | nn.Sequential( 172 | lnfn( 173 | hidden_channels if i == num_layers - 174 | 2 else out_channels, ln), 175 | nn.Dropout(dropout, True), nn.ReLU(True))) 176 | else: 177 | self.lins.append(nn.Identity()) 178 | 179 | 180 | def forward(self, x, adj_t): 181 | x = self.xemb(x) 182 | jkx = [] 183 | for i, conv in enumerate(self.convs): 184 | x1 = self.lins[i](conv(x, self.adjdrop(adj_t))) 185 | if self.res and x1.shape[-1] == x.shape[-1]: # residual connection 186 | x = x1 + x 187 | else: 188 | x = x1 189 | if self.jk: 190 | jkx.append(x) 191 | if self.jk: # JumpingKnowledge Connection 192 | jkx = torch.stack(jkx, dim=0) 193 | sftmax = self.jkparams.reshape(-1, 1, 1) 194 | x = torch.sum(jkx*sftmax, dim=0) 195 | return x 196 | 197 | 198 | # GAE predictor 199 | class LinkPredictor(nn.Module): 200 | 201 | def __init__(self, 202 | in_channels, 203 | hidden_channels, 204 | out_channels, 205 | num_layers, 206 | dropout, 207 | edrop=0.0, 208 | ln=False, 209 | **kwargs): 210 | super(LinkPredictor, self).__init__() 211 | 212 | self.lins = nn.Sequential() 213 | 214 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 215 | 216 | if num_layers == 1: 217 | self.lins = nn.Linear(in_channels, out_channels) 218 | else: 219 | self.lins.append(nn.Linear(in_channels, hidden_channels)) 220 | self.lins.append(lnfn(hidden_channels, ln)) 221 | self.lins.append(nn.Dropout(dropout, inplace=True)) 222 | self.lins.append(nn.ReLU(inplace=True)) 223 | for _ in range(num_layers - 2): 224 | self.lins.append(nn.Linear(hidden_channels, hidden_channels)) 225 | self.lins.append(lnfn(hidden_channels, ln)) 226 | self.lins.append(nn.Dropout(dropout, inplace=True)) 227 | self.lins.append(nn.ReLU(inplace=True)) 228 | self.lins.append(nn.Linear(hidden_channels, out_channels)) 229 | 230 | def multidomainforward(self, 231 | x, 232 | adj, 233 | tar_ei, 234 | filled1: bool = False, 235 | cndropprobs: Iterable[float] = [0.25]): 236 | x = x[tar_ei].prod(dim=0) 237 | x = self.lins(x) 238 | return x.expand(-1, len(cndropprobs) + 1) 239 | 240 | def forward(self, x, adj, tar_ei, filled1: bool = False): 241 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 242 | 243 | 244 | # GAE + CN link predictor 245 | class SCNLinkPredictor(nn.Module): 246 | cndeg: Final[int] 247 | def __init__(self, 248 | in_channels, 249 | hidden_channels, 250 | out_channels, 251 | num_layers, 252 | dropout, 253 | edrop=0.0, 254 | ln=False, 255 | cndeg=-1, 256 | use_xlin=False, 257 | tailact=False, 258 | twolayerlin=False, 259 | beta=1.0): 260 | super().__init__() 261 | 262 | self.register_parameter("beta", nn.Parameter(beta*torch.ones((1)))) 263 | self.dropadj = DropAdj(edrop) 264 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 265 | 266 | self.xlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 267 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 268 | nn.Linear(hidden_channels, hidden_channels), 269 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True)) if use_xlin else lambda x: 0 270 | 271 | self.xcnlin = nn.Sequential( 272 | nn.Linear(1, hidden_channels), 273 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 274 | nn.Linear(hidden_channels, hidden_channels), 275 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 276 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 277 | self.xijlin = nn.Sequential( 278 | nn.Linear(in_channels, hidden_channels), lnfn(hidden_channels, ln), 279 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 280 | nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 281 | self.lin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 282 | lnfn(hidden_channels, ln), 283 | nn.Dropout(dropout, inplace=True), 284 | nn.ReLU(inplace=True), 285 | nn.Linear(hidden_channels, hidden_channels) if twolayerlin else nn.Identity(), 286 | lnfn(hidden_channels, ln) if twolayerlin else nn.Identity(), 287 | nn.Dropout(dropout, inplace=True) if twolayerlin else nn.Identity(), 288 | nn.ReLU(inplace=True) if twolayerlin else nn.Identity(), 289 | nn.Linear(hidden_channels, out_channels)) 290 | self.cndeg = cndeg 291 | 292 | def multidomainforward(self, 293 | x, 294 | adj, 295 | tar_ei, 296 | filled1: bool = False, 297 | cndropprobs: Iterable[float] = []): 298 | adj = self.dropadj(adj) 299 | xi = x[tar_ei[0]] 300 | xj = x[tar_ei[1]] 301 | cn = adjoverlap(adj, adj, tar_ei, filled1, cnsampledeg=self.cndeg) 302 | xcn = cn.sum(dim=-1).float().reshape(-1, 1) 303 | xij = self.xijlin(xi * xj) 304 | 305 | xs = torch.cat( 306 | [self.lin(self.xcnlin(xcn) * self.beta + xij)], 307 | dim=-1) 308 | return xs 309 | 310 | def forward(self, x, adj, tar_ei, filled1: bool = False): 311 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 312 | 313 | # another GAE + CN predictor 314 | class CatSCNLinkPredictor(nn.Module): 315 | cndeg: Final[int] 316 | def __init__(self, 317 | in_channels, 318 | hidden_channels, 319 | out_channels, 320 | num_layers, 321 | dropout, 322 | edrop=0.0, 323 | ln=False, 324 | cndeg=-1, 325 | use_xlin=False, 326 | tailact=False, 327 | twolayerlin=False, 328 | beta=1.0): 329 | super().__init__() 330 | 331 | self.register_parameter("beta", nn.Parameter(beta*torch.ones((1)))) 332 | self.dropadj = DropAdj(edrop) 333 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 334 | 335 | self.xlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 336 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 337 | nn.Linear(hidden_channels, hidden_channels), 338 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True)) if use_xlin else lambda x: 0 339 | 340 | self.xcnlin = nn.Sequential( 341 | nn.Linear(1, hidden_channels), 342 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 343 | nn.Linear(hidden_channels, hidden_channels), 344 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 345 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 346 | self.xijlin = nn.Sequential( 347 | nn.Linear(in_channels, hidden_channels), lnfn(hidden_channels, ln), 348 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 349 | nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 350 | self.lin = nn.Sequential(nn.Linear(hidden_channels+1, hidden_channels), 351 | lnfn(hidden_channels, ln), 352 | nn.Dropout(dropout, inplace=True), 353 | nn.ReLU(inplace=True), 354 | nn.Linear(hidden_channels, hidden_channels) if twolayerlin else nn.Identity(), 355 | lnfn(hidden_channels, ln) if twolayerlin else nn.Identity(), 356 | nn.Dropout(dropout, inplace=True) if twolayerlin else nn.Identity(), 357 | nn.ReLU(inplace=True) if twolayerlin else nn.Identity(), 358 | nn.Linear(hidden_channels, out_channels)) 359 | self.cndeg = cndeg 360 | 361 | def multidomainforward(self, 362 | x, 363 | adj, 364 | tar_ei, 365 | filled1: bool = False, 366 | cndropprobs: Iterable[float] = []): 367 | adj = self.dropadj(adj) 368 | xi = x[tar_ei[0]] 369 | xj = x[tar_ei[1]] 370 | cn = adjoverlap(adj, adj, tar_ei, filled1, cnsampledeg=self.cndeg) 371 | xcn = cn.sum(dim=-1).float().reshape(-1, 1) 372 | xij = self.xijlin(xi * xj) 373 | 374 | xs = torch.cat( 375 | [self.lin(torch.cat((xcn, xij), dim=-1) )], 376 | dim=-1) 377 | return xs 378 | 379 | def forward(self, x, adj, tar_ei, filled1: bool = False): 380 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 381 | 382 | # GAE + CN predictor boosted by CNC trick 383 | class IncompleteSCN1Predictor(SCNLinkPredictor): 384 | learnablept: Final[bool] 385 | depth: Final[int] 386 | splitsize: Final[int] 387 | def __init__(self, 388 | in_channels, 389 | hidden_channels, 390 | out_channels, 391 | num_layers, 392 | dropout, 393 | edrop=0.0, 394 | ln=False, 395 | cndeg=-1, 396 | use_xlin=False, 397 | tailact=False, 398 | twolayerlin=False, 399 | beta=1.0, 400 | alpha=1.0, 401 | scale=5, 402 | offset=3, 403 | trainresdeg=8, 404 | testresdeg=128, 405 | pt=0.5, 406 | learnablept=False, 407 | depth=1, 408 | splitsize=-1, 409 | ): 410 | super().__init__(in_channels, hidden_channels, out_channels, num_layers, dropout, edrop, ln, cndeg, use_xlin, tailact, twolayerlin, beta) 411 | self.learnablept= learnablept 412 | self.depth = depth 413 | self.splitsize = splitsize 414 | self.lins = nn.Sequential() 415 | self.register_buffer("alpha", torch.tensor([alpha])) 416 | self.register_buffer("pt", torch.tensor([pt])) 417 | self.register_buffer("scale", torch.tensor([scale])) 418 | self.register_buffer("offset", torch.tensor([offset])) 419 | 420 | self.trainresdeg = trainresdeg 421 | self.testresdeg = testresdeg 422 | self.ptlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=True), nn.Linear(hidden_channels, 1), nn.Sigmoid()) 423 | # print(self.xcnlin) 424 | 425 | def clampprob(self, prob, pt): 426 | p0 = torch.sigmoid_(self.scale*(prob-self.offset)) 427 | return self.alpha*pt*p0/(pt*p0+1-p0) 428 | 429 | def multidomainforward(self, 430 | x, 431 | adj, 432 | tar_ei, 433 | filled1: bool = False, 434 | cndropprobs: Iterable[float] = [], 435 | depth: int=None): 436 | assert len(cndropprobs) == 0 437 | if depth is None: 438 | depth = self.depth 439 | adj = self.dropadj(adj) 440 | xi = x[tar_ei[0]] 441 | xj = x[tar_ei[1]] 442 | xij = xi*xj 443 | if depth > 0.5: 444 | cn, cnres1, cnres2 = adjoverlap( 445 | adj, 446 | adj, 447 | tar_ei, 448 | filled1, 449 | calresadj=True, 450 | cnsampledeg=self.cndeg, 451 | ressampledeg=self.trainresdeg if self.training else self.testresdeg) 452 | else: 453 | cn = adjoverlap( 454 | adj, 455 | adj, 456 | tar_ei, 457 | filled1, 458 | calresadj=False, 459 | cnsampledeg=self.cndeg, 460 | ressampledeg=self.trainresdeg if self.training else self.testresdeg) 461 | xcns = [cn.sum(dim=-1).float().reshape(-1, 1)] 462 | 463 | if depth > 0.5: 464 | potcn1 = cnres1.coo() 465 | potcn2 = cnres2.coo() 466 | with torch.no_grad(): 467 | if self.splitsize < 0: 468 | ei1 = torch.stack((tar_ei[1][potcn1[0]], potcn1[1])) 469 | probcn1 = self.forward( 470 | x, adj, ei1, 471 | filled1, depth-1).flatten() 472 | ei2 = torch.stack((tar_ei[0][potcn2[0]], potcn2[1])) 473 | probcn2 = self.forward( 474 | x, adj, ei2, 475 | filled1, depth-1).flatten() 476 | else: 477 | num1 = potcn1[1].shape[0] 478 | ei1 = torch.stack((tar_ei[1][potcn1[0]], potcn1[1])) 479 | probcn1 = torch.empty_like(potcn1[1], dtype=torch.float) 480 | for i in range(0, num1, self.splitsize): 481 | probcn1[i:i+self.splitsize] = self.forward(x, adj, ei1[:, i: i+self.splitsize], filled1, depth-1).flatten() 482 | num2 = potcn2[1].shape[0] 483 | ei2 = torch.stack((tar_ei[0][potcn2[0]], potcn2[1])) 484 | probcn2 = torch.empty_like(potcn2[1], dtype=torch.float) 485 | for i in range(0, num2, self.splitsize): 486 | probcn2[i:i+self.splitsize] = self.forward(x, adj, ei2[:, i: i+self.splitsize],filled1, depth-1).flatten() 487 | if self.learnablept: 488 | pt = self.ptlin(xij) 489 | probcn1 = self.clampprob(probcn1, pt[potcn1[0]]) 490 | probcn2 = self.clampprob(probcn2, pt[potcn2[0]]) 491 | else: 492 | probcn1 = self.clampprob(probcn1, self.pt) 493 | probcn2 = self.clampprob(probcn2, self.pt) 494 | probcn1 = probcn1 * potcn1[-1] 495 | probcn2 = probcn2 * potcn2[-1] 496 | cnres1.set_value_(probcn1, layout="coo") 497 | cnres2.set_value_(probcn2, layout="coo") 498 | xcn1 = cnres1.sum(dim=-1).float().reshape(-1, 1) 499 | xcn2 = cnres2.sum(dim=-1).float().reshape(-1, 1) 500 | xcns[0] = xcns[0] + xcn2 + xcn1 501 | xij = self.xijlin(xij) 502 | 503 | xs = torch.cat( 504 | [self.lin(self.xcnlin(xcn) * self.beta + xij) for xcn in xcns], 505 | dim=-1) 506 | return xs 507 | 508 | def setalpha(self, alpha: float): 509 | self.alpha.fill_(alpha) 510 | print(f"set alpha: {alpha}") 511 | 512 | def forward(self, 513 | x, 514 | adj, 515 | tar_ei, 516 | filled1: bool = False, 517 | depth: int = None): 518 | if depth is None: 519 | depth = self.depth 520 | return self.multidomainforward(x, adj, tar_ei, filled1, [], 521 | depth) 522 | 523 | 524 | # NCN predictor 525 | class CNLinkPredictor(nn.Module): 526 | cndeg: Final[int] 527 | def __init__(self, 528 | in_channels, 529 | hidden_channels, 530 | out_channels, 531 | num_layers, 532 | dropout, 533 | edrop=0.0, 534 | ln=False, 535 | cndeg=-1, 536 | use_xlin=False, 537 | tailact=False, 538 | twolayerlin=False, 539 | beta=1.0): 540 | super().__init__() 541 | 542 | self.register_parameter("beta", nn.Parameter(beta*torch.ones((1)))) 543 | self.dropadj = DropAdj(edrop) 544 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 545 | 546 | self.xlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 547 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 548 | nn.Linear(hidden_channels, hidden_channels), 549 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True)) if use_xlin else lambda x: 0 550 | 551 | self.xcnlin = nn.Sequential( 552 | nn.Linear(in_channels, hidden_channels), 553 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 554 | nn.Linear(hidden_channels, hidden_channels), 555 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 556 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 557 | self.xijlin = nn.Sequential( 558 | nn.Linear(in_channels, hidden_channels), lnfn(hidden_channels, ln), 559 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 560 | nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 561 | self.lin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 562 | lnfn(hidden_channels, ln), 563 | nn.Dropout(dropout, inplace=True), 564 | nn.ReLU(inplace=True), 565 | nn.Linear(hidden_channels, hidden_channels) if twolayerlin else nn.Identity(), 566 | lnfn(hidden_channels, ln) if twolayerlin else nn.Identity(), 567 | nn.Dropout(dropout, inplace=True) if twolayerlin else nn.Identity(), 568 | nn.ReLU(inplace=True) if twolayerlin else nn.Identity(), 569 | nn.Linear(hidden_channels, out_channels)) 570 | self.cndeg = cndeg 571 | 572 | def multidomainforward(self, 573 | x, 574 | adj, 575 | tar_ei, 576 | filled1: bool = False, 577 | cndropprobs: Iterable[float] = []): 578 | adj = self.dropadj(adj) 579 | xi = x[tar_ei[0]] 580 | xj = x[tar_ei[1]] 581 | x = x + self.xlin(x) 582 | cn = adjoverlap(adj, adj, tar_ei, filled1, cnsampledeg=self.cndeg) 583 | xcns = [spmm_add(cn, x)] 584 | xij = self.xijlin(xi * xj) 585 | 586 | xs = torch.cat( 587 | [self.lin(self.xcnlin(xcn) * self.beta + xij) for xcn in xcns], 588 | dim=-1) 589 | return xs 590 | 591 | def forward(self, x, adj, tar_ei, filled1: bool = False): 592 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 593 | 594 | # GAE predictor for ablation study 595 | class CN0LinkPredictor(nn.Module): 596 | cndeg: Final[int] 597 | def __init__(self, 598 | in_channels, 599 | hidden_channels, 600 | out_channels, 601 | num_layers, 602 | dropout, 603 | edrop=0.0, 604 | ln=False, 605 | cndeg=-1, 606 | use_xlin=False, 607 | tailact=False, 608 | twolayerlin=False, 609 | beta=1.0): 610 | super().__init__() 611 | 612 | self.register_parameter("beta", nn.Parameter(beta*torch.ones((1)))) 613 | self.dropadj = DropAdj(edrop) 614 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 615 | 616 | self.xlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 617 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 618 | nn.Linear(hidden_channels, hidden_channels), 619 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True)) if use_xlin else lambda x: 0 620 | 621 | self.xcnlin = nn.Sequential( 622 | nn.Linear(in_channels, hidden_channels), 623 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 624 | nn.Linear(hidden_channels, hidden_channels), 625 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 626 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 627 | self.xijlin = nn.Sequential( 628 | nn.Linear(in_channels, hidden_channels), lnfn(hidden_channels, ln), 629 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 630 | nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 631 | self.lin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 632 | lnfn(hidden_channels, ln), 633 | nn.Dropout(dropout, inplace=True), 634 | nn.ReLU(inplace=True), 635 | nn.Linear(hidden_channels, hidden_channels) if twolayerlin else nn.Identity(), 636 | lnfn(hidden_channels, ln) if twolayerlin else nn.Identity(), 637 | nn.Dropout(dropout, inplace=True) if twolayerlin else nn.Identity(), 638 | nn.ReLU(inplace=True) if twolayerlin else nn.Identity(), 639 | nn.Linear(hidden_channels, out_channels)) 640 | self.cndeg = cndeg 641 | 642 | def multidomainforward(self, 643 | x, 644 | adj, 645 | tar_ei, 646 | filled1: bool = False, 647 | cndropprobs: Iterable[float] = []): 648 | adj = self.dropadj(adj) 649 | xi = x[tar_ei[0]] 650 | xj = x[tar_ei[1]] 651 | xij = self.xijlin(xi * xj) 652 | 653 | xs = torch.cat( 654 | [self.lin(xij)], 655 | dim=-1) 656 | return xs 657 | 658 | def forward(self, x, adj, tar_ei, filled1: bool = False): 659 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 660 | 661 | # NCNC predictor 662 | class IncompleteCN1Predictor(CNLinkPredictor): 663 | learnablept: Final[bool] 664 | depth: Final[int] 665 | splitsize: Final[int] 666 | def __init__(self, 667 | in_channels, 668 | hidden_channels, 669 | out_channels, 670 | num_layers, 671 | dropout, 672 | edrop=0.0, 673 | ln=False, 674 | cndeg=-1, 675 | use_xlin=False, 676 | tailact=False, 677 | twolayerlin=False, 678 | beta=1.0, 679 | alpha=1.0, 680 | scale=5, 681 | offset=3, 682 | trainresdeg=8, 683 | testresdeg=128, 684 | pt=0.5, 685 | learnablept=False, 686 | depth=1, 687 | splitsize=-1, 688 | ): 689 | super().__init__(in_channels, hidden_channels, out_channels, num_layers, dropout, edrop, ln, cndeg, use_xlin, tailact, twolayerlin, beta) 690 | self.learnablept= learnablept 691 | self.depth = depth 692 | self.splitsize = splitsize 693 | self.lins = nn.Sequential() 694 | self.register_buffer("alpha", torch.tensor([alpha])) 695 | self.register_buffer("pt", torch.tensor([pt])) 696 | self.register_buffer("scale", torch.tensor([scale])) 697 | self.register_buffer("offset", torch.tensor([offset])) 698 | 699 | self.trainresdeg = trainresdeg 700 | self.testresdeg = testresdeg 701 | self.ptlin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), nn.ReLU(inplace=True), nn.Linear(hidden_channels, 1), nn.Sigmoid()) 702 | 703 | def clampprob(self, prob, pt): 704 | p0 = torch.sigmoid_(self.scale*(prob-self.offset)) 705 | return self.alpha*pt*p0/(pt*p0+1-p0) 706 | 707 | def multidomainforward(self, 708 | x, 709 | adj, 710 | tar_ei, 711 | filled1: bool = False, 712 | cndropprobs: Iterable[float] = [], 713 | depth: int=None): 714 | assert len(cndropprobs) == 0 715 | if depth is None: 716 | depth = self.depth 717 | adj = self.dropadj(adj) 718 | xi = x[tar_ei[0]] 719 | xj = x[tar_ei[1]] 720 | xij = xi*xj 721 | x = x + self.xlin(x) 722 | if depth > 0.5: 723 | cn, cnres1, cnres2 = adjoverlap( 724 | adj, 725 | adj, 726 | tar_ei, 727 | filled1, 728 | calresadj=True, 729 | cnsampledeg=self.cndeg, 730 | ressampledeg=self.trainresdeg if self.training else self.testresdeg) 731 | else: 732 | cn = adjoverlap( 733 | adj, 734 | adj, 735 | tar_ei, 736 | filled1, 737 | calresadj=False, 738 | cnsampledeg=self.cndeg, 739 | ressampledeg=self.trainresdeg if self.training else self.testresdeg) 740 | xcns = [spmm_add(cn, x)] 741 | 742 | if depth > 0.5: 743 | potcn1 = cnres1.coo() 744 | potcn2 = cnres2.coo() 745 | with torch.no_grad(): 746 | if self.splitsize < 0: 747 | ei1 = torch.stack((tar_ei[1][potcn1[0]], potcn1[1])) 748 | probcn1 = self.forward( 749 | x, adj, ei1, 750 | filled1, depth-1).flatten() 751 | ei2 = torch.stack((tar_ei[0][potcn2[0]], potcn2[1])) 752 | probcn2 = self.forward( 753 | x, adj, ei2, 754 | filled1, depth-1).flatten() 755 | else: 756 | num1 = potcn1[1].shape[0] 757 | ei1 = torch.stack((tar_ei[1][potcn1[0]], potcn1[1])) 758 | probcn1 = torch.empty_like(potcn1[1], dtype=torch.float) 759 | for i in range(0, num1, self.splitsize): 760 | probcn1[i:i+self.splitsize] = self.forward(x, adj, ei1[:, i: i+self.splitsize], filled1, depth-1).flatten() 761 | num2 = potcn2[1].shape[0] 762 | ei2 = torch.stack((tar_ei[0][potcn2[0]], potcn2[1])) 763 | probcn2 = torch.empty_like(potcn2[1], dtype=torch.float) 764 | for i in range(0, num2, self.splitsize): 765 | probcn2[i:i+self.splitsize] = self.forward(x, adj, ei2[:, i: i+self.splitsize],filled1, depth-1).flatten() 766 | if self.learnablept: 767 | pt = self.ptlin(xij) 768 | probcn1 = self.clampprob(probcn1, pt[potcn1[0]]) 769 | probcn2 = self.clampprob(probcn2, pt[potcn2[0]]) 770 | else: 771 | probcn1 = self.clampprob(probcn1, self.pt) 772 | probcn2 = self.clampprob(probcn2, self.pt) 773 | probcn1 = probcn1 * potcn1[-1] 774 | probcn2 = probcn2 * potcn2[-1] 775 | cnres1.set_value_(probcn1, layout="coo") 776 | cnres2.set_value_(probcn2, layout="coo") 777 | xcn1 = spmm_add(cnres1, x) 778 | xcn2 = spmm_add(cnres2, x) 779 | xcns[0] = xcns[0] + xcn2 + xcn1 780 | 781 | xij = self.xijlin(xij) 782 | 783 | xs = torch.cat( 784 | [self.lin(self.xcnlin(xcn) * self.beta + xij) for xcn in xcns], 785 | dim=-1) 786 | return xs 787 | 788 | def setalpha(self, alpha: float): 789 | self.alpha.fill_(alpha) 790 | print(f"set alpha: {alpha}") 791 | 792 | def forward(self, 793 | x, 794 | adj, 795 | tar_ei, 796 | filled1: bool = False, 797 | depth: int = None): 798 | if depth is None: 799 | depth = self.depth 800 | return self.multidomainforward(x, adj, tar_ei, filled1, [], 801 | depth) 802 | 803 | 804 | # NCN2 predictor 805 | class CNhalf2LinkPredictor(CNLinkPredictor): 806 | 807 | def __init__(self, 808 | in_channels, 809 | hidden_channels, 810 | out_channels, 811 | num_layers, 812 | dropout, 813 | ln=False, 814 | tailact=False, 815 | **kwargs): 816 | super().__init__(in_channels, 817 | hidden_channels, 818 | out_channels, 819 | num_layers, 820 | dropout, ln=ln, tailact=tailact, **kwargs) 821 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 822 | self.xcn12lin = nn.Sequential( 823 | nn.Linear(in_channels, hidden_channels), 824 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 825 | nn.Linear(hidden_channels, hidden_channels), 826 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 827 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 828 | 829 | def multidomainforward(self, 830 | x, 831 | adj, 832 | tar_ei, 833 | filled1: bool = False, 834 | cndropprobs: Iterable[float] = []): 835 | adj = self.dropadj(adj) 836 | xi = x[tar_ei[0]] 837 | xj = x[tar_ei[1]] 838 | x = x + self.xlin(x) 839 | cn = adjoverlap(adj, adj, tar_ei, filled1, cnsampledeg=self.cndeg) 840 | adj2 = adj@adj 841 | cn12 = adjoverlap(adj, adj2, tar_ei, filled1, cnsampledeg=self.cndeg) 842 | cn21 = adjoverlap(adj2, adj, tar_ei, filled1, cnsampledeg=self.cndeg) 843 | 844 | xcns = [(spmm_add(cn, x), spmm_add(cn12, x)+spmm_add(cn21, x))] 845 | xij = self.xijlin(xi * xj) 846 | 847 | xs = torch.cat( 848 | [self.lin(self.xcnlin(xcn[0]) * self.beta + self.xcn12lin(xcn[1]) + xij) for xcn in xcns], 849 | dim=-1) 850 | return xs 851 | 852 | def forward(self, x, adj, tar_ei, filled1: bool = False): 853 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 854 | 855 | 856 | 857 | # NCN-diff 858 | class CNResLinkPredictor(CNLinkPredictor): 859 | 860 | def __init__(self, 861 | in_channels, 862 | hidden_channels, 863 | out_channels, 864 | num_layers, 865 | dropout, 866 | ln=False, 867 | tailact=False, 868 | **kwargs): 869 | super().__init__(in_channels, 870 | hidden_channels, 871 | out_channels, 872 | num_layers, 873 | dropout, ln=ln, tailact=tailact, **kwargs) 874 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 875 | self.xcnreslin = nn.Sequential( 876 | nn.Linear(in_channels, hidden_channels), 877 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 878 | nn.Linear(hidden_channels, hidden_channels), 879 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 880 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels) if not tailact else nn.Identity()) 881 | 882 | def multidomainforward(self, 883 | x, 884 | adj, 885 | tar_ei, 886 | filled1: bool = False, 887 | cndropprobs: Iterable[float] = []): 888 | adj = self.dropadj(adj) 889 | xi = x[tar_ei[0]] 890 | xj = x[tar_ei[1]] 891 | x = x + self.xlin(x) 892 | cn, cnres1, cnres2 = adjoverlap(adj, adj, tar_ei, filled1, cnsampledeg=self.cndeg, calresadj=True) 893 | 894 | xcns = [(spmm_add(cn, x), spmm_add(cnres1, x)+spmm_add(cnres2, x))] 895 | xij = self.xijlin(xi * xj) 896 | 897 | xs = torch.cat( 898 | [self.lin(self.xcnlin(xcn[0]) * self.beta + self.xcnreslin(xcn[1]) + xij) for xcn in xcns], 899 | dim=-1) 900 | return xs 901 | 902 | def forward(self, x, adj, tar_ei, filled1: bool = False): 903 | return self.multidomainforward(x, adj, tar_ei, filled1, []) 904 | 905 | # NCN with higher order neighborhood overlaps than NCN-2 906 | class CN2LinkPredictor(nn.Module): 907 | 908 | def __init__(self, 909 | in_channels, 910 | hidden_channels, 911 | out_channels, 912 | num_layers, 913 | dropout, 914 | edrop=0.0, 915 | ln=False, 916 | cndeg=-1): 917 | super().__init__() 918 | 919 | self.lins = nn.Sequential() 920 | 921 | self.register_parameter("alpha", nn.Parameter(torch.ones((3)))) 922 | self.register_parameter("beta", nn.Parameter(torch.ones((1)))) 923 | lnfn = lambda dim, ln: nn.LayerNorm(dim) if ln else nn.Identity() 924 | 925 | self.xcn1lin = nn.Sequential( 926 | nn.Linear(in_channels, hidden_channels), 927 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 928 | nn.Linear(hidden_channels, hidden_channels), 929 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 930 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels)) 931 | self.xcn2lin = nn.Sequential( 932 | nn.Linear(in_channels, hidden_channels), 933 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 934 | nn.Linear(hidden_channels, hidden_channels), 935 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 936 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels)) 937 | self.xcn4lin = nn.Sequential( 938 | nn.Linear(in_channels, hidden_channels), 939 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 940 | nn.Linear(hidden_channels, hidden_channels), 941 | lnfn(hidden_channels, ln), nn.Dropout(dropout, inplace=True), 942 | nn.ReLU(inplace=True), nn.Linear(hidden_channels, hidden_channels)) 943 | self.xijlin = nn.Sequential( 944 | nn.Linear(in_channels, hidden_channels), lnfn(hidden_channels, ln), 945 | nn.Dropout(dropout, inplace=True), nn.ReLU(inplace=True), 946 | nn.Linear(hidden_channels, hidden_channels)) 947 | self.lin = nn.Sequential(nn.Linear(hidden_channels, hidden_channels), 948 | lnfn(hidden_channels, ln), 949 | nn.Dropout(dropout, inplace=True), 950 | nn.ReLU(inplace=True), 951 | nn.Linear(hidden_channels, out_channels)) 952 | 953 | def forward(self, x, adj: SparseTensor, tar_ei, filled1: bool = False): 954 | spadj = adj.to_torch_sparse_coo_tensor() 955 | adj2 = SparseTensor.from_torch_sparse_coo_tensor(spadj @ spadj, False) 956 | cn1 = adjoverlap(adj, adj, tar_ei, filled1) 957 | cn2 = adjoverlap(adj, adj2, tar_ei, filled1) 958 | cn3 = adjoverlap(adj2, adj, tar_ei, filled1) 959 | cn4 = adjoverlap(adj2, adj2, tar_ei, filled1) 960 | xij = self.xijlin(x[tar_ei[0]] * x[tar_ei[1]]) 961 | xcn1 = self.xcn1lin(spmm_add(cn1, x)) 962 | xcn2 = self.xcn2lin(spmm_add(cn2, x)) 963 | xcn3 = self.xcn2lin(spmm_add(cn3, x)) 964 | xcn4 = self.xcn4lin(spmm_add(cn4, x)) 965 | alpha = torch.sigmoid(self.alpha).cumprod(-1) 966 | x = self.lin(alpha[0] * xcn1 + alpha[1] * xcn2 * xcn3 + 967 | alpha[2] * xcn4 + self.beta * xij) 968 | return x 969 | 970 | 971 | predictor_dict = { 972 | "cn0": CN0LinkPredictor, 973 | "catscn1": CatSCNLinkPredictor, 974 | "scn1": SCNLinkPredictor, 975 | "sincn1cn1": IncompleteSCN1Predictor, 976 | "cn1": CNLinkPredictor, 977 | "cn1.5": CNhalf2LinkPredictor, 978 | "cn1res": CNResLinkPredictor, 979 | "cn2": CN2LinkPredictor, 980 | "incn1cn1": IncompleteCN1Predictor 981 | } -------------------------------------------------------------------------------- /ogbdataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.metrics import roc_auc_score, average_precision_score 3 | from ogb.linkproppred import PygLinkPropPredDataset 4 | import torch_geometric.transforms as T 5 | from torch_sparse import SparseTensor 6 | from torch_geometric.datasets import Planetoid 7 | from torch_geometric.utils import train_test_split_edges, negative_sampling, to_undirected 8 | from torch_geometric.transforms import RandomLinkSplit 9 | 10 | # random split dataset 11 | def randomsplit(dataset, val_ratio: float=0.10, test_ratio: float=0.2): 12 | def removerepeated(ei): 13 | ei = to_undirected(ei) 14 | ei = ei[:, ei[0] None: 13 | self.bs = bs 14 | self.training = training 15 | self.idx = torch.randperm( 16 | size, device=device) if training else torch.arange(size, 17 | device=device) 18 | 19 | def __len__(self): 20 | return (self.idx.shape[0] + (self.bs - 1) * 21 | (not self.training)) // self.bs 22 | 23 | def __iter__(self): 24 | self.ptr = 0 25 | return self 26 | 27 | def __next__(self): 28 | if self.ptr + self.bs * self.training > self.idx.shape[0]: 29 | raise StopIteration 30 | ret = self.idx[self.ptr:self.ptr + self.bs] 31 | self.ptr += self.bs 32 | return ret 33 | 34 | 35 | def sparsesample(adj: SparseTensor, deg: int) -> SparseTensor: 36 | ''' 37 | sampling elements from a adjacency matrix 38 | ''' 39 | rowptr, col, _ = adj.csr() 40 | rowcount = adj.storage.rowcount() 41 | mask = rowcount > 0 42 | rowcount = rowcount[mask] 43 | rowptr = rowptr[:-1][mask] 44 | 45 | rand = torch.rand((rowcount.size(0), deg), device=col.device) 46 | rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1)) 47 | rand = rand.to(torch.long) 48 | rand.add_(rowptr.reshape(-1, 1)) 49 | 50 | samplecol = col[rand] 51 | 52 | samplerow = torch.arange(adj.size(0), device=adj.device())[mask] 53 | 54 | ret = SparseTensor(row=samplerow.reshape(-1, 1).expand(-1, deg).flatten(), 55 | col=samplecol.flatten(), 56 | sparse_sizes=adj.sparse_sizes()).to_device( 57 | adj.device()).coalesce().fill_value_(1.0) 58 | #print(ret.storage.value()) 59 | return ret 60 | 61 | 62 | def sparsesample2(adj: SparseTensor, deg: int) -> SparseTensor: 63 | ''' 64 | another implementation for sampling elements from a adjacency matrix 65 | ''' 66 | rowptr, col, _ = adj.csr() 67 | rowcount = adj.storage.rowcount() 68 | mask = rowcount > deg 69 | 70 | rowcount = rowcount[mask] 71 | rowptr = rowptr[:-1][mask] 72 | 73 | rand = torch.rand((rowcount.size(0), deg), device=col.device) 74 | rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1)) 75 | rand = rand.to(torch.long) 76 | rand.add_(rowptr.reshape(-1, 1)) 77 | 78 | samplecol = col[rand].flatten() 79 | 80 | samplerow = torch.arange(adj.size(0), device=adj.device())[mask].reshape( 81 | -1, 1).expand(-1, deg).flatten() 82 | 83 | mask = torch.logical_not(mask) 84 | nosamplerow, nosamplecol = adj[mask].coo()[:2] 85 | nosamplerow = torch.arange(adj.size(0), 86 | device=adj.device())[mask][nosamplerow] 87 | 88 | ret = SparseTensor( 89 | row=torch.cat((samplerow, nosamplerow)), 90 | col=torch.cat((samplecol, nosamplecol)), 91 | sparse_sizes=adj.sparse_sizes()).to_device( 92 | adj.device()).fill_value_(1.0).coalesce() #.fill_value_(1) 93 | #assert (ret.sum(dim=-1) == torch.clip(adj.sum(dim=-1), 0, deg)).all() 94 | return ret 95 | 96 | 97 | def sparsesample_reweight(adj: SparseTensor, deg: int) -> SparseTensor: 98 | ''' 99 | another implementation for sampling elements from a adjacency matrix. It will also scale the sampled elements. 100 | 101 | ''' 102 | rowptr, col, _ = adj.csr() 103 | rowcount = adj.storage.rowcount() 104 | mask = rowcount > deg 105 | 106 | rowcount = rowcount[mask] 107 | rowptr = rowptr[:-1][mask] 108 | 109 | rand = torch.rand((rowcount.size(0), deg), device=col.device) 110 | rand.mul_(rowcount.to(rand.dtype).reshape(-1, 1)) 111 | rand = rand.to(torch.long) 112 | rand.add_(rowptr.reshape(-1, 1)) 113 | 114 | samplecol = col[rand].flatten() 115 | 116 | samplerow = torch.arange(adj.size(0), device=adj.device())[mask].reshape( 117 | -1, 1).expand(-1, deg).flatten() 118 | samplevalue = (rowcount * (1/deg)).reshape(-1, 1).expand(-1, deg).flatten() 119 | 120 | mask = torch.logical_not(mask) 121 | nosamplerow, nosamplecol = adj[mask].coo()[:2] 122 | nosamplerow = torch.arange(adj.size(0), 123 | device=adj.device())[mask][nosamplerow] 124 | 125 | ret = SparseTensor(row=torch.cat((samplerow, nosamplerow)), 126 | col=torch.cat((samplecol, nosamplecol)), 127 | value=torch.cat((samplevalue, 128 | torch.ones_like(nosamplerow))), 129 | sparse_sizes=adj.sparse_sizes()).to_device( 130 | adj.device()).coalesce() #.fill_value_(1) 131 | #assert (ret.sum(dim=-1) == torch.clip(adj.sum(dim=-1), 0, deg)).all() 132 | return ret 133 | 134 | 135 | def elem2spm(element: Tensor, sizes: List[int]) -> SparseTensor: 136 | # Convert adjacency matrix to a 1-d vector 137 | col = torch.bitwise_and(element, 0xffffffff) 138 | row = torch.bitwise_right_shift(element, 32) 139 | return SparseTensor(row=row, col=col, sparse_sizes=sizes).to_device( 140 | element.device).fill_value_(1.0) 141 | 142 | 143 | def spm2elem(spm: SparseTensor) -> Tensor: 144 | # Convert 1-d vector to an adjacency matrix 145 | sizes = spm.sizes() 146 | elem = torch.bitwise_left_shift(spm.storage.row(), 147 | 32).add_(spm.storage.col()) 148 | #elem = spm.storage.row()*sizes[-1] + spm.storage.col() 149 | #assert torch.all(torch.diff(elem) > 0) 150 | return elem 151 | 152 | 153 | def spmoverlap_(adj1: SparseTensor, adj2: SparseTensor) -> SparseTensor: 154 | ''' 155 | Compute the overlap of neighbors (rows in adj). The returned matrix is similar to the hadamard product of adj1 and adj2 156 | ''' 157 | assert adj1.sizes() == adj2.sizes() 158 | element1 = spm2elem(adj1) 159 | element2 = spm2elem(adj2) 160 | 161 | if element2.shape[0] > element1.shape[0]: 162 | element1, element2 = element2, element1 163 | 164 | idx = torch.searchsorted(element1[:-1], element2) 165 | mask = (element1[idx] == element2) 166 | retelem = element2[mask] 167 | ''' 168 | nnz1 = adj1.nnz() 169 | element = torch.cat((adj1.storage.row(), adj2.storage.row()), dim=-1) 170 | element.bitwise_left_shift_(32) 171 | element[:nnz1] += adj1.storage.col() 172 | element[nnz1:] += adj2.storage.col() 173 | 174 | element = torch.sort(element, dim=-1)[0] 175 | mask = (element[1:] == element[:-1]) 176 | retelem = element[:-1][mask] 177 | ''' 178 | 179 | return elem2spm(retelem, adj1.sizes()) 180 | 181 | 182 | def spmnotoverlap_(adj1: SparseTensor, 183 | adj2: SparseTensor) -> Tuple[SparseTensor, SparseTensor]: 184 | ''' 185 | return elements in adj1 but not in adj2 and in adj2 but not adj1 186 | ''' 187 | # assert adj1.sizes() == adj2.sizes() 188 | element1 = spm2elem(adj1) 189 | element2 = spm2elem(adj2) 190 | 191 | idx = torch.searchsorted(element1[:-1], element2) 192 | matchedmask = (element1[idx] == element2) 193 | 194 | maskelem1 = torch.ones_like(element1, dtype=torch.bool) 195 | maskelem1[idx[matchedmask]] = 0 196 | retelem1 = element1[maskelem1] 197 | 198 | retelem2 = element2[torch.logical_not(matchedmask)] 199 | return elem2spm(retelem1, adj1.sizes()), elem2spm(retelem2, adj2.sizes()) 200 | 201 | 202 | def spmoverlap_notoverlap_( 203 | adj1: SparseTensor, 204 | adj2: SparseTensor) -> Tuple[SparseTensor, SparseTensor, SparseTensor]: 205 | ''' 206 | return elements in adj1 but not in adj2 and in adj2 but not adj1 207 | ''' 208 | # assert adj1.sizes() == adj2.sizes() 209 | element1 = spm2elem(adj1) 210 | element2 = spm2elem(adj2) 211 | 212 | if element1.shape[0] == 0: 213 | retoverlap = element1 214 | retelem1 = element1 215 | retelem2 = element2 216 | else: 217 | idx = torch.searchsorted(element1[:-1], element2) 218 | matchedmask = (element1[idx] == element2) 219 | 220 | maskelem1 = torch.ones_like(element1, dtype=torch.bool) 221 | maskelem1[idx[matchedmask]] = 0 222 | retelem1 = element1[maskelem1] 223 | 224 | retoverlap = element2[matchedmask] 225 | retelem2 = element2[torch.logical_not(matchedmask)] 226 | sizes = adj1.sizes() 227 | return elem2spm(retoverlap, 228 | sizes), elem2spm(retelem1, 229 | sizes), elem2spm(retelem2, sizes) 230 | 231 | 232 | def adjoverlap(adj1: SparseTensor, 233 | adj2: SparseTensor, 234 | tarei: Tensor, 235 | filled1: bool = False, 236 | calresadj: bool = False, 237 | cnsampledeg: int = -1, 238 | ressampledeg: int = -1): 239 | # a wrapper for functions above. 240 | adj1 = adj1[tarei[0]] 241 | adj2 = adj2[tarei[1]] 242 | if calresadj: 243 | adjoverlap, adjres1, adjres2 = spmoverlap_notoverlap_(adj1, adj2) 244 | if cnsampledeg > 0: 245 | adjoverlap = sparsesample_reweight(adjoverlap, cnsampledeg) 246 | if ressampledeg > 0: 247 | adjres1 = sparsesample_reweight(adjres1, ressampledeg) 248 | adjres2 = sparsesample_reweight(adjres2, ressampledeg) 249 | return adjoverlap, adjres1, adjres2 250 | else: 251 | adjoverlap = spmoverlap_(adj1, adj2) 252 | if cnsampledeg > 0: 253 | adjoverlap = sparsesample_reweight(adjoverlap, cnsampledeg) 254 | return adjoverlap 255 | 256 | 257 | if __name__ == "__main__": 258 | adj1 = SparseTensor.from_edge_index( 259 | torch.LongTensor([[0, 0, 1, 2, 3], [0, 1, 1, 2, 3]])) 260 | adj2 = SparseTensor.from_edge_index( 261 | torch.LongTensor([[0, 3, 1, 2, 3], [0, 1, 1, 2, 3]])) 262 | adj3 = SparseTensor.from_edge_index( 263 | torch.LongTensor([[0, 1, 2, 2, 2,2, 3, 3, 3], [1, 0, 2,3,4, 5, 4, 5, 6]])) 264 | print(spmnotoverlap_(adj1, adj2)) 265 | print(spmoverlap_(adj1, adj2)) 266 | print(spmoverlap_notoverlap_(adj1, adj2)) 267 | print(sparsesample2(adj3, 2)) 268 | print(sparsesample_reweight(adj3, 2)) --------------------------------------------------------------------------------