├── figure.png ├── requirements.txt ├── .gitignore ├── README.md ├── kernels.py ├── gnns.py ├── logger.py ├── parse.py ├── models.py ├── data_utils.py ├── main.py └── dataset.py /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chr26195/GKD/HEAD/figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | googledrivedownloader==0.4 2 | numpy==1.21.5 3 | ogb==1.3.3 4 | scikit_learn==1.1.3 5 | scipy==1.7.3 6 | torch==1.9.0 7 | torch-cluster==1.5.9 8 | torch_geometric==1.7.2 9 | torch_scatter==2.0.7 10 | torch_sparse==0.6.10 11 | torch-spline-conv==1.2.1 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | **/__pycache__/ 6 | 7 | # Pycharm 8 | .idea/ 9 | 10 | 11 | # Saved models 12 | **/model/ 13 | **/checkpoint/ 14 | **/saved_models/ 15 | *.pkl 16 | *.sh 17 | 18 | # SLURM script 19 | *.slurm 20 | *.err 21 | *.out 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | env/ 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | **/logs/ 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *,cover 64 | .hypothesis/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # IPython Notebook 88 | .ipynb_checkpoints 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # dotenv 97 | .env 98 | 99 | # virtualenv 100 | venv/ 101 | ENV/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | 106 | # Rope project settings 107 | .ropeproject -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks 2 | 3 | The official implementation for "Geometric Knowledge Distillation: Topology Compression for Graph Neural Networks" which is accepted to NeurIPS22. 4 | 5 | **Abstract:** We study a new paradigm of knowledge transfer that aims at encoding graph topological information into graph neural networks (GNNs) by distilling knowledge from a teacher GNN model trained on a complete graph to a student GNN model operating on a smaller or sparser graph. To this end, we revisit the connection between thermodynamics and the behavior of GNN, based on which we propose Neural Heat Kernel (NHK) to encapsulate the geometric property of the underlying manifold concerning the architecture of GNNs. A fundamental and principled solution is derived by aligning NHKs on teacher and student models, dubbed as Geometric Knowledge Distillation. We develop non- and parametric instantiations and demonstrate their efficacy in various experimental settings for knowledge distillation regarding different types of privileged topological information and teacher-student schemes. 6 | 7 | Related materials: 8 | [paper](https://openreview.net/pdf?id=7WGNT3MHyBm) 9 | 10 | 11 | 12 | ### Use the Code 13 | 14 | - Install the required package according to `requirements.txt`. 15 | - Specify your own data path in `parse.py` and download the datasets. 16 | - Pretrain teacher models, which will be saved in the folder `/saved_models` 17 | ``` 18 | python main.py --dataset cora --rand_split --use_bn --base_model gcn --mode pretrain --dist_mode no --save_model 19 | ``` 20 | - Train student models, e.g., 21 | ``` 22 | python main.py --dataset cora --rand_split --use_bn --base_model gcn --mode train --priv_type edge --dist_mode gkd --kernel sigmoid 23 | ``` 24 | 25 | ### ACK 26 | The pipeline for training and preprocessing is developed on basis of the Non-Homophilous Benchmark project. 27 | -------------------------------------------------------------------------------- /kernels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_sparse import SparseTensor, matmul 5 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 6 | from torch_geometric.utils import num_nodes, to_dense_adj 7 | from gnns import * 8 | import numpy as np 9 | 10 | class Kernel(nn.Module): 11 | def __init__(self, hidden_channels, out_channels, args, num_node): 12 | super(Kernel, self).__init__() 13 | self.args = args 14 | self.hidden_channels = hidden_channels 15 | self.num_node = num_node 16 | self.bns = nn.BatchNorm1d(hidden_channels, eps=1e-10, affine=False, track_running_stats=False) 17 | if self.args.dist_mode == 'pgkd': 18 | self.phi = nn.Parameter(torch.randn(hidden_channels, hidden_channels * self.args.s)) 19 | 20 | def normalize(self, mat): 21 | norm = mat.diagonal().sqrt() 22 | norm_mat = torch.outer(norm, norm) 23 | mat = mat / norm_mat 24 | return mat 25 | 26 | def forward(self, x, y): 27 | if self.args.kernel == 'sigmoid': 28 | mat = nn.Tanh()(x @ y.transpose(-1, -2)) 29 | elif self.args.kernel == 'gaussian': 30 | mat = torch.cdist(x, y, p=2) 31 | mat = (-mat/self.args.t).exp() 32 | return self.normalize(mat) if self.args.ker_norm else mat 33 | 34 | def random(self, xt, xs): # n * d 35 | d = xt.shape[-1] 36 | W = torch.randn(self.args.m, d, d * self.args.s).to(xs.device) 37 | 38 | xt_ = torch.einsum('nd,adk->ank', xt, W).tanh() 39 | xs_ = torch.einsum('nd,adk->ank', xs, W).tanh() 40 | te_mat = torch.einsum('aij,akj->aik', xt_, xt_) 41 | st_mat = torch.einsum('aij,akj->aik', xs_, xs_) 42 | 43 | e = torch.randn(self.args.m).to(xt_.device) 44 | te_mat = torch.einsum('a,aij->ij', e, te_mat) 45 | st_mat = torch.einsum('a,aij->ij', e, st_mat) 46 | return te_mat.detach(), st_mat 47 | 48 | def parametric(self, xt, xs, detach = False): 49 | if detach: 50 | xt_ = (xt @ self.phi.detach().clone()).tanh() 51 | xs_ = (xs @ self.phi.detach().clone()).tanh() 52 | else: 53 | xt_ = (xt @ self.phi).tanh() 54 | xs_ = (xs @ self.phi).tanh() 55 | te_mat = xt_ @ xt_.transpose(-1, -2) 56 | st_mat = xs_ @ xs_.transpose(-1, -2) 57 | return te_mat, st_mat 58 | 59 | def dist_loss(self, mt, ms, A = None): 60 | if A == None or self.args.delta == 1.0: 61 | return nn.MSELoss(reduction='sum')(mt, ms) 62 | else: 63 | return (nn.MSELoss(reduction='none')(mt, ms) * (A + (1-A) * self.args.delta)).sum() 64 | 65 | def rec_loss(self, x2, x1, m): 66 | return nn.MSELoss(reduction='sum')(self.bns(m @ x2.detach()), self.bns(x1.detach())) -------------------------------------------------------------------------------- /gnns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_sparse import SparseTensor, matmul 5 | from torch_geometric.nn import GCNConv, SGConv, GATConv, JumpingKnowledge, APPNP, MessagePassing 6 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 7 | from torch_geometric.utils import remove_self_loops, add_self_loops, degree 8 | import scipy.sparse 9 | import numpy as np 10 | import math 11 | 12 | class SpecialSpmmFunction(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, indices, values, shape, b): 15 | assert indices.requires_grad == False 16 | a = torch.sparse_coo_tensor(indices, values, shape) 17 | ctx.save_for_backward(a, b) 18 | ctx.N = shape[0] 19 | return torch.matmul(a, b) 20 | 21 | @staticmethod 22 | def backward(ctx, grad_output): 23 | a, b = ctx.saved_tensors 24 | grad_values = grad_b = None 25 | if ctx.needs_input_grad[1]: 26 | grad_a_dense = grad_output.matmul(b.t()) 27 | edge_idx = a._indices()[0, :] * ctx.N + a._indices()[1, :] 28 | grad_values = grad_a_dense.view(-1)[edge_idx] 29 | if ctx.needs_input_grad[3]: 30 | grad_b = a.t().matmul(grad_output) 31 | return None, grad_values, None, grad_b 32 | 33 | 34 | class SpecialSpmm(nn.Module): 35 | def forward(self, indices, values, shape, b): 36 | return SpecialSpmmFunction.apply(indices, values, shape, b) 37 | 38 | 39 | 40 | class GCNLayer(nn.Module): 41 | def __init__(self, in_channels, out_channels, dropout=0.0, simple=False): 42 | super(GCNLayer, self).__init__() 43 | self.simple = simple 44 | if not simple: 45 | self.W = nn.Parameter(torch.zeros(in_channels, out_channels)) 46 | self.dropout = dropout 47 | self.specialspmm = SpecialSpmm() 48 | self.reset_parameters() 49 | 50 | def reset_parameters(self): 51 | if not self.simple: nn.init.xavier_uniform_(self.W.data, gain=1.414) 52 | 53 | def forward(self, x, edge_index, ff = True): 54 | if not self.simple and ff: h = torch.matmul(x, self.W) 55 | else: h = x 56 | N = h.size(0) 57 | 58 | # weight_mat: hard and differentiable affinity matrix 59 | edge_index, _ = remove_self_loops(edge_index) 60 | edge_index, _ = add_self_loops(edge_index, num_nodes=N) 61 | src, dst = edge_index 62 | 63 | deg = degree(dst, num_nodes=N) 64 | deg_src = deg[src].pow_(-0.5) 65 | deg_src.masked_fill_(deg_src == float('inf'), 0) 66 | deg_dst = deg[dst].pow_(-0.5) 67 | deg_dst.masked_fill_(deg_dst == float('inf'), 0) 68 | edge_weight = deg_src * deg_dst 69 | 70 | h_prime = self.specialspmm(edge_index, edge_weight, torch.Size([N, N]), h) 71 | return h_prime 72 | 73 | 74 | 75 | 76 | class GATLayer(nn.Module): 77 | def __init__(self, in_channels, out_channels, dropout=0.0, alpha=0.2): 78 | super(GATLayer, self).__init__() 79 | 80 | self.W = nn.Parameter(torch.zeros(in_channels, out_channels)) 81 | self.a = nn.Parameter(torch.zeros(1, out_channels * 2)) 82 | 83 | self.leakyrelu = nn.LeakyReLU(alpha) 84 | self.specialspmm = SpecialSpmm() 85 | 86 | self.dropout = dropout 87 | self.out_channels = out_channels 88 | self.reset_parameters() 89 | 90 | def reset_parameters(self): 91 | nn.init.xavier_uniform_(self.W.data, gain=1.414) 92 | nn.init.xavier_uniform_(self.a.data, gain=1.414) 93 | 94 | def forward(self, x, edge_index, edge_weight=None, output_weight=False): 95 | h = torch.matmul(x, self.W) 96 | N = h.size(0) 97 | edge_index, _ = remove_self_loops(edge_index) 98 | edge_index, _ = add_self_loops(edge_index, num_nodes=N) 99 | src, dst = edge_index 100 | 101 | # edge_index, _ = remove_self_loops(edge_index) 102 | # edge_index, _ = add_self_loops(edge_index, num_nodes=N) 103 | 104 | if edge_weight is None: 105 | edge_h = torch.cat((h[src], h[dst]), dim=-1) # [E, 2*D] 106 | 107 | edge_e = (edge_h * self.a).sum(dim=-1) # [E] 108 | edge_e = torch.exp(self.leakyrelu(edge_e)) # [E] 109 | edge_e = F.dropout(edge_e, p=self.dropout, training=self.training) # [E] 110 | # e = torch.sparse_coo_tensor(edge_index, edge_e, size=torch.Size([N, N])) 111 | e_expsum = self.specialspmm(edge_index, edge_e, torch.Size([N, N]), torch.ones(N, 1).to(x.device)) 112 | assert not torch.isnan(e_expsum).any() 113 | 114 | # edge_e_ = F.dropout(edge_e, p=0.8, training=self.training) 115 | h_prime = self.specialspmm(edge_index, edge_e, torch.Size([N, N]), h) 116 | h_prime = torch.div(h_prime, e_expsum) # [N, D] tensor 117 | else: 118 | h_prime = self.specialspmm(edge_index, edge_weight, torch.Size([N, N]), h) 119 | 120 | if output_weight: 121 | edge_expsum = e_expsum[dst].squeeze(1) 122 | return h_prime, torch.div(edge_e, edge_expsum) 123 | else: 124 | return h_prime 125 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | 4 | printable_method = {'transgnn', 'gat'} 5 | 6 | 7 | def create_print_dict(args): 8 | if args.method == 'transgnn': 9 | return {'n_layer': args.num_layers, 10 | 'hidden_channels': args.hidden_channels, 11 | 'trans_heads': args.trans_heads, 12 | 'lr': args.lr, 13 | 'epochs': args.epochs} 14 | elif args.method == 'gat': 15 | return {'n_layer': args.num_layers, 16 | 'hidden_channels': args.hidden_channels, 17 | 'gat_heads': args.gat_heads, 18 | 'lr': args.lr, 19 | 'epochs': args.epochs 20 | } 21 | else: 22 | return None 23 | 24 | 25 | class Logger(object): 26 | """ Adapted from https://github.com/snap-stanford/ogb/ """ 27 | 28 | def __init__(self, runs, info=None): 29 | self.info = info 30 | self.results = [[] for _ in range(runs)] 31 | 32 | def add_result(self, run, result): 33 | assert len(result) == 4 34 | assert run >= 0 and run < len(self.results) 35 | self.results[run].append(result) 36 | 37 | def print_statistics(self, run=None, mode='max_acc'): 38 | if run is not None: 39 | result = 100 * torch.tensor(self.results[run]) 40 | argmax = result[:, 1].argmax().item() 41 | argmin = result[:, 3].argmin().item() 42 | if mode == 'max_acc': 43 | ind = argmax 44 | else: 45 | ind = argmin 46 | print(f'Run {run + 1:02d}:') 47 | print(f'Highest Train: {result[:, 0].max():.2f}') 48 | print(f'Highest Valid: {result[:, 1].max():.2f}') 49 | print(f'Highest Test: {result[:, 2].max():.2f}') 50 | print(f'Chosen epoch: {ind + 1}') 51 | print(f'Final Train: {result[ind, 0]:.2f}') 52 | print(f'Final Test: {result[ind, 2]:.2f}') 53 | self.test = result[ind, 2] 54 | else: 55 | result = 100 * torch.tensor(self.results) 56 | 57 | best_results = [] 58 | for r in result: 59 | train1 = r[:, 0].max().item() 60 | test1 = r[:, 2].max().item() 61 | valid = r[:, 1].max().item() 62 | if mode == 'max_acc': 63 | train2 = r[r[:, 1].argmax(), 0].item() 64 | 65 | test2 = r[r[:, 1].argmax(), 2].item() 66 | else: 67 | train2 = r[r[:, 3].argmin(), 0].item() 68 | test2 = r[r[:, 3].argmin(), 2].item() 69 | best_results.append((train1, test1, valid, train2, test2)) 70 | best_result = torch.tensor(best_results) 71 | 72 | print(f'All runs:') 73 | r = best_result[:, 0] 74 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 75 | r = best_result[:, 1] 76 | print(f'Highest Test: {r.mean():.2f} ± {r.std():.2f}') 77 | r = best_result[:, 2] 78 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 79 | r = best_result[:, 3] 80 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 81 | r = best_result[:, 4] 82 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') 83 | 84 | self.test = r.mean() 85 | return best_result[:, 4] 86 | 87 | def output(self, out_path, info): 88 | with open(out_path, 'a') as f: 89 | f.write(info) 90 | f.write(f'test acc:{self.test}\n') 91 | 92 | 93 | class SimpleLogger(object): 94 | """ Adapted from https://github.com/CUAI/CorrectAndSmooth """ 95 | 96 | def __init__(self, desc, param_names, num_values=2): 97 | self.results = defaultdict(dict) 98 | self.param_names = tuple(param_names) 99 | self.used_args = list() 100 | self.desc = desc 101 | self.num_values = num_values 102 | 103 | def add_result(self, run, args, values): 104 | """Takes run=int, args=tuple, value=tuple(float)""" 105 | assert (len(args) == len(self.param_names)) 106 | assert (len(values) == self.num_values) 107 | self.results[run][args] = values 108 | if args not in self.used_args: 109 | self.used_args.append(args) 110 | 111 | def get_best(self, top_k=1): 112 | all_results = [] 113 | for args in self.used_args: 114 | results = [i[args] for i in self.results.values() if args in i] 115 | results = torch.tensor(results) * 100 116 | results_mean = results.mean(dim=0)[-1] 117 | results_std = results.std(dim=0) 118 | 119 | all_results.append((args, results_mean)) 120 | results = sorted(all_results, key=lambda x: x[1], reverse=True)[:top_k] 121 | return [i[0] for i in results] 122 | 123 | def prettyprint(self, x): 124 | if isinstance(x, float): 125 | return '%.2f' % x 126 | return str(x) 127 | 128 | def display(self, args=None): 129 | disp_args = self.used_args if args is None else args 130 | if len(disp_args) > 1: 131 | print(f'{self.desc} {self.param_names}, {len(self.results.keys())} runs') 132 | for args in disp_args: 133 | results = [i[args] for i in self.results.values() if args in i] 134 | results = torch.tensor(results) * 100 135 | results_mean = results.mean(dim=0) 136 | results_std = results.std(dim=0) 137 | res_str = f'{results_mean[0]:.2f} ± {results_std[0]:.2f}' 138 | for i in range(1, self.num_values): 139 | res_str += f' -> {results_mean[i]:.2f} ± {results_std[1]:.2f}' 140 | print(f'Args {[self.prettyprint(x) for x in args]}: {res_str}') 141 | if len(disp_args) > 1: 142 | print() 143 | return results 144 | -------------------------------------------------------------------------------- /parse.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | from data_utils import normalize 3 | 4 | 5 | def parse_method(args, dataset, n, c, d, device): 6 | model = GeoDist(d, args.hidden_channels, c, args, n, num_layers=args.num_layers, use_bn=args.use_bn).to(device) 7 | return model 8 | 9 | 10 | def parser_add_main_args(parser): 11 | # setup and protocol 12 | parser.add_argument('--data_dir', type=str, default='data') 13 | parser.add_argument('--dataset', type=str, default='cora') 14 | parser.add_argument('--sub_dataset', type=str, default='') 15 | parser.add_argument('--cpu', action='store_true') 16 | parser.add_argument('--seed', type=int, default=42) 17 | parser.add_argument('--train_prop', type=float, default=.5, 18 | help='training label proportion') 19 | parser.add_argument('--valid_prop', type=float, default=.25, 20 | help='validation label proportion') 21 | 22 | 23 | parser.add_argument('--protocol', type=str, default='semi', 24 | help='protocol for cora datasets, semi or supervised') 25 | parser.add_argument('--rand_split', action='store_true', help='use random splits') 26 | parser.add_argument('--metric', type=str, default='acc', choices=['acc', 'rocauc', 'f1'], 27 | help='evaluation metric') 28 | parser.add_argument('--runs', type=int, default=5, help='number of distinct runs') 29 | parser.add_argument('--epochs', type=int, default=200) 30 | 31 | parser.add_argument('--hidden_channels', type=int, default=32) 32 | 33 | parser.add_argument('--num_layers', type=int, default=3, 34 | help='number of layers for deep methods') 35 | parser.add_argument('--gat_heads', type=int, default=8, 36 | help='attention heads for gat') 37 | parser.add_argument('--out_heads', type=int, default=1, 38 | help='out heads for gat') 39 | parser.add_argument('--hops', type=int, default=1, 40 | help='power of adjacency matrix for certain methods') 41 | parser.add_argument('--lp_alpha', type=float, default=.1, 42 | help='alpha for label prop') 43 | parser.add_argument('--gpr_alpha', type=float, default=.1, 44 | help='alpha for gprgnn') 45 | parser.add_argument('--jk_type', type=str, default='max', choices=['max', 'lstm', 'cat'], 46 | help='jumping knowledge type') 47 | parser.add_argument('--directed', action='store_true', 48 | help='set to not symmetrize adjacency') 49 | parser.add_argument('--num_mlp_layers', type=int, default=1, 50 | help='number of mlp layers in h2gcn') 51 | 52 | # display and utility 53 | parser.add_argument('--display_step', type=int, 54 | default=5, help='how often to print') 55 | parser.add_argument('--cached', action='store_true', 56 | help='set to use faster sgc') 57 | parser.add_argument('--print_prop', action='store_true', 58 | help='print proportions of predicted class') 59 | 60 | parser.add_argument('--priv_type', type=str, choices=['edge', 'node'], 61 | default='edge', help='type for privileged information') 62 | parser.add_argument('--priv_ratio', type=float, 63 | default=0.5, help='ratio for privileged nodes/edges') 64 | 65 | parser.add_argument('--save_model', action='store_true', help='save model') 66 | parser.add_argument('--save_name', type=str, default='gcn', help='saved model name') 67 | parser.add_argument('--log_name', type=str, default='none', help='log file appendix name') 68 | 69 | parser.add_argument('--base_model', type=str, default='gcn', choices=['gcn', 'gat'], help='which model') 70 | parser.add_argument('--not_load_teacher', action='store_true', help='whether not load teacher model') 71 | 72 | parser.add_argument('--mode', type=str, choices=['pretrain', 'train'], 73 | default='pretrain', help='mode for pretrain teacher or train student') 74 | parser.add_argument('--dist_mode', type=str, 75 | default='label', help='mode for knowledge distillation') 76 | 77 | # training 78 | parser.add_argument('--weight_decay', type=float, default=0.05) 79 | parser.add_argument('--dropout', type=float, default=0.0) 80 | parser.add_argument('--lr', type=float, default=0.01) 81 | parser.add_argument('--use_bn', action='store_true', help='use batch norm') 82 | parser.add_argument('--use_batch', action='store_true') 83 | parser.add_argument('--batch_size', type=int, default=2048) 84 | parser.add_argument('--oracle', action='store_true', help='whether using the complete graph for testing') 85 | parser.add_argument('--device', type=int, default=0, help='which gpu to use if any (default: 0)') 86 | 87 | 88 | # for distillation loss 89 | parser.add_argument('--alpha', type=float, default=0.5, help='distillation loss (GKD) weight') 90 | parser.add_argument('--delta', type=float, default=0.1, help='distillation loss hyperparameter') 91 | 92 | parser.add_argument('--use_kd', action='store_true', help='whether use vanilla KD loss') 93 | parser.add_argument('--beta', type=float, default=0.0, help='weight for auxiliary KD loss') 94 | parser.add_argument('--tau', type=float, default=1.0, help='KD loss temperature') 95 | 96 | parser.add_argument('--kernel', type=str, default='sigmoid', help='sigmoid/gaussian/random') 97 | parser.add_argument('--t', type=float, default=1.0, help='hyperparameter for Gauss-Weierstras kernel') 98 | parser.add_argument('--include_last', action='store_true', help='whether include last layer') 99 | parser.add_argument('--s', type=int, default=2, help='hyperparameter for random and parametric kernel') 100 | parser.add_argument('--m', type=int, default=1, help='hyperparameter for random kernel') 101 | parser.add_argument('--lr2', type=float, default=0.001) 102 | 103 | parser.add_argument('--sim', type=str, default='l2') 104 | parser.add_argument('--ker_norm', action='store_true') -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_sparse import SparseTensor, matmul 5 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 6 | from torch_geometric.utils import num_nodes, to_dense_adj 7 | from gnns import * 8 | import numpy as np 9 | from kernels import Kernel 10 | 11 | class GeoDist(nn.Module): 12 | ''' 13 | in_channels: number of features 14 | hidden_channels: hidden size 15 | ''' 16 | def __init__(self, in_channels, hidden_channels, out_channels, args, num_node, num_layers=2, 17 | use_bn=False): 18 | super(GeoDist, self).__init__() 19 | self.args = args 20 | self.dropout = args.dropout 21 | self.num_node = num_node 22 | self.k = Kernel(hidden_channels, out_channels, args, num_node) 23 | 24 | self.teacher_gnn = nn.ModuleList([nn.Linear(in_channels, hidden_channels, bias=False)]) 25 | self.student_gnn = nn.ModuleList([nn.Linear(in_channels, hidden_channels, bias=False)]) 26 | self.bns = nn.BatchNorm1d(hidden_channels, eps=1e-10, affine=False, track_running_stats=False) 27 | self.bns2d = nn.BatchNorm2d(hidden_channels, affine=False, track_running_stats=False) 28 | 29 | if args.base_model == 'gcn': 30 | self.teacher_gnn.append(GCNLayer(hidden_channels, hidden_channels, dropout=self.dropout, simple=True)) 31 | self.student_gnn.append(GCNLayer(hidden_channels, hidden_channels, dropout=self.dropout, simple=True)) 32 | 33 | for _ in range(num_layers - 2): 34 | self.teacher_gnn.append( 35 | GCNLayer(hidden_channels, hidden_channels)) 36 | self.student_gnn.append( 37 | GCNLayer(hidden_channels, hidden_channels)) 38 | 39 | self.teacher_gnn.append(GCNLayer(hidden_channels, out_channels)) 40 | self.student_gnn.append(GCNLayer(hidden_channels, out_channels)) 41 | 42 | elif args.base_model == 'gat': 43 | self.teacher_gnn.append(GATLayer(hidden_channels, hidden_channels, dropout=self.dropout, simple=True)) 44 | self.student_gnn.append(GATLayer(hidden_channels, hidden_channels, dropout=self.dropout, simple=True)) 45 | 46 | for _ in range(num_layers - 2): 47 | self.teacher_gnn.append( 48 | GATLayer(hidden_channels, hidden_channels, dropout=self.dropout)) 49 | self.student_gnn.append( 50 | GATLayer(hidden_channels, hidden_channels, dropout=self.dropout)) 51 | 52 | self.teacher_gnn.append(GATLayer(hidden_channels, out_channels, dropout=self.dropout)) 53 | self.student_gnn.append(GATLayer(hidden_channels, out_channels, dropout=self.dropout)) 54 | 55 | self.activation = F.relu 56 | self.use_bn = use_bn 57 | 58 | def reset_parameters(self): 59 | for conv in self.teacher_gnn: 60 | conv.reset_parameters() 61 | for conv in self.student_gnn: 62 | conv.reset_parameters() 63 | 64 | def forward(self, data_full, data=None, mode='pretrain', dist_mode='label', t=1.0): 65 | if mode == 'pretrain': 66 | return self.forward_teacher(data_full) 67 | elif mode == 'train': 68 | return self.forward_student(data_full, data, dist_mode, t) 69 | else: 70 | NotImplementedError 71 | 72 | def forward_teacher(self, data_full): 73 | x, edge_index = data_full.graph['node_feat'], data_full.graph['edge_index'] 74 | x = self.teacher_gnn[0](x) 75 | for i in range(1, len(self.teacher_gnn) - 1): 76 | x = self.teacher_gnn[i](x, edge_index) # [n, h] 77 | if self.use_bn: x = self.bns(x) 78 | x = self.activation(x) 79 | x = F.dropout(x, p=self.dropout, training=self.training) 80 | x = self.teacher_gnn[-1](x, edge_index) 81 | return x 82 | 83 | def forward_student(self, data_full, data, dist_mode, t = 1): 84 | x_full, edge_index_full = data_full.graph['node_feat'], data_full.graph['edge_index'] 85 | xt = xs = x_full # edge missing, node is seen for both teacher and student 86 | edge_index = data.graph['edge_index'] 87 | 88 | if self.args.use_batch: 89 | idx = torch.randperm(len(data.share_node_idx))[:self.args.batch_size] 90 | share_node_idx = data.share_node_idx[idx] 91 | else: share_node_idx = data.share_node_idx 92 | train_idx = data.train_idx 93 | 94 | if dist_mode == 'no': 95 | return self.no_kd(xs, xt, edge_index, edge_index_full, train_idx, share_node_idx) 96 | elif dist_mode == 'gkd': 97 | return self.gkd(xs, xt, edge_index, edge_index_full, train_idx, share_node_idx) 98 | elif dist_mode == 'pgkd': 99 | return self.pgkd(xs, xt, edge_index, edge_index_full, train_idx, share_node_idx) 100 | else: 101 | NotImplementedError 102 | 103 | def inference(self, data, mode='pretrain'): 104 | x, edge_index = data.graph['node_feat'], data.graph['edge_index'] 105 | if mode == 'pretrain': 106 | x = self.teacher_gnn[0](x) 107 | for i in range(1, len(self.teacher_gnn) - 1): 108 | x = self.teacher_gnn[i](x, edge_index) 109 | if self.use_bn: 110 | if 'share_node_idx' in data.__dict__.keys(): 111 | x[data.share_node_idx] = self.bns(x[data.share_node_idx]) 112 | else: x = self.bns(x) 113 | x = self.activation(x) 114 | x = self.teacher_gnn[-1](x, edge_index) 115 | return x 116 | elif mode == 'train': 117 | share_node_idx = data.share_node_idx 118 | x = self.student_gnn[0](x) 119 | for i in range(1, len(self.student_gnn) - 1): 120 | x = self.student_gnn[i](x, edge_index) 121 | if self.use_bn: x[share_node_idx] = self.bns(x[share_node_idx]) 122 | x = self.activation(x) 123 | x = self.student_gnn[-1](x, edge_index) 124 | return x 125 | 126 | def no_kd(self, xs, xt, edge_index, edge_index_full, train_idx, share_node_idx): 127 | xs = self.student_gnn[0](xs) 128 | for i in range(1, len(self.student_gnn) - 1): 129 | xs = self.student_gnn[i](xs, edge_index) 130 | xs[share_node_idx] = self.bns(xs[share_node_idx]) 131 | xs = self.activation(xs) 132 | y_logit_s = self.student_gnn[-1](xs, edge_index) 133 | return y_logit_s 134 | 135 | def gkd(self, xs, xt, edge_index, edge_index_full, train_idx, share_node_idx): 136 | if self.args.delta != 1: 137 | A = to_dense_adj(edge_index, max_num_nodes=self.num_node).squeeze().fill_diagonal_(1.) 138 | A = A[share_node_idx, :][:, share_node_idx] 139 | else: A = None 140 | 141 | loss_list = [] 142 | xt = self.teacher_gnn[0](xt) 143 | xs = self.student_gnn[0](xs) 144 | 145 | for i in range(1, len(self.teacher_gnn) - 1): 146 | xt = self.teacher_gnn[i](xt, edge_index_full) 147 | xs = self.student_gnn[i](xs, edge_index) 148 | xt, xs = self.bns(xt), self.bns(xs) 149 | if self.args.kernel != 'random': 150 | mt = self.k(xt[share_node_idx], xt[share_node_idx]).detach() 151 | ms = self.k(xs[share_node_idx], xs[share_node_idx]) 152 | else: 153 | mt, ms = self.k.random(xt[share_node_idx], xs[share_node_idx]) 154 | loss_list.append(self.k.dist_loss(mt, ms, A)) 155 | xt, xs = self.activation(xt), self.activation(xs) 156 | 157 | y_logit_t = self.teacher_gnn[-1](xt, edge_index_full) 158 | y_logit_s = self.student_gnn[-1](xs, edge_index) 159 | if self.args.include_last: 160 | if self.args.kernel != 'random': 161 | mt = self.k(y_logit_t[share_node_idx], y_logit_t[share_node_idx]).detach() 162 | ms = self.k(y_logit_s[share_node_idx], y_logit_s[share_node_idx]) 163 | else: 164 | mt, ms = self.k.random(y_logit_t[share_node_idx], y_logit_s[share_node_idx]) 165 | loss_list.append(self.k.dist_loss(mt, ms, A)) 166 | gkd_dist_loss = sum(loss_list)/len(loss_list) 167 | if self.args.use_kd: 168 | dist_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y_logit_s[train_idx]/self.args.tau, dim=1), F.softmax(y_logit_t[train_idx].detach()/self.args.tau, dim=1)) 169 | return y_logit_s, gkd_dist_loss, dist_loss 170 | else: 171 | return y_logit_s, gkd_dist_loss 172 | 173 | def pgkd(self, xs, xt, edge_index, edge_index_full, train_idx, share_node_idx): 174 | if self.args.delta != 1: 175 | A = to_dense_adj(edge_index, max_num_nodes=self.num_node).squeeze().fill_diagonal_(1.) 176 | A = A[share_node_idx, :][:, share_node_idx] 177 | else: A = None 178 | 179 | xt = self.teacher_gnn[0](xt) 180 | xs = self.student_gnn[0](xs) 181 | xt1, xs1 = xt[share_node_idx].clone(), xs[share_node_idx].clone() 182 | 183 | for i in range(1, len(self.teacher_gnn) - 1): 184 | xt = self.teacher_gnn[i](xt, edge_index_full) 185 | xs = self.student_gnn[i](xs, edge_index) 186 | xt, xs = self.bns(xt), self.bns(xs) 187 | xt, xs = self.activation(xt), self.activation(xs) 188 | 189 | y_logit_t = self.teacher_gnn[-1](xt, edge_index_full, ff=False) 190 | y_logit_s = self.student_gnn[-1](xs, edge_index, ff=False) 191 | 192 | y_logit_t0 = y_logit_t[share_node_idx] 193 | y_logit_s0 = y_logit_s[share_node_idx] 194 | 195 | mt, ms = self.k.parametric(y_logit_t0.detach(), y_logit_s0, detach=True) 196 | gkd_dist_loss = self.k.dist_loss(mt, ms, A) 197 | 198 | mt2, ms2 = self.k.parametric(y_logit_t0, y_logit_s0) 199 | rec_loss = self.k.rec_loss(y_logit_t0.detach(), xt1.detach(), mt2) 200 | rec_loss += self.k.rec_loss(y_logit_s0.detach(), xs1.detach(), ms2) 201 | 202 | y_logit_t_, y_logit_s_ = torch.matmul(y_logit_t, self.teacher_gnn[-1].W), torch.matmul(y_logit_s, self.student_gnn[-1].W) 203 | if self.args.use_kd: 204 | dist_loss = nn.KLDivLoss(reduction='batchmean')(F.log_softmax(y_logit_s[train_idx]/self.args.tau, dim=1), F.softmax(y_logit_t[train_idx].detach()/self.args.tau, dim=1)) 205 | return y_logit_s_, gkd_dist_loss, rec_loss, dist_loss 206 | else: 207 | return y_logit_s_, gkd_dist_loss, rec_loss -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from scipy import sparse as sp 8 | from sklearn.metrics import roc_auc_score, f1_score 9 | 10 | from torch_sparse import SparseTensor 11 | from google_drive_downloader import GoogleDriveDownloader as gdd 12 | 13 | def rand_train_test_idx(label, train_prop=.5, valid_prop=.25, ignore_negative=True): 14 | """ randomly splits label into train/valid/test splits """ 15 | if ignore_negative: 16 | labeled_nodes = torch.where(label != -1)[0] 17 | else: 18 | labeled_nodes = label 19 | 20 | n = labeled_nodes.shape[0] 21 | train_num = int(n * train_prop) 22 | valid_num = int(n * valid_prop) 23 | 24 | perm = torch.as_tensor(np.random.permutation(n)) 25 | 26 | train_indices = perm[:train_num] 27 | val_indices = perm[train_num:train_num + valid_num] 28 | test_indices = perm[train_num + valid_num:] 29 | 30 | if not ignore_negative: 31 | return train_indices, val_indices, test_indices 32 | 33 | train_idx = labeled_nodes[train_indices] 34 | valid_idx = labeled_nodes[val_indices] 35 | test_idx = labeled_nodes[test_indices] 36 | 37 | return train_idx, valid_idx, test_idx 38 | 39 | 40 | def even_quantile_labels(vals, nclasses, verbose=True): 41 | """ partitions vals into nclasses by a quantile based split, 42 | where the first class is less than the 1/nclasses quantile, 43 | second class is less than the 2/nclasses quantile, and so on 44 | 45 | vals is np array 46 | returns an np array of int class labels 47 | """ 48 | label = -1 * np.ones(vals.shape[0], dtype=np.int) 49 | interval_lst = [] 50 | lower = -np.inf 51 | for k in range(nclasses - 1): 52 | upper = np.quantile(vals, (k + 1) / nclasses) 53 | interval_lst.append((lower, upper)) 54 | inds = (vals >= lower) * (vals < upper) 55 | label[inds] = k 56 | lower = upper 57 | label[vals >= lower] = nclasses - 1 58 | interval_lst.append((lower, np.inf)) 59 | if verbose: 60 | print('Class Label Intervals:') 61 | for class_idx, interval in enumerate(interval_lst): 62 | print(f'Class {class_idx}: [{interval[0]}, {interval[1]})]') 63 | return label 64 | 65 | 66 | def to_planetoid(dataset): 67 | """ 68 | Takes in a NCDataset and returns the dataset in H2GCN Planetoid form, as follows: 69 | x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object; 70 | tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object; 71 | allx => the feature vectors of both labeled and unlabeled training instances 72 | (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object; 73 | y => the one-hot labels of the labeled training instances as numpy.ndarray object; 74 | ty => the one-hot labels of the test instances as numpy.ndarray object; 75 | ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object; 76 | graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict 77 | object; 78 | split_idx => The ogb dictionary that contains the train, valid, test splits 79 | """ 80 | split_idx = dataset.get_idx_split('random', 0.25) 81 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 82 | 83 | graph, label = dataset[0] 84 | 85 | label = torch.squeeze(label) 86 | 87 | print("generate x") 88 | x = graph['node_feat'][train_idx].numpy() 89 | x = sp.csr_matrix(x) 90 | 91 | tx = graph['node_feat'][test_idx].numpy() 92 | tx = sp.csr_matrix(tx) 93 | 94 | allx = graph['node_feat'].numpy() 95 | allx = sp.csr_matrix(allx) 96 | 97 | y = F.one_hot(label[train_idx]).numpy() 98 | ty = F.one_hot(label[test_idx]).numpy() 99 | ally = F.one_hot(label).numpy() 100 | 101 | edge_index = graph['edge_index'].T 102 | 103 | graph = defaultdict(list) 104 | 105 | for i in range(0, label.shape[0]): 106 | graph[i].append(i) 107 | 108 | for start_edge, end_edge in edge_index: 109 | graph[start_edge.item()].append(end_edge.item()) 110 | 111 | return x, tx, allx, y, ty, ally, graph, split_idx 112 | 113 | 114 | def to_sparse_tensor(edge_index, edge_feat, num_nodes): 115 | """ converts the edge_index into SparseTensor 116 | """ 117 | num_edges = edge_index.size(1) 118 | 119 | (row, col), N, E = edge_index, num_nodes, num_edges 120 | perm = (col * N + row).argsort() 121 | row, col = row[perm], col[perm] 122 | 123 | value = edge_feat[perm] 124 | adj_t = SparseTensor(row=col, col=row, value=value, 125 | sparse_sizes=(N, N), is_sorted=True) 126 | 127 | # Pre-process some important attributes. 128 | adj_t.storage.rowptr() 129 | adj_t.storage.csr2csc() 130 | 131 | return adj_t 132 | 133 | 134 | def normalize(edge_index): 135 | """ normalizes the edge_index 136 | """ 137 | adj_t = edge_index.set_diag() 138 | deg = adj_t.sum(dim=1).to(torch.float) 139 | deg_inv_sqrt = deg.pow(-0.5) 140 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 141 | adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1) 142 | return adj_t 143 | 144 | 145 | def gen_normalized_adjs(dataset): 146 | """ returns the normalized adjacency matrix 147 | """ 148 | row, col = dataset.graph['edge_index'] 149 | N = dataset.graph['num_nodes'] 150 | adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N)) 151 | deg = adj.sum(dim=1).to(torch.float) 152 | D_isqrt = deg.pow(-0.5) 153 | D_isqrt[D_isqrt == float('inf')] = 0 154 | 155 | DAD = D_isqrt.view(-1,1) * adj * D_isqrt.view(1,-1) 156 | DA = D_isqrt.view(-1,1) * D_isqrt.view(-1,1) * adj 157 | AD = adj * D_isqrt.view(1,-1) * D_isqrt.view(1,-1) 158 | return DAD, DA, AD 159 | 160 | def eval_f1(y_true, y_pred): 161 | acc_list = [] 162 | y_true = y_true.detach().cpu().numpy() 163 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 164 | 165 | for i in range(y_true.shape[1]): 166 | f1 = f1_score(y_true, y_pred, average='micro') 167 | acc_list.append(f1) 168 | 169 | return sum(acc_list)/len(acc_list) 170 | 171 | def eval_acc(y_true, y_pred): 172 | acc_list = [] 173 | y_true = y_true.detach().cpu().numpy() 174 | y_pred = y_pred.argmax(dim=-1, keepdim=True).detach().cpu().numpy() 175 | 176 | for i in range(y_true.shape[1]): 177 | is_labeled = y_true[:, i] == y_true[:, i] 178 | correct = y_true[is_labeled, i] == y_pred[is_labeled, i] 179 | acc_list.append(float(np.sum(correct))/len(correct)) 180 | 181 | return sum(acc_list)/len(acc_list) 182 | 183 | 184 | def eval_rocauc(y_true, y_pred): 185 | """ adapted from ogb 186 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/evaluate.py""" 187 | rocauc_list = [] 188 | y_true = y_true.detach().cpu().numpy() 189 | if y_true.shape[1] == 1: 190 | # use the predicted class for single-class classification 191 | y_pred = F.softmax(y_pred, dim=-1)[:,1].unsqueeze(1).cpu().numpy() 192 | else: 193 | y_pred = y_pred.detach().cpu().numpy() 194 | 195 | for i in range(y_true.shape[1]): 196 | # AUC is only defined when there is at least one positive data. 197 | if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: 198 | is_labeled = y_true[:, i] == y_true[:, i] 199 | score = roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) 200 | 201 | rocauc_list.append(score) 202 | 203 | if len(rocauc_list) == 0: 204 | raise RuntimeError( 205 | 'No positively labeled data available. Cannot compute ROC-AUC.') 206 | 207 | return sum(rocauc_list)/len(rocauc_list) 208 | 209 | 210 | @torch.no_grad() 211 | def evaluate(model, dataset, split_idx, eval_func, criterion, args, result=None, test_dataset = None): 212 | if result is not None: 213 | out = result 214 | else: 215 | model.eval() 216 | out = model.inference(dataset, mode=args.mode) 217 | 218 | if test_dataset == None: 219 | test_dataset = dataset 220 | test_out = out 221 | else: 222 | test_out = model.inference(test_dataset, mode=args.mode) # teacher trained/valid on full dataset, test on mask dataset 223 | 224 | train_acc = eval_func( 225 | dataset.label[dataset.train_idx], out[dataset.train_idx]) 226 | valid_acc = eval_func( 227 | dataset.label[split_idx['valid']], out[split_idx['valid']]) 228 | test_acc = eval_func( 229 | dataset.label[split_idx['test']], test_out[split_idx['test']]) 230 | 231 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 232 | if dataset.label.shape[1] == 1: 233 | true_label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 234 | else: 235 | true_label = dataset.label 236 | valid_loss = criterion(out[split_idx['valid']], true_label.squeeze(1)[ 237 | split_idx['valid']].to(torch.float)) 238 | else: 239 | out = F.log_softmax(out, dim=1) 240 | valid_loss = criterion( 241 | out[split_idx['valid']], dataset.label.squeeze(1)[split_idx['valid']]) 242 | 243 | return train_acc, valid_acc, test_acc, valid_loss, out 244 | 245 | @torch.no_grad() 246 | def evaluate_whole_graph(model, datasets, eval_func): 247 | dataset_tr, dataset_val, dataset_te = datasets[0], datasets[1], datasets[2] 248 | model.eval() 249 | train_out = model(dataset_tr) 250 | train_acc = eval_func(dataset_tr.label, train_out) 251 | valid_out = model(dataset_val) 252 | valid_acc = eval_func(dataset_val.label, valid_out) 253 | test_out = model(dataset_te) 254 | test_acc = eval_func(dataset_te.label, test_out) 255 | 256 | return train_acc, valid_acc, test_acc, test_out 257 | 258 | 259 | def load_fixed_splits(dataset, name, protocol): 260 | 261 | splits_lst = [] 262 | if name in ['cora', 'citeseer', 'pubmed'] and protocol == 'semi': 263 | splits = {} 264 | splits['train'] = torch.as_tensor(dataset.train_idx) 265 | splits['valid'] = torch.as_tensor(dataset.valid_idx) 266 | splits['test'] = torch.as_tensor(dataset.test_idx) 267 | splits_lst.append(splits) 268 | elif name in ['cora', 'citeseer', 'pubmed', 'chameleon', 'squirrel', 'film', 'cornell', 'texas', 'wisconsin']: 269 | for i in range(10): 270 | splits_file_path = '../../data/geom-gcn/splits/{}'.format(name) + '_split_0.6_0.2_'+str(i)+'.npz' 271 | splits = {} 272 | with np.load(splits_file_path) as splits_file: 273 | splits['train'] = torch.BoolTensor(splits_file['train_mask']) 274 | splits['valid'] = torch.BoolTensor(splits_file['val_mask']) 275 | splits['test'] = torch.BoolTensor(splits_file['test_mask']) 276 | splits_lst.append(splits) 277 | else: 278 | raise NotImplementedError 279 | 280 | return splits_lst 281 | 282 | def convert_to_adj(edge_index,n_node): 283 | '''convert from pyg format edge_index to n by n adj matrix''' 284 | adj=torch.zeros((n_node,n_node)) 285 | row,col=edge_index 286 | adj[row,col]=1 287 | return adj 288 | 289 | def remove_edges(edge_index, ratio): 290 | E = edge_index.size(0) 291 | e = int(E * ratio) 292 | idx = torch.randperm(E)[:e] 293 | edge_index_new = edge_index[:, idx] 294 | return edge_index_new 295 | 296 | dataset_drive_url = { 297 | 'snap-patents' : '1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia', 298 | 'pokec' : '1dNs5E7BrWJbgcHeQ_zuy5Ozp2tRCWG0y', 299 | 'yelp-chi': '1fAXtTVQS4CfEk4asqrFw9EPmlUPGbGtJ', 300 | } 301 | 302 | splits_drive_url = { 303 | 'snap-patents' : '12xbBRqd8mtG_XkNLH8dRRNZJvVM4Pw-N', 304 | 'pokec' : '1ZhpAiyTNc0cE_hhgyiqxnkKREHK7MK-_', 305 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os, random 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import time 9 | from torch.utils import data 10 | from torch_geometric.utils import to_undirected, subgraph, add_remaining_self_loops, add_self_loops 11 | from torch_scatter import scatter 12 | 13 | from logger import Logger, SimpleLogger 14 | from dataset import load_nc_dataset 15 | from data_utils import normalize, gen_normalized_adjs, evaluate, eval_acc, eval_rocauc, eval_f1, to_sparse_tensor, \ 16 | load_fixed_splits, remove_edges 17 | from parse import parse_method, parser_add_main_args 18 | 19 | import copy 20 | 21 | torch.autograd.set_detect_anomaly(True) 22 | 23 | # NOTE: data splits are consistent given fixed seed, see data_utils.rand_train_test_idx 24 | def fix_seed(seed): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed(seed) 29 | torch.backends.cudnn.deterministic = True 30 | 31 | ### Parse args ### 32 | parser = argparse.ArgumentParser(description='General Training Pipeline') 33 | parser_add_main_args(parser) 34 | args = parser.parse_args() 35 | print(args) 36 | 37 | fix_seed(args.seed) 38 | 39 | if args.cpu: 40 | device = torch.device("cpu") 41 | else: 42 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 43 | print(device) 44 | 45 | 46 | ### Load and preprocess data ### 47 | dataset = load_nc_dataset(args.dataset, args.sub_dataset, args.data_dir) 48 | 49 | if len(dataset.label.shape) == 1: 50 | dataset.label = dataset.label.unsqueeze(1) 51 | 52 | print(dataset.label.shape) 53 | dataset.label = dataset.label.to(device) 54 | 55 | # get the splits for all runs 56 | if args.rand_split: 57 | split_idx_lst = [dataset.get_idx_split(train_prop=args.train_prop, valid_prop=args.valid_prop) 58 | for _ in range(args.runs)] 59 | elif args.dataset in ['ogbn-proteins', 'ogbn-arxiv', 'ogbn-products']: 60 | split_idx_lst = [dataset.load_fixed_splits() 61 | for _ in range(args.runs)] 62 | else: 63 | split_idx_lst = load_fixed_splits(dataset, name=args.dataset, protocol=args.protocol) 64 | 65 | if args.dataset == 'ogbn-proteins': 66 | if args.method == 'mlp' or args.method == 'cs': 67 | dataset.graph['node_feat'] = scatter(dataset.graph['edge_feat'], dataset.graph['edge_index'][0], 68 | dim=0, dim_size=dataset.graph['num_nodes'], reduce='mean') 69 | else: 70 | dataset.graph['edge_index'] = to_sparse_tensor(dataset.graph['edge_index'], 71 | dataset.graph['edge_feat'], dataset.graph['num_nodes']) 72 | dataset.graph['node_feat'] = dataset.graph['edge_index'].mean(dim=1) 73 | dataset.graph['edge_index'].set_value_(None) 74 | dataset.graph['edge_feat'] = None 75 | 76 | n = dataset.graph['num_nodes'] 77 | # infer the number of classes for non one-hot and one-hot labels 78 | c = max(dataset.label.max().item() + 1, dataset.label.shape[1]) 79 | d = dataset.graph['node_feat'].shape[1] 80 | 81 | # whether or not to symmetrize 82 | if not args.directed and args.dataset != 'ogbn-proteins': 83 | dataset.graph['edge_index'] = to_undirected(dataset.graph['edge_index']) 84 | edge_index_directed = dataset.graph['edge_index'][:, dataset.graph['edge_index'][1,:] >= dataset.graph['edge_index'][0,:] ] 85 | edge_index_directed = edge_index_directed.to(device) 86 | 87 | 88 | print(f"num nodes {n} | num classes {c} | num node feats {d}") 89 | 90 | ### Load method ### 91 | model = parse_method(args, dataset, n, c, d, device) 92 | 93 | # using rocauc as the eval function 94 | if args.dataset in ('yelp-chi', 'deezer-europe', 'twitch-e', 'fb100', 'ogbn-proteins'): 95 | criterion = nn.BCEWithLogitsLoss() 96 | else: 97 | criterion = nn.NLLLoss() 98 | if args.metric == 'rocauc': 99 | eval_func = eval_rocauc 100 | elif args.metric == 'f1': 101 | eval_func = eval_f1 102 | else: 103 | eval_func = eval_acc 104 | 105 | logger = Logger(args.runs, args) 106 | model.train() 107 | print('MODEL:', model) 108 | dataset.graph['edge_index'], dataset.graph['node_feat'] = \ 109 | dataset.graph['edge_index'].to(device), dataset.graph['node_feat'].to(device) 110 | 111 | if args.dataset in ('yelp-chi', 'deezer-europe', 'fb100', 'twitch-e', 'ogbn-proteins'): 112 | if dataset.label.shape[1] == 1: 113 | dataset.label = F.one_hot(dataset.label, dataset.label.max() + 1).squeeze(1) 114 | 115 | dataset_mask = copy.deepcopy(dataset) 116 | 117 | ### Training loop ### 118 | for run in range(args.runs): 119 | if args.dataset in ['cora', 'citeseer', 'pubmed'] and args.protocol == 'semi': 120 | split_idx = split_idx_lst[0] 121 | else: 122 | split_idx = split_idx_lst[run] 123 | train_idx = split_idx['train'].to(device) 124 | dataset.train_idx = train_idx 125 | 126 | # Processing for privileged information in each run 127 | if args.priv_type == 'edge': 128 | num = int(edge_index_directed.size(1) * (1 - args.priv_ratio)) # priv_ratio: information loss ratio 129 | idx = torch.randperm(edge_index_directed.size(1))[:num] 130 | edge_index_share = edge_index_directed[:, idx] 131 | try: dataset_mask.graph['edge_index'] = to_undirected(edge_index_share) 132 | except: dataset_mask.graph['edge_index'] = edge_index_share 133 | dataset_mask.train_idx = train_idx 134 | dataset_mask.share_node_idx = torch.cat([train_idx, split_idx['valid'].to(device), split_idx['test'].to(device)], dim=-1) 135 | 136 | elif args.priv_type == 'node': 137 | train_num = train_idx.shape[0] 138 | num = int((1 - args.priv_ratio) * train_num) # removing certain ratio of train nodes on training node 139 | assert num < train_num 140 | share_train_idx = train_idx[torch.randperm(train_num)[:num]] 141 | share_node_idx = torch.cat([share_train_idx, split_idx['valid'].to(device), split_idx['test'].to(device)], dim=-1) 142 | dataset_mask.graph['edge_index'] = subgraph(share_node_idx, dataset.graph['edge_index'])[0] 143 | dataset_mask.train_idx = share_train_idx 144 | dataset_mask.share_node_idx = share_node_idx 145 | else: 146 | raise NotImplementedError 147 | 148 | model.reset_parameters() 149 | 150 | if args.mode == 'train': # loading teacher model 151 | model_dir = f'saved_models/{args.base_model}_{args.dataset}_{run}.pkl' 152 | if not os.path.exists(model_dir): 153 | raise FileNotFoundError 154 | else: 155 | model_dict = torch.load(model_dir) 156 | if not args.not_load_teacher: 157 | model.teacher_gnn.load_state_dict(model_dict) 158 | 159 | optimizer_te = torch.optim.Adam([{'params': model.teacher_gnn.parameters()}], lr=args.lr, weight_decay=args.weight_decay) 160 | optimizer_st = torch.optim.Adam([{'params': model.student_gnn.parameters()}], lr=args.lr, weight_decay=args.weight_decay) 161 | if args.dist_mode == 'pgkd': optimizer_k = torch.optim.Adam([{'params': model.k.parameters()}], lr=args.lr2, weight_decay=args.weight_decay) 162 | 163 | best_val = float('-inf') 164 | for epoch in range(args.epochs): 165 | model.train() 166 | train_start = time.time() 167 | if args.mode == 'pretrain': 168 | optimizer_te.zero_grad() 169 | out = model(dataset, mode='pretrain') 170 | if args.dataset in ('yelp-chi', 'deezer-europe', 'fb100', 'twitch-e', 'ogbn-proteins'): # binary classification 171 | loss = criterion(out[train_idx], dataset.label.squeeze(1)[train_idx].to(torch.float)) 172 | else: 173 | out = F.log_softmax(out, dim=1) 174 | loss = criterion(out[train_idx], dataset.label.squeeze(1)[train_idx]) 175 | loss.backward() 176 | optimizer_te.step() 177 | 178 | elif args.mode == 'train' and args.dist_mode != 'pgkd': 179 | optimizer_st.zero_grad() 180 | outputs = model(dataset, dataset_mask, mode='train', dist_mode=args.dist_mode, t=args.t) 181 | out = outputs[0] if type(outputs) == tuple else outputs 182 | 183 | if args.dataset in ('yelp-chi', 'deezer-europe', 'fb100', 'twitch-e', 'ogbn-proteins'): 184 | sup_loss = criterion(out[dataset_mask.train_idx], dataset_mask.label.squeeze(1)[dataset_mask.train_idx].to(torch.float)) 185 | else: 186 | out = F.log_softmax(out, dim=1) 187 | sup_loss = criterion(out[dataset_mask.train_idx], dataset_mask.label.squeeze(1)[dataset_mask.train_idx]) 188 | 189 | if args.dist_mode == 'no': loss = sup_loss 190 | elif args.dist_mode == 'gkd' and not args.use_kd: 191 | loss = (1 - args.alpha) * sup_loss + args.alpha * outputs[1] 192 | elif args.dist_mode == 'gkd' and args.use_kd: 193 | loss = (1 - args.alpha) * sup_loss + args.alpha * outputs[1] + args.beta * outputs[2] * args.tau * args.tau 194 | 195 | loss.backward() 196 | optimizer_st.step() 197 | 198 | elif args.mode == 'train' and args.dist_mode == 'pgkd': 199 | outputs = model(dataset, dataset_mask, mode='train', dist_mode=args.dist_mode, t=args.t) 200 | out = outputs[0] if type(outputs) == tuple else outputs 201 | 202 | if args.dataset in ('yelp-chi', 'deezer-europe', 'fb100', 'twitch-e', 'ogbn-proteins'): 203 | sup_loss = criterion(out[dataset_mask.train_idx], dataset_mask.label.squeeze(1)[dataset_mask.train_idx].to(torch.float)) 204 | else: 205 | out = F.log_softmax(out, dim=1) 206 | sup_loss = criterion(out[dataset_mask.train_idx], dataset_mask.label.squeeze(1)[dataset_mask.train_idx]) 207 | 208 | if not args.use_kd: 209 | loss = (1 - args.alpha) * sup_loss + args.alpha * outputs[1] 210 | else: 211 | loss = (1 - args.alpha) * sup_loss + args.alpha * outputs[1] + args.beta * outputs[3] * args.tau * args.tau 212 | 213 | optimizer_k.zero_grad() 214 | rec_loss = outputs[2] 215 | rec_loss.backward(retain_graph=True) 216 | optimizer_k.step() 217 | 218 | optimizer_st.zero_grad() 219 | loss.backward() 220 | optimizer_st.step() 221 | 222 | train_time = time.time() - train_start 223 | 224 | if args.mode == 'pretrain': 225 | if args.oracle: 226 | result = evaluate(model, dataset, split_idx, eval_func, criterion, args, test_dataset=dataset) 227 | else: 228 | result = evaluate(model, dataset, split_idx, eval_func, criterion, args, test_dataset=dataset_mask) 229 | elif args.mode == 'train': 230 | result = evaluate(model, dataset_mask, split_idx, eval_func, criterion, args) 231 | 232 | 233 | logger.add_result(run, result[:-1]) 234 | if result[1] > best_val: 235 | best_val = result[1] 236 | if args.dataset != 'ogbn-proteins': 237 | best_out = F.softmax(result[-1], dim=1) 238 | else: 239 | best_out = result[-1] 240 | if args.mode == 'pretrain' and args.save_model: 241 | torch.save(model.teacher_gnn.state_dict(), f'saved_models/{args.base_model}_{args.dataset}_{run}.pkl') 242 | 243 | if epoch % args.display_step == 0: 244 | print(f'Epoch: {epoch:02d}, ' 245 | f'Loss: {loss:.4f}, ' 246 | f'Train: {100 * result[0]:.2f}%, ' 247 | f'Valid: {100 * result[1]:.2f}%, ' 248 | f'Test: {100 * result[2]:.2f}%') 249 | if args.print_prop: 250 | pred = out.argmax(dim=-1, keepdim=True) 251 | print("Predicted proportions:", pred.unique(return_counts=True)[1].float() / pred.shape[0]) 252 | 253 | results = logger.print_statistics(run) 254 | 255 | results = logger.print_statistics() 256 | 257 | # ### Save results ### 258 | filename = f'logs/{args.dataset}_{args.priv_type}.csv' 259 | print(f"Saving results to {filename}") 260 | with open(f"{filename}", 'a+') as write_obj: 261 | sub_dataset = f'{args.sub_dataset},' if args.sub_dataset else '' 262 | write_obj.write(f"data({args.dataset},{args.priv_type}{args.priv_ratio}), model({args.log_name},{args.base_model},{args.dist_mode}),\ 263 | \t lr({args.lr}), wd({args.weight_decay}), alpha({args.alpha}), t({args.t}), dt({args.delta}) \t") 264 | write_obj.write("perf: {} $\pm$ {}\n".format(format(results.mean(), '.2f'), format(results.std(), '.2f'))) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import scipy 6 | import scipy.io 7 | import csv 8 | import json 9 | from sklearn.preprocessing import label_binarize 10 | import torch_geometric.transforms as T 11 | 12 | from data_utils import rand_train_test_idx, even_quantile_labels, to_sparse_tensor, dataset_drive_url 13 | 14 | from torch_geometric.datasets import Planetoid, Amazon, CoraFull, Coauthor 15 | from torch_geometric.transforms import NormalizeFeatures 16 | from os import path 17 | 18 | from torch_sparse import SparseTensor 19 | from google_drive_downloader import GoogleDriveDownloader as gdd 20 | 21 | import networkx as nx 22 | import scipy.sparse as sp 23 | 24 | from ogb.nodeproppred import NodePropPredDataset 25 | 26 | 27 | class NCDataset(object): 28 | def __init__(self, name, data_dir=None): 29 | """ 30 | based off of ogb NodePropPredDataset 31 | https://github.com/snap-stanford/ogb/blob/master/ogb/nodeproppred/dataset.py 32 | Gives torch tensors instead of numpy arrays 33 | - name (str): name of the dataset 34 | - root (str): root directory to store the dataset folder 35 | - meta_dict: dictionary that stores all the meta-information about data. Default is None, 36 | but when something is passed, it uses its information. Useful for debugging for external contributers. 37 | 38 | Usage after construction: 39 | 40 | split_idx = dataset.get_idx_split() 41 | train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] 42 | graph, label = dataset[0] 43 | 44 | Where the graph is a dictionary of the following form: 45 | dataset.graph = {'edge_index': edge_index, 46 | 'edge_feat': None, 47 | 'node_feat': node_feat, 48 | 'num_nodes': num_nodes} 49 | For additional documentation, see OGB Library-Agnostic Loader https://ogb.stanford.edu/docs/nodeprop/ 50 | 51 | """ 52 | 53 | self.name = name # original name, e.g., ogbn-proteins 54 | self.graph = {} 55 | self.label = None 56 | 57 | def get_idx_split(self, split_type='random', train_prop=.5, valid_prop=.25): 58 | """ 59 | train_prop: The proportion of dataset for train split. Between 0 and 1. 60 | valid_prop: The proportion of dataset for validation split. Between 0 and 1. 61 | """ 62 | 63 | if split_type == 'random': 64 | ignore_negative = False if self.name == 'ogbn-proteins' else True 65 | train_idx, valid_idx, test_idx = rand_train_test_idx( 66 | self.label, train_prop=train_prop, valid_prop=valid_prop, ignore_negative=ignore_negative) 67 | split_idx = {'train': train_idx, 68 | 'valid': valid_idx, 69 | 'test': test_idx} 70 | 71 | return split_idx 72 | 73 | def __getitem__(self, idx): 74 | assert idx == 0, 'This dataset has only one graph' 75 | return self.graph, self.label 76 | 77 | def __len__(self): 78 | return 1 79 | 80 | def __repr__(self): 81 | return '{}({})'.format(self.__class__.__name__, len(self)) 82 | 83 | 84 | def load_nc_dataset(dataname, sub_dataname='', data_dir='../../data/'): 85 | """ Loader for NCDataset 86 | Returns NCDataset 87 | """ 88 | print(dataname) 89 | if dataname == 'twitch-e': 90 | # twitch-explicit graph 91 | if sub_dataname not in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'): 92 | print('Invalid sub_dataname, deferring to DE graph') 93 | sub_dataname = 'DE' 94 | dataset = load_twitch_dataset(sub_dataname, data_dir) 95 | elif dataname == 'fb100': 96 | if sub_dataname not in ('Penn94', 'Amherst41', 'Cornell5', 'Johns Hopkins55', 'Reed98'): 97 | print('Invalid sub_dataname, deferring to Penn94 graph') 98 | sub_dataname = 'Penn94' 99 | dataset = load_fb100_dataset(sub_dataname, data_dir) 100 | elif dataname == 'ogbn-proteins': 101 | dataset = load_proteins_dataset(data_dir) 102 | elif dataname == 'deezer-europe': 103 | dataset = load_deezer_dataset(data_dir) 104 | elif dataname == 'arxiv-year': 105 | dataset = load_arxiv_year_dataset(data_dir) 106 | elif dataname == 'pokec': 107 | dataset = load_pokec_mat(data_dir) 108 | elif dataname == 'snap-patents': 109 | dataset = load_snap_patents_mat(data_dir) 110 | elif dataname == 'yelp-chi': 111 | dataset = load_yelpchi_dataset(data_dir) 112 | elif dataname in ('ogbn-arxiv', 'ogbn-products'): 113 | dataset = load_ogb_dataset(dataname, data_dir) 114 | elif dataname in ('cora', 'citeseer', 'pubmed'): 115 | dataset = load_planetoid_dataset(dataname, data_dir) 116 | elif dataname in ('amazon-computer', 'amazon-photo'): 117 | dataset = load_planetoid_dataset(dataname, data_dir) 118 | elif dataname in ('coauthor-cs', 'coauthor-physics', 'corafull'): 119 | dataset = load_planetoid_dataset(dataname, data_dir) 120 | elif dataname in ('chameleon', 'cornell', 'film', 'squirrel', 'texas', 'wisconsin'): 121 | dataset = load_geom_gcn_dataset(dataname, data_dir) 122 | else: 123 | raise ValueError('Invalid dataname') 124 | return dataset 125 | 126 | 127 | def load_twitch_dataset(lang, data_dir): 128 | assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset' 129 | filepath = data_dir + f"twitch/{lang}" 130 | label = [] 131 | node_ids = [] 132 | src = [] 133 | targ = [] 134 | uniq_ids = set() 135 | with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f: 136 | reader = csv.reader(f) 137 | next(reader) 138 | for row in reader: 139 | node_id = int(row[5]) 140 | # handle FR case of non-unique rows 141 | if node_id not in uniq_ids: 142 | uniq_ids.add(node_id) 143 | label.append(int(row[2] == "True")) 144 | node_ids.append(int(row[5])) 145 | 146 | node_ids = np.array(node_ids, dtype=np.int) 147 | with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f: 148 | reader = csv.reader(f) 149 | next(reader) 150 | for row in reader: 151 | src.append(int(row[0])) 152 | targ.append(int(row[1])) 153 | with open(f"{filepath}/musae_{lang}_features.json", 'r') as f: 154 | j = json.load(f) 155 | src = np.array(src) 156 | targ = np.array(targ) 157 | label = np.array(label) 158 | inv_node_ids = {node_id: idx for (idx, node_id) in enumerate(node_ids)} 159 | reorder_node_ids = np.zeros_like(node_ids) 160 | for i in range(label.shape[0]): 161 | reorder_node_ids[i] = inv_node_ids[i] 162 | 163 | n = label.shape[0] 164 | A = scipy.sparse.csr_matrix((np.ones(len(src)), 165 | (np.array(src), np.array(targ))), 166 | shape=(n, n)) 167 | features = np.zeros((n, 3170)) 168 | for node, feats in j.items(): 169 | if int(node) >= n: 170 | continue 171 | features[int(node), np.array(feats, dtype=int)] = 1 172 | # features = features[:, np.sum(features, axis=0) != 0] # remove zero cols. not need for cross graph task 173 | new_label = label[reorder_node_ids] 174 | label = new_label 175 | 176 | dataset = NCDataset(lang) 177 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long) 178 | node_feat = torch.tensor(features, dtype=torch.float) 179 | num_nodes = node_feat.shape[0] 180 | dataset.graph = {'edge_index': edge_index, 181 | 'edge_feat': None, 182 | 'node_feat': node_feat, 183 | 'num_nodes': num_nodes} 184 | dataset.label = torch.tensor(label) 185 | return dataset 186 | 187 | 188 | def load_fb100_dataset(filename, data_dir): 189 | feature_vals_all = np.empty((0, 6)) 190 | for f in ['Penn94', 'Amherst41', 'Cornell5', 'Johns Hopkins55', 'Reed98']: 191 | mat = scipy.io.loadmat(data_dir + 'facebook100/' + f + '.mat') 192 | metadata = mat['local_info'] 193 | metadata = metadata.astype(np.int) 194 | feature_vals = np.hstack( 195 | (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:])) 196 | feature_vals_all = np.vstack( 197 | (feature_vals_all, feature_vals) 198 | ) 199 | 200 | mat = scipy.io.loadmat(data_dir + 'facebook100/' + filename + '.mat') 201 | A = mat['A'] 202 | metadata = mat['local_info'] 203 | dataset = NCDataset(filename) 204 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long) 205 | metadata = metadata.astype(np.int) 206 | label = metadata[:, 1] - 1 # gender label, -1 means unlabeled 207 | 208 | # make features into one-hot encodings 209 | feature_vals = np.hstack( 210 | (np.expand_dims(metadata[:, 0], 1), metadata[:, 2:])) 211 | features = np.empty((A.shape[0], 0)) 212 | for col in range(feature_vals.shape[1]): 213 | feat_col = feature_vals[:, col] 214 | # feat_onehot = label_binarize(feat_col, classes=np.unique(feat_col)) 215 | feat_onehot = label_binarize(feat_col, classes=np.unique(feature_vals_all[:, col])) 216 | features = np.hstack((features, feat_onehot)) 217 | 218 | node_feat = torch.tensor(features, dtype=torch.float) 219 | num_nodes = metadata.shape[0] 220 | dataset.graph = {'edge_index': edge_index, 221 | 'edge_feat': None, 222 | 'node_feat': node_feat, 223 | 'num_nodes': num_nodes} 224 | dataset.label = torch.tensor(label) 225 | dataset.label = torch.where(dataset.label > 0, 1, 0) 226 | return dataset 227 | 228 | 229 | def load_deezer_dataset(data_dir): 230 | filename = 'deezer-europe' 231 | dataset = NCDataset(filename) 232 | deezer = scipy.io.loadmat(f'{data_dir}/deezer/deezer-europe.mat') 233 | 234 | A, label, features = deezer['A'], deezer['label'], deezer['features'] 235 | edge_index = torch.tensor(A.nonzero(), dtype=torch.long) 236 | node_feat = torch.tensor(features.todense(), dtype=torch.float) 237 | label = torch.tensor(label, dtype=torch.long).squeeze() 238 | num_nodes = label.shape[0] 239 | 240 | dataset.graph = {'edge_index': edge_index, 241 | 'edge_feat': None, 242 | 'node_feat': node_feat, 243 | 'num_nodes': num_nodes} 244 | dataset.label = label 245 | return dataset 246 | 247 | 248 | def load_arxiv_year_dataset(data_dir, nclass=5): 249 | filename = 'arxiv-year' 250 | dataset = NCDataset(filename) 251 | ogb_dataset = NodePropPredDataset(name='ogbn-arxiv', root=f'{data_dir}/ogb') 252 | dataset.graph = ogb_dataset.graph 253 | dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index']) 254 | dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat']) 255 | 256 | label = even_quantile_labels( 257 | dataset.graph['node_year'].flatten(), nclass, verbose=False) 258 | dataset.label = torch.as_tensor(label).reshape(-1, 1) 259 | return dataset 260 | 261 | 262 | def load_proteins_dataset(data_dir): 263 | ogb_dataset = NodePropPredDataset(name='ogbn-proteins', root=f'{data_dir}/ogb') 264 | dataset = NCDataset('ogbn-proteins') 265 | 266 | def protein_orig_split(**kwargs): 267 | split_idx = ogb_dataset.get_idx_split() 268 | return {'train': torch.as_tensor(split_idx['train']), 269 | 'valid': torch.as_tensor(split_idx['valid']), 270 | 'test': torch.as_tensor(split_idx['test'])} 271 | 272 | dataset.load_fixed_splits = protein_orig_split 273 | dataset.graph, dataset.label = ogb_dataset.graph, ogb_dataset.labels 274 | 275 | dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index']) 276 | dataset.graph['edge_feat'] = torch.as_tensor(dataset.graph['edge_feat']) 277 | dataset.label = torch.as_tensor(dataset.label) 278 | return dataset 279 | 280 | 281 | def load_ogb_dataset(name, data_dir): 282 | dataset = NCDataset(name) 283 | ogb_dataset = NodePropPredDataset(name=name, root=f'{data_dir}/ogb') 284 | dataset.graph = ogb_dataset.graph 285 | dataset.graph['edge_index'] = torch.as_tensor(dataset.graph['edge_index']) 286 | dataset.graph['node_feat'] = torch.as_tensor(dataset.graph['node_feat']) 287 | 288 | def ogb_idx_to_tensor(): 289 | split_idx = ogb_dataset.get_idx_split() 290 | tensor_split_idx = {key: torch.as_tensor( 291 | split_idx[key]) for key in split_idx} 292 | return tensor_split_idx 293 | 294 | dataset.load_fixed_splits = ogb_idx_to_tensor # ogb_dataset.get_idx_split 295 | dataset.label = torch.as_tensor(ogb_dataset.labels).reshape(-1, 1) 296 | return dataset 297 | 298 | 299 | def load_pokec_mat(data_dir): 300 | """ requires pokec.mat """ 301 | if not path.exists(f'{data_dir}pokec.mat'): 302 | gdd.download_file_from_google_drive( 303 | file_id=dataset_drive_url['pokec'], \ 304 | dest_path=f'{data_dir}pokec.mat', showsize=True) 305 | 306 | fulldata = scipy.io.loadmat(f'{data_dir}pokec.mat') 307 | 308 | dataset = NCDataset('pokec') 309 | edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long) 310 | node_feat = torch.tensor(fulldata['node_feat']).float() 311 | num_nodes = int(fulldata['num_nodes']) 312 | dataset.graph = {'edge_index': edge_index, 313 | 'edge_feat': None, 314 | 'node_feat': node_feat, 315 | 'num_nodes': num_nodes} 316 | 317 | label = fulldata['label'].flatten() 318 | dataset.label = torch.tensor(label, dtype=torch.long) 319 | 320 | return dataset 321 | 322 | 323 | def load_snap_patents_mat(data_dir, nclass=5): 324 | if not path.exists(f'{data_dir}snap_patents.mat'): 325 | p = dataset_drive_url['snap-patents'] 326 | print(f"Snap patents url: {p}") 327 | gdd.download_file_from_google_drive( 328 | file_id=dataset_drive_url['snap-patents'], \ 329 | dest_path=f'{data_dir}snap_patents.mat', showsize=True) 330 | 331 | fulldata = scipy.io.loadmat(f'{data_dir}snap_patents.mat') 332 | 333 | dataset = NCDataset('snap_patents') 334 | edge_index = torch.tensor(fulldata['edge_index'], dtype=torch.long) 335 | node_feat = torch.tensor( 336 | fulldata['node_feat'].todense(), dtype=torch.float) 337 | num_nodes = int(fulldata['num_nodes']) 338 | dataset.graph = {'edge_index': edge_index, 339 | 'edge_feat': None, 340 | 'node_feat': node_feat, 341 | 'num_nodes': num_nodes} 342 | 343 | years = fulldata['years'].flatten() 344 | label = even_quantile_labels(years, nclass, verbose=False) 345 | dataset.label = torch.tensor(label, dtype=torch.long) 346 | 347 | return dataset 348 | 349 | 350 | def load_yelpchi_dataset(data_dir): 351 | if not path.exists(f'{data_dir}YelpChi.mat'): 352 | gdd.download_file_from_google_drive( 353 | file_id=dataset_drive_url['yelp-chi'], \ 354 | dest_path=f'{data_dir}YelpChi.mat', showsize=True) 355 | fulldata = scipy.io.loadmat(f'{data_dir}YelpChi.mat') 356 | A = fulldata['homo'] 357 | edge_index = np.array(A.nonzero()) 358 | node_feat = fulldata['features'] 359 | label = np.array(fulldata['label'], dtype=np.int).flatten() 360 | num_nodes = node_feat.shape[0] 361 | 362 | dataset = NCDataset('YelpChi') 363 | edge_index = torch.tensor(edge_index, dtype=torch.long) 364 | node_feat = torch.tensor(node_feat.todense(), dtype=torch.float) 365 | dataset.graph = {'edge_index': edge_index, 366 | 'node_feat': node_feat, 367 | 'edge_feat': None, 368 | 'num_nodes': num_nodes} 369 | label = torch.tensor(label, dtype=torch.long) 370 | dataset.label = label 371 | return dataset 372 | 373 | 374 | def load_planetoid_dataset(name, data_dir): 375 | transform = T.NormalizeFeatures() 376 | if name in ('cora', 'citeseer', 'pubmed'): 377 | torch_dataset = Planetoid(root=f'{data_dir}Planetoid', 378 | name=name, transform=transform) 379 | elif name == 'amazon-computer': 380 | torch_dataset = Amazon(root=f'{data_dir}Amazon', name='Computers', transform=transform) 381 | elif name == 'amazon-photo': 382 | torch_dataset = Amazon(root=f'{data_dir}Amazon', name='Photo', transform=transform) 383 | elif name == 'coauthor-cs': 384 | torch_dataset = Coauthor(root=f'{data_dir}Coauthor', name='CS', transform=transform) 385 | elif name == 'coauthor-physics': 386 | torch_dataset = Coauthor(root=f'{data_dir}Coauthor', name='Physics', transform=transform) 387 | elif name == 'cora-full': 388 | torch_dataset = CoraFull(root=f'{data_dir}CoraFull', transform=transform) 389 | data = torch_dataset[0] 390 | 391 | edge_index = data.edge_index 392 | node_feat = data.x 393 | label = data.y 394 | num_nodes = data.num_nodes 395 | print(f"Num nodes: {num_nodes}") 396 | 397 | dataset = NCDataset(name) 398 | 399 | if name in ('cora', 'citeseer', 'pubmed'): 400 | dataset.train_idx = torch.where(data.train_mask)[0] 401 | dataset.valid_idx = torch.where(data.val_mask)[0] 402 | dataset.test_idx = torch.where(data.test_mask)[0] 403 | 404 | dataset.graph = {'edge_index': edge_index, 405 | 'node_feat': node_feat, 406 | 'edge_feat': None, 407 | 'num_nodes': num_nodes} 408 | dataset.label = label 409 | 410 | return dataset 411 | 412 | def load_geom_gcn_dataset(name, data_dir): 413 | graph_adjacency_list_file_path = f'{data_dir}geom-gcn/{name}/out1_graph_edges.txt' 414 | graph_node_features_and_labels_file_path = f'{data_dir}geom-gcn/{name}/out1_node_feature_label.txt' 415 | 416 | G = nx.DiGraph() 417 | graph_node_features_dict = {} 418 | graph_labels_dict = {} 419 | 420 | if name == 'film': 421 | with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: 422 | graph_node_features_and_labels_file.readline() 423 | for line in graph_node_features_and_labels_file: 424 | line = line.rstrip().split('\t') 425 | assert (len(line) == 3) 426 | assert (int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict) 427 | feature_blank = np.zeros(932, dtype=np.uint8) 428 | feature_blank[np.array(line[1].split(','), dtype=np.uint16)] = 1 429 | graph_node_features_dict[int(line[0])] = feature_blank 430 | graph_labels_dict[int(line[0])] = int(line[2]) 431 | else: 432 | with open(graph_node_features_and_labels_file_path) as graph_node_features_and_labels_file: 433 | graph_node_features_and_labels_file.readline() 434 | for line in graph_node_features_and_labels_file: 435 | line = line.rstrip().split('\t') 436 | assert (len(line) == 3) 437 | assert (int(line[0]) not in graph_node_features_dict and int(line[0]) not in graph_labels_dict) 438 | graph_node_features_dict[int(line[0])] = np.array(line[1].split(','), dtype=np.uint8) 439 | graph_labels_dict[int(line[0])] = int(line[2]) 440 | 441 | with open(graph_adjacency_list_file_path) as graph_adjacency_list_file: 442 | graph_adjacency_list_file.readline() 443 | for line in graph_adjacency_list_file: 444 | line = line.rstrip().split('\t') 445 | assert (len(line) == 2) 446 | if int(line[0]) not in G: 447 | G.add_node(int(line[0]), features=graph_node_features_dict[int(line[0])], 448 | label=graph_labels_dict[int(line[0])]) 449 | if int(line[1]) not in G: 450 | G.add_node(int(line[1]), features=graph_node_features_dict[int(line[1])], 451 | label=graph_labels_dict[int(line[1])]) 452 | G.add_edge(int(line[0]), int(line[1])) 453 | 454 | adj = nx.adjacency_matrix(G, sorted(G.nodes())) 455 | adj = sp.coo_matrix(adj) 456 | adj = adj + sp.eye(adj.shape[0]) 457 | adj = adj.tocoo().astype(np.float32) 458 | features = np.array( 459 | [features for _, features in sorted(G.nodes(data='features'), key=lambda x: x[0])]) 460 | labels = np.array( 461 | [label for _, label in sorted(G.nodes(data='label'), key=lambda x: x[0])]) 462 | print(features.shape) 463 | 464 | def preprocess_features(feat): 465 | """Row-normalize feature matrix and convert to tuple representation""" 466 | rowsum = np.array(feat.sum(1)) 467 | rowsum = (rowsum == 0) * 1 + rowsum 468 | r_inv = np.power(rowsum, -1).flatten() 469 | r_inv[np.isinf(r_inv)] = 0. 470 | r_mat_inv = sp.diags(r_inv) 471 | feat = r_mat_inv.dot(feat) 472 | return feat 473 | 474 | features = preprocess_features(features) 475 | 476 | edge_index = torch.from_numpy( 477 | np.vstack((adj.row, adj.col)).astype(np.int64)) 478 | node_feat = torch.FloatTensor(features) 479 | labels = torch.LongTensor(labels) 480 | num_nodes = node_feat.shape[0] 481 | print(f"Num nodes: {num_nodes}") 482 | 483 | dataset = NCDataset(name) 484 | 485 | dataset.graph = {'edge_index': edge_index, 486 | 'node_feat': node_feat, 487 | 'edge_feat': None, 488 | 'num_nodes': num_nodes} 489 | dataset.label = labels 490 | 491 | return dataset 492 | 493 | 494 | 495 | --------------------------------------------------------------------------------