├── graphsaint ├── __init__.py ├── norm_aggr.cp36-win_amd64.pyd ├── cython_utils.cp36-win_amd64.pyd ├── cython_sampler.cp36-win_amd64.pyd ├── __pycache__ │ └── __init__.cpython-36.pyc ├── setup.py ├── norm_aggr.pyx ├── cython_utils.pxd ├── cython_utils.pyx └── cython_sampler.pyx ├── utility ├── __pycache__ │ ├── globals.cpython-36.pyc │ ├── globals.cpython-39.pyc │ ├── metric.cpython-36.pyc │ ├── utils.cpython-36.pyc │ └── graph_samplers.cpython-36.pyc ├── metric.py ├── globals.py ├── graph_samplers.py └── utils.py ├── data └── toy │ └── readme.txt ├── train_config └── toy.yml ├── embedding ├── user_model.py ├── text_model.py ├── minibatch.py └── layers.py ├── README.md └── train.py /graphsaint/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /graphsaint/norm_aggr.cp36-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/graphsaint/norm_aggr.cp36-win_amd64.pyd -------------------------------------------------------------------------------- /graphsaint/cython_utils.cp36-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/graphsaint/cython_utils.cp36-win_amd64.pyd -------------------------------------------------------------------------------- /utility/__pycache__/globals.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/utility/__pycache__/globals.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/globals.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/utility/__pycache__/globals.cpython-39.pyc -------------------------------------------------------------------------------- /utility/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/utility/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/utility/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /graphsaint/cython_sampler.cp36-win_amd64.pyd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/graphsaint/cython_sampler.cp36-win_amd64.pyd -------------------------------------------------------------------------------- /data/toy/readme.txt: -------------------------------------------------------------------------------- 1 | Please download the toy dataset: 2 | https://drive.google.com/drive/folders/18IwOQ7hc0S6QaOQxdp7AIHhZezzMZ0CU?usp=sharing 3 | -------------------------------------------------------------------------------- /graphsaint/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/graphsaint/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utility/__pycache__/graph_samplers.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingsumq/Us-DeFake/HEAD/utility/__pycache__/graph_samplers.cpython-36.pyc -------------------------------------------------------------------------------- /train_config/toy.yml: -------------------------------------------------------------------------------- 1 | network: 2 | - dim: 512 3 | aggr: 'concat' 4 | loss: 'softmax' 5 | arch: '1-1-0' 6 | act: 'relu' 7 | bias: 'norm' 8 | params: 9 | - lr: 0.01 10 | dropout: 0.1 11 | weight_decay: 0.0 12 | sample_coverage: 50 13 | phase: 14 | - end: 30 15 | sampler: 'rw' 16 | depth: 2 17 | num_root: 3000 18 | -------------------------------------------------------------------------------- /graphsaint/setup.py: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | from distutils.core import setup, Extension 3 | from Cython.Build import cythonize 4 | import numpy 5 | # import cython_utils 6 | 7 | import os 8 | os.environ["CC"] = "g++" 9 | os.environ["CXX"] = "g++" 10 | 11 | setup(ext_modules = cythonize(["graphsaint/cython_sampler.pyx","graphsaint/cython_utils.pyx","graphsaint/norm_aggr.pyx"]), include_dirs = [numpy.get_include()]) 12 | # to compile: python graphsaint/setup.py build_ext --inplace 13 | -------------------------------------------------------------------------------- /utility/metric.py: -------------------------------------------------------------------------------- 1 | from sklearn import metrics 2 | import numpy as np 3 | from sklearn.metrics import precision_recall_fscore_support 4 | 5 | def evaluation(y_true, y_pred,is_sigmoid): 6 | if not is_sigmoid: 7 | y_true = np.argmax(y_true, axis=1) 8 | y_pred = np.argmax(y_pred, axis=1) 9 | else: 10 | y_pred[y_pred > 0.5] = 1 11 | y_pred[y_pred <= 0.5] = 0 12 | a = metrics.accuracy_score(y_true, y_pred) 13 | p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average='macro') 14 | return a, p, r, f 15 | 16 | 17 | -------------------------------------------------------------------------------- /graphsaint/norm_aggr.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -fopenmp -std=c++11 4 | # distutils: extra_link_args = -fopenmp 5 | 6 | import numpy as np 7 | cimport numpy as np 8 | cimport cython 9 | from cython.parallel import prange,parallel 10 | from libcpp.map cimport map 11 | from libcpp.unordered_map cimport unordered_map 12 | from libcpp.vector cimport vector 13 | from libcpp.string cimport string 14 | from cython.operator cimport dereference, preincrement 15 | 16 | def norm_aggr(data,edge_index,norm_aggr,num_proc=20): 17 | cdef int num_proc_view=num_proc 18 | cdef float [:] data_view=data 19 | cdef int length=data.shape[0] 20 | cdef int [:] edge_index_view=edge_index 21 | cdef float [:] norm_aggr_view=norm_aggr 22 | cdef int i 23 | for i in prange(length,schedule='static',nogil=True,num_threads=num_proc_view): 24 | data_view[i]=norm_aggr_view[edge_index_view[i]] -------------------------------------------------------------------------------- /graphsaint/cython_utils.pxd: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -fopenmp -std=c++11 4 | # distutils: extra_link_args = -fopenmp 5 | 6 | from libcpp.vector cimport vector 7 | cimport cython 8 | import numpy as np 9 | cimport numpy as np 10 | 11 | 12 | cdef extern from "" namespace "std" nogil: 13 | T move[T](T) 14 | 15 | 16 | cdef class array_wrapper_float: 17 | cdef vector[float] vec 18 | cdef Py_ssize_t shape[1] 19 | cdef Py_ssize_t strides[1] 20 | cdef void set_data(self,vector[float]& data) 21 | 22 | cdef class array_wrapper_int: 23 | cdef vector[int] vec 24 | cdef Py_ssize_t shape[1] 25 | cdef Py_ssize_t strides[1] 26 | cdef void set_data(self,vector[int]& data) 27 | 28 | cdef inline void npy2vec_int(np.ndarray[int,ndim=1,mode='c'] nda, vector[int]& vec): 29 | cdef int size = nda.size 30 | cdef int* vec_c = &(nda[0]) 31 | vec.assign(vec_c,vec_c+size) 32 | 33 | cdef inline void npy2vec_float(np.ndarray[float,ndim=1,mode='c'] nda, vector[float]& vec): 34 | cdef int size = nda.size 35 | cdef float* vec_c = &(nda[0]) 36 | vec.assign(vec_c,vec_c+size) 37 | 38 | -------------------------------------------------------------------------------- /graphsaint/cython_utils.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -fopenmp -std=c++11 4 | # distutils: extra_link_args = -fopenmp 5 | cimport cython 6 | from cython.parallel import prange,parallel 7 | from cython.operator import dereference, postincrement 8 | from cython cimport Py_buffer 9 | from libcpp.vector cimport vector 10 | from libcpp.map cimport map 11 | from libcpp.utility cimport pair 12 | import numpy as np 13 | cimport numpy as np 14 | from libc.stdio cimport printf 15 | import time 16 | 17 | # reference: https://stackoverflow.com/questions/45133276/passing-c-vector-to-numpy-through-cython-without-copying-and-taking-care-of-me 18 | cdef class array_wrapper_float: 19 | 20 | cdef void set_data(self, vector[float]& data): 21 | self.vec = move(data) 22 | 23 | # now implement the buffer protocol for the class 24 | # which makes it generally useful to anything that expects an array 25 | def __getbuffer__(self, Py_buffer *buffer, int flags): 26 | # relevant documentation http://cython.readthedocs.io/en/latest/src/userguide/buffer.html#a-matrix-class 27 | cdef Py_ssize_t itemsize = sizeof(self.vec[0]) 28 | self.shape[0] = self.vec.size() 29 | self.strides[0] = sizeof(float) 30 | buffer.buf = &(self.vec[0]) 31 | buffer.format = 'f' 32 | buffer.internal = NULL 33 | buffer.itemsize = itemsize 34 | buffer.len = self.vec.size() * itemsize 35 | buffer.ndim = 1 36 | buffer.obj = self 37 | buffer.readonly = 0 38 | buffer.shape = self.shape 39 | buffer.strides = self.strides 40 | buffer.suboffsets = NULL 41 | 42 | def __releasebuffer__(self,Py_buffer *buffer): 43 | pass 44 | 45 | 46 | cdef class array_wrapper_int: 47 | 48 | cdef void set_data(self, vector[int]& data): 49 | self.vec = move(data) 50 | 51 | def __getbuffer__(self, Py_buffer *buffer, int flags): 52 | # relevant documentation http://cython.readthedocs.io/en/latest/src/userguide/buffer.html#a-matrix-class 53 | cdef Py_ssize_t itemsize = sizeof(self.vec[0]) 54 | self.shape[0] = self.vec.size() 55 | self.strides[0] = sizeof(int) 56 | buffer.buf = &(self.vec[0]) 57 | buffer.format = 'i' 58 | buffer.internal = NULL 59 | buffer.itemsize = itemsize 60 | buffer.len = self.vec.size() * itemsize 61 | buffer.ndim = 1 62 | buffer.obj = self 63 | buffer.readonly = 0 64 | buffer.shape = self.shape 65 | buffer.strides = self.strides 66 | buffer.suboffsets = NULL 67 | 68 | def __releasebuffer__(self,Py_buffer *buffer): 69 | pass 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /utility/globals.py: -------------------------------------------------------------------------------- 1 | import os,sys,time,datetime 2 | import argparse 3 | import subprocess 4 | 5 | git_rev = subprocess.Popen("git rev-parse --short HEAD", shell=True, stdout=subprocess.PIPE, universal_newlines=True).communicate()[0] 6 | git_branch = subprocess.Popen("git symbolic-ref --short -q HEAD", shell=True, stdout=subprocess.PIPE, universal_newlines=True).communicate()[0] 7 | 8 | timestamp = time.time() 9 | timestamp = datetime.datetime.fromtimestamp(int(timestamp)).strftime('%Y-%m-%d %H-%M-%S') 10 | 11 | 12 | parser = argparse.ArgumentParser(description="argument for GraphSAINT training") 13 | parser.add_argument("--num_cpu_core",default=20,type=int,help="Number of CPU cores for parallel sampling") 14 | parser.add_argument("--log_device_placement",default=False,action="store_true",help="Whether to log device placement") 15 | parser.add_argument("--data_prefix",default="./data/toy",type=str,help="prefix identifying training data") 16 | parser.add_argument("--fold",default="1",type=str,help="k-fold") 17 | parser.add_argument("--dir_log",default=".",type=str,help="base directory for logging and saving embeddings") 18 | parser.add_argument("--gpu",default="-1",type=str,help="Whether use GPU") 19 | parser.add_argument("--eval_train_every",default=15,type=int,help="How often to evaluate training subgraph accuracy") 20 | parser.add_argument("--train_config",default="./train_config/toy.yml",type=str,help="path to the configuration of training (*.yml)") 21 | parser.add_argument("--dtype",default="s",type=str,help="d for double, s for single precision floating point") 22 | parser.add_argument("--timeline",default=False,action="store_true",help="to save timeline.json or not") 23 | parser.add_argument("--tensorboard",default=False,action="store_true",help="to save data to tensorboard or not") 24 | parser.add_argument("--cpu_eval",default=False,action="store_true",help="whether to use CPU to do evaluation") 25 | parser.add_argument("--saved_model_path",default="./",type=str,help="path to pretrained model file") 26 | args_global = parser.parse_args() 27 | 28 | 29 | NUM_PAR_SAMPLER = args_global.num_cpu_core 30 | SAMPLES_PER_PROC = -(-200 // NUM_PAR_SAMPLER) # round up division 31 | 32 | EVAL_VAL_EVERY_EP = 1 # get accuracy on the validation set every this # epochs 33 | 34 | 35 | # auto choosing available NVIDIA GPU 36 | gpu_selected = args_global.gpu 37 | if gpu_selected == '-1234': 38 | # auto detect gpu by filtering on the nvidia-smi command output 39 | gpu_stat = subprocess.Popen("nvidia-smi",shell=True,stdout=subprocess.PIPE,universal_newlines=True).communicate()[0] 40 | gpu_avail = set([str(i) for i in range(8)]) 41 | for line in gpu_stat.split('\n'): 42 | if 'python' in line: 43 | if line.split()[1] in gpu_avail: 44 | gpu_avail.remove(line.split()[1]) 45 | if len(gpu_avail) == 0: 46 | gpu_selected = -2 47 | else: 48 | gpu_selected = sorted(list(gpu_avail))[0] 49 | if gpu_selected == -1: 50 | gpu_selected = '0' 51 | args_global.gpu = int(gpu_selected) 52 | if str(gpu_selected).startswith('nvlink'): 53 | os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_selected).split('nvlink')[1] 54 | elif int(gpu_selected) >= 0: 55 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 56 | os.environ["CUDA_VISIBLE_DEVICES"]=str(gpu_selected) 57 | GPU_MEM_FRACTION = 0.8 58 | else: 59 | os.environ["CUDA_VISIBLE_DEVICES"]="-1" 60 | args_global.gpu = int(args_global.gpu) 61 | 62 | # global vars 63 | 64 | f_mean = lambda l: sum(l)/len(l) 65 | 66 | DTYPE = "float32" if args_global.dtype=='s' else "float64" # NOTE: currently not supporting float64 yet 67 | -------------------------------------------------------------------------------- /utility/graph_samplers.py: -------------------------------------------------------------------------------- 1 | from utility.globals import * 2 | import numpy as np 3 | import scipy.sparse 4 | import graphsaint.cython_sampler as cy 5 | 6 | 7 | class GraphSampler: 8 | def __init__(self, adj_train, node_train, size_subgraph, args_preproc): 9 | self.adj_train = adj_train 10 | self.node_train = np.unique(node_train).astype(np.int32) 11 | # size in terms of number of vertices in subgraph 12 | self.size_subgraph = size_subgraph 13 | self.name_sampler = 'None' 14 | self.node_subgraph = None 15 | self.preproc(**args_preproc) 16 | 17 | def preproc(self, **kwargs): 18 | pass 19 | 20 | def par_sample(self, stage, **kwargs): 21 | return self.cy_sampler.par_sample() 22 | 23 | 24 | class rw_sampling(GraphSampler): 25 | def __init__(self, adj_train, node_train, size_subgraph, size_root, size_depth): 26 | self.size_root = size_root 27 | self.size_depth = size_depth 28 | size_subgraph = size_root * size_depth 29 | super().__init__(adj_train, node_train, size_subgraph, {}) 30 | self.cy_sampler = cy.RW( 31 | self.adj_train.indptr, 32 | self.adj_train.indices, 33 | self.node_train, 34 | NUM_PAR_SAMPLER, 35 | SAMPLES_PER_PROC, 36 | self.size_root, 37 | self.size_depth 38 | ) 39 | 40 | def preproc(self, **kwargs): 41 | pass 42 | 43 | 44 | class edge_sampling(GraphSampler): 45 | def __init__(self,adj_train,node_train,num_edges_subgraph): 46 | """ 47 | The sampler picks edges from the training graph independently, following 48 | a pre-computed edge probability distribution. i.e., 49 | p_{u,v} \\propto 1 / deg_u + 1 / deg_v 50 | Such prob. dist. is derived to minimize the variance of the minibatch 51 | estimator (see Thm 3.2 of the GraphSAINT paper). 52 | """ 53 | self.num_edges_subgraph = num_edges_subgraph 54 | # num subgraph nodes may not be num_edges_subgraph * 2 in many cases, 55 | # but it is not too important to have an accurate estimation of subgraph 56 | # size. So it's probably just fine to use this number. 57 | self.size_subgraph = num_edges_subgraph * 2 58 | self.deg_train = np.array(adj_train.sum(1)).flatten() 59 | self.adj_train_norm = scipy.sparse.dia_matrix((1 / self.deg_train, 0), shape=adj_train.shape).dot(adj_train) 60 | super().__init__(adj_train, node_train, self.size_subgraph, {}) 61 | self.cy_sampler = cy.Edge2( 62 | self.adj_train.indptr, 63 | self.adj_train.indices, 64 | self.node_train, 65 | NUM_PAR_SAMPLER, 66 | SAMPLES_PER_PROC, 67 | self.edge_prob_tri.row, 68 | self.edge_prob_tri.col, 69 | self.edge_prob_tri.data.cumsum(), 70 | self.num_edges_subgraph, 71 | ) 72 | 73 | def preproc(self,**kwargs): 74 | """ 75 | Compute the edge probability distribution p_{u,v}. 76 | """ 77 | self.edge_prob = scipy.sparse.csr_matrix( 78 | ( 79 | np.zeros(self.adj_train.size), 80 | self.adj_train.indices, 81 | self.adj_train.indptr 82 | ), 83 | shape=self.adj_train.shape, 84 | ) 85 | self.edge_prob.data[:] = self.adj_train_norm.data[:] 86 | _adj_trans = scipy.sparse.csr_matrix.tocsc(self.adj_train_norm) 87 | self.edge_prob.data += _adj_trans.data # P_e \propto a_{u,v} + a_{v,u} 88 | self.edge_prob.data *= 2 * self.num_edges_subgraph / self.edge_prob.data.sum() 89 | # now edge_prob is a symmetric matrix, we only keep the 90 | # upper triangle part, since adj is assumed to be undirected. 91 | self.edge_prob_tri = scipy.sparse.triu(self.edge_prob).astype(np.float32) # NOTE: in coo format 92 | 93 | -------------------------------------------------------------------------------- /embedding/user_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utility.utils import * 6 | import embedding.layers as layers 7 | 8 | 9 | class UnGraphSAINT(nn.Module): 10 | def __init__(self, arch_gcn, train_params, feat_full, cpu_eval=False): 11 | super(UnGraphSAINT,self).__init__() 12 | self.use_cuda = (args_global.gpu >= 0) 13 | if cpu_eval: 14 | self.use_cuda=False 15 | if "attention" in arch_gcn: 16 | if "gated_attention" in arch_gcn: 17 | if arch_gcn['gated_attention']: 18 | self.aggregator_cls = layers.GatedAttentionAggregator 19 | self.mulhead = int(arch_gcn['attention']) 20 | else: 21 | self.aggregator_cls = layers.AttentionAggregator 22 | self.mulhead = int(arch_gcn['attention']) 23 | else: 24 | self.aggregator_cls = layers.HighOrderAggregator 25 | self.mulhead = 1 26 | self.num_layers = len(arch_gcn['arch'].split('-')) 27 | self.weight_decay = train_params['weight_decay'] 28 | self.dropout = train_params['dropout'] 29 | self.lr = train_params['lr'] 30 | self.arch_gcn = arch_gcn 31 | self.out_dim = arch_gcn['dim'] 32 | self.feat_full = torch.from_numpy(feat_full.astype(np.float32)) 33 | if self.use_cuda: 34 | self.feat_full = self.feat_full.cuda() 35 | _dims, self.order_layer, self.act_layer, self.bias_layer, self.aggr_layer \ 36 | = parse_layer_yml(arch_gcn, self.feat_full.shape[1]) 37 | # get layer index for each conv layer, useful for jk net last layer aggregation 38 | self.set_idx_conv() 39 | self.set_dims(_dims) 40 | self.opt_op = None 41 | 42 | # build the model below 43 | self.num_params = 0 44 | self.aggregators, num_param = self.get_aggregators() 45 | self.num_params += num_param 46 | self.conv_layers = nn.Sequential(*self.aggregators) 47 | self.classifier = layers.HighOrderAggregator(self.dims_feat[-1], self.out_dim,\ 48 | act='I', order=0, dropout=self.dropout, bias='bias') 49 | self.num_params += self.classifier.num_param 50 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 51 | 52 | def set_dims(self, dims): 53 | """ 54 | Set the feature dimension / weight dimension for each GNN or MLP layer. 55 | We will use the dimensions set here to initialize PyTorch layers. 56 | 57 | Inputs: 58 | dims list, length of node feature for each hidden layer 59 | 60 | Outputs: 61 | None 62 | """ 63 | self.dims_feat = [dims[0]] + [ 64 | ((self.aggr_layer[l]=='concat') * self.order_layer[l] + 1) * dims[l+1] 65 | for l in range(len(dims) - 1) 66 | ] 67 | self.dims_weight = [(self.dims_feat[l],dims[l+1]) for l in range(len(dims)-1)] 68 | 69 | def set_idx_conv(self): 70 | """ 71 | Set the index of GNN layers for the full neural net. For example, if 72 | the full NN is having 1-0-1-0 arch (1-hop graph conv, followed by 0-hop 73 | MLP, ...). Then the layer indices will be 0, 2. 74 | """ 75 | idx_conv = np.where(np.array(self.order_layer) >= 1)[0] 76 | idx_conv = list(idx_conv[1:] - 1) 77 | idx_conv.append(len(self.order_layer) - 1) 78 | _o_arr = np.array(self.order_layer)[idx_conv] 79 | if np.prod(np.ediff1d(_o_arr)) == 0: 80 | self.idx_conv = idx_conv 81 | else: 82 | self.idx_conv = list(np.where(np.array(self.order_layer) == 1)[0]) 83 | 84 | 85 | def forward(self, node_subgraph, adj_subgraph): 86 | feat_subg = self.feat_full[node_subgraph] # 读取子图节点属性 87 | _, emb_subg = self.conv_layers((adj_subgraph, feat_subg)) # 处理子图属性 88 | emb_subg_norm = F.normalize(emb_subg, p=2, dim=1) 89 | pred_subg = self.classifier((None, emb_subg_norm))[1] #默认调用了layer中的HighOrderAggregator 90 | return pred_subg 91 | 92 | def decoder(self, pred_subg): 93 | # pred_subg = pred_subg.to_sparse() 94 | adj_rec = torch.sigmoid(torch.matmul(pred_subg.cpu(), pred_subg.cpu().t())) # 解码器点乘还原邻接矩阵A' 95 | return adj_rec 96 | 97 | def rec_loss(self, adj_subg, pred_subg, norm_loss): 98 | rec_subg = self.decoder(pred_subg) 99 | mse = torch.nn.MSELoss() 100 | _rl = mse(adj_subg.cpu().to_dense(), rec_subg.cpu()) 101 | 102 | return (norm_loss*_rl).sum() 103 | 104 | 105 | def get_aggregators(self): 106 | """ 107 | Return a list of aggregator instances. to be used in self.build() 108 | """ 109 | num_param = 0 110 | aggregators = [] 111 | for l in range(self.num_layers): 112 | aggr = self.aggregator_cls( 113 | *self.dims_weight[l], 114 | dropout=self.dropout, 115 | act=self.act_layer[l], 116 | order=self.order_layer[l], 117 | aggr=self.aggr_layer[l], 118 | bias=self.bias_layer[l], 119 | mulhead=self.mulhead, 120 | ) 121 | num_param += aggr.num_param 122 | aggregators.append(aggr) 123 | return aggregators, num_param 124 | 125 | 126 | def predict(self, preds): #对embedding按行进行softmax,使得每行所有数值加起来等于1,以进行分类 127 | return F.softmax(preds, dim=1) 128 | 129 | 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mining User-aware Multi-relations for Fake News Detection in Large Scale Online Social Networks 2 | 3 | This repository is the official PyTorch implementation of Us-DeFake in the paper: 4 | 5 | Mining User-aware Multi-relations for Fake News Detection in Large Scale Online Social Networks, accepted by [*the the 16th ACM International Conference on Web Search and Data Mining*](https://www.wsdm-conference.org/2023/program/accepted-papers) (WSDM '23) [[arXiv](https://arxiv.org/pdf/2212.10778.pdf)]. 6 | 7 | 8 | ## Dependencies 9 | 10 | * python >= 3.6.8 11 | * pytorch >= 1.1.0 12 | * cython >=0.29.2 13 | * numpy >= 1.14.3 14 | * scipy >= 1.1.0 15 | * scikit-learn >= 0.19.1 16 | * pyyaml >= 3.12 17 | * g++ >= 5.4.0 18 | * openmp >= 4.0 19 | 20 | 21 | ## Datasets 22 | 23 | To show the input formats of datasets, we give an example dataset "toy" in /data/ directory. The toy dataset is just used to show the input format, it's not suitable for experiments. The structure of the /data/toy/ directory should be as follows.[Download Example Dataset](https://drive.google.com/drive/folders/18IwOQ7hc0S6QaOQxdp7AIHhZezzMZ0CU?usp=sharing) 24 | 25 | ``` 26 | data/ 27 | │ 28 | └───toy/ 29 | │ class_map.json 30 | │ post_graph.txt 31 | │ text_adj_full.npz 32 | │ text_feats.npy 33 | │ user_adj_full.npz 34 | │ user_feats.npy 35 | │ user_graph 36 | │ user_post_graph.txt 37 | └───1/ 38 | │ text_adj_train.npz 39 | │ text_role.json 40 | │ user_adj_train.npz 41 | └─── user_role.json 42 | ``` 43 | * `class_map.json`: a dictionary of length N. Each key is a node index, and each value is 0 (real news) or 1 (fake news). 44 | * `post_graph.txt`: propagation graph of news. It is not input to Us-DeFake, but is intended to intuitively show the contents of the text_adj_*.npz file of news. 45 | * `text_adj_full.npz`: a sparse matrix in CSR format of `post_graph.txt`, stored as a `scipy.sparse.csr_matrix`. The shape is N by N. Non-zeros in the matrix correspond to all the edges in the full graph. It doesn't matter if the two nodes connected by an edge are training, validation or test nodes. 46 | * `text_feats.npy`: attributes of news. They are learned by RoBERT algorithm with 768 dimensions. 47 | * `user_adj_full.npz`: a sparse matrix in CSR format of `user_graph.txt`, stored as a `scipy.sparse.csr_matrix`. The shape is M by M. Non-zeros in the matrix correspond to all the edges in the full graph. It doesn't matter if the two nodes connected by an edge are training, validation or test nodes. 48 | * `user_feats.npy`: attributes of users. They are representive information of users, e.g., the number of followers, the number of following, and so on. 49 | * `user_graph.txt`: interaction graph of users. It is not input to Us-DeFake, but is intended to intuitively show the contents of the user_adj_*.npz file of users. 50 | * `user_post_graph.txt`: posting graph of news and users. 51 | * `1`: 1st fold of k-fold cross validation. 52 | * `text_adj_train.npz`: a sparse matrix in CSR format of training news, stored as a `scipy.sparse.csr_matrix`. The shape is also N by N. However, non-zeros in the matrix only correspond to edges connecting two training nodes. The graph sampler only picks nodes/edges from this `text_adj_train`, not `text_adj_full`. Therefore, neither the attribute information nor the structural information are revealed during training. Also, note that only aN rows and cols of `text_adj_train` contains non-zeros. For unweighted graph, the non-zeros are all 1. 53 | * `text_role.json`: a dictionary of four keys. Key `'tr'` corresponds to the list of all training node indices. Key `'va'` corresponds to the list of all validation node indices. Key `'te'` corresponds to the list of all test node indices. Note that in the raw data, nodes may have string-type ID. Key `'source news'` corresponds to the source news. You would need to re-assign numerical ID (0 to N-1) to the nodes, so that you can index into the matrices of adj, features and class labels. 54 | * `user_adj_train.npz`: a sparse matrix in CSR format of training users, stored as a `scipy.sparse.csr_matrix`. The shape is also M by M. However, non-zeros in the matrix only correspond to edges connecting two training nodes. The graph sampler only picks nodes/edges from this `user_adj_train`, not `user_adj_full`. Therefore, neither the attribute information nor the structural information are revealed during training. Also, note that only aN rows and cols of `user_adj_train` contains non-zeros. For unweighted graph, the non-zeros are all 1. 55 | * `user_role.json`: a dictionary of four keys. Key `'tr'` corresponds to the list of all training node indices. Key `'va'` corresponds to the list of all validation node indices. Key `'te'` corresponds to the list of all test node indices. Note that in the raw data, nodes may have string-type ID. You would need to re-assign numerical ID (0 to N-1) to the nodes, so that you can index into the matrices of adj, features and class labels. 56 | 57 | 58 | 59 | ## Cython Implemented Parallel Graph Sampler 60 | 61 | We have a cython module which need compilation before training can start. Compile the module by running the following from the root directory: 62 | 63 | `python graphsaint/setup.py build_ext --inplace` 64 | 65 | 66 | ## Training Configuration 67 | 68 | The hyperparameters needed in training can be set via the configuration file: `./train_config/.yml`. 69 | 70 | 71 | ## Run Training 72 | 73 | First of all, please compile cython samplers (see above). 74 | We suggest looking through the available command line arguments defined in `./utility/globals.py`. 75 | 76 | To run the code on CPU 77 | 78 | ``` 79 | python -m train --data_prefix ./data/ --fold --train_config ./train_config/.yml --gpu -1 80 | ``` 81 | 82 | 83 | To run the code on GPU 84 | 85 | ``` 86 | python -m train --data_prefix ./data/ --fold --train_config ./train_config/.yml --gpu 0 87 | ``` 88 | 89 | For example, to run dataset 'toy' on CPU: 90 | ``` 91 | python -m train --data_prefix ./data/toy --fold 1 --train_config ./train_config/toy.yml --gpu -1 92 | ``` 93 | 94 | 95 | ## Citation & Acknowledgement 96 | 97 | We thank Hanqing Zeng et al. proposed the GraphSAINT [paper](https://arxiv.org/abs/1907.04931) and released the [code](https://github.com/GraphSAINT/GraphSAINT). Us-DeFake employs GraphSAINT to learn representations of news and users in large scale online social networks. 98 | 99 | If you find this method helpful for your research, please cite our paper. 100 | 101 | ``` 102 | @article{su2022mining, 103 | title={Mining User-aware Multi-relations for Fake News Detection in Large Scale Online Social Networks}, 104 | author={Su, Xing and Yang, Jian and Wu, Jia and Zhang, Yuchen}, 105 | journal={arXiv preprint arXiv:2212.10778}, 106 | year={2022} 107 | } 108 | ``` 109 | 110 | 111 | -------------------------------------------------------------------------------- /embedding/text_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from utility.utils import * 6 | import embedding.layers as layers 7 | 8 | 9 | class GraphSAINT(nn.Module): 10 | def __init__(self, num_classes, arch_gcn, train_params, feat_full, label_full, cpu_eval=False): 11 | """ 12 | Build the multi-layer GNN architecture. 13 | 14 | Inputs: 15 | num_classes int, number of classes a node can belong to 16 | arch_gcn dict, config for each GNN layer 17 | train_params dict, training hyperparameters (e.g., learning rate) 18 | feat_full np array of shape N x f, where N is the total num of 19 | nodes and f is the dimension for input node feature 20 | label_full np array, for single-class classification, the shape 21 | is N x 1 and for multi-class classification, the 22 | shape is N x c (where c = num_classes) 23 | cpu_eval bool, if True, will put the model on CPU. 24 | 25 | Outputs: 26 | None 27 | """ 28 | super(GraphSAINT,self).__init__() 29 | self.use_cuda = (args_global.gpu >= 0) 30 | if cpu_eval: 31 | self.use_cuda=False 32 | if "attention" in arch_gcn: 33 | if "gated_attention" in arch_gcn: 34 | if arch_gcn['gated_attention']: 35 | self.aggregator_cls = layers.GatedAttentionAggregator 36 | self.mulhead = int(arch_gcn['attention']) 37 | else: 38 | self.aggregator_cls = layers.AttentionAggregator 39 | self.mulhead = int(arch_gcn['attention']) 40 | else: 41 | self.aggregator_cls = layers.HighOrderAggregator 42 | self.mulhead = 1 43 | self.num_layers = len(arch_gcn['arch'].split('-')) 44 | self.weight_decay = train_params['weight_decay'] 45 | self.dropout = train_params['dropout'] 46 | self.lr = train_params['lr'] 47 | self.arch_gcn = arch_gcn 48 | self.sigmoid_loss = (arch_gcn['loss'] == 'sigmoid') 49 | self.out_dim = arch_gcn['dim'] 50 | self.feat_full = torch.from_numpy(feat_full.astype(np.float32)) 51 | self.label_full = torch.from_numpy(label_full.astype(np.float32)) 52 | if self.use_cuda: 53 | self.feat_full = self.feat_full.cuda() 54 | self.label_full = self.label_full.cuda() 55 | if not self.sigmoid_loss: 56 | self.label_full_cat = torch.from_numpy(label_full.argmax(axis=1).astype(np.int64)) 57 | if self.use_cuda: 58 | self.label_full_cat = self.label_full_cat.cuda() 59 | self.num_classes = num_classes 60 | _dims, self.order_layer, self.act_layer, self.bias_layer, self.aggr_layer \ 61 | = parse_layer_yml(arch_gcn, self.feat_full.shape[1]) 62 | # get layer index for each conv layer, useful for jk net last layer aggregation 63 | self.set_idx_conv() 64 | self.set_dims(_dims) 65 | 66 | self.loss = 0 67 | self.opt_op = None 68 | 69 | # build the model below 70 | self.num_params = 0 71 | self.aggregators, num_param = self.get_aggregators() 72 | self.num_params += num_param 73 | self.conv_layers = nn.Sequential(*self.aggregators) 74 | self.classifier = layers.HighOrderAggregator(self.dims_feat[-1], self.out_dim,\ 75 | act='I', order=0, dropout=self.dropout, bias='bias') 76 | self.num_params += self.classifier.num_param 77 | self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) 78 | 79 | def set_dims(self, dims): 80 | """ 81 | Set the feature dimension / weight dimension for each GNN or MLP layer. 82 | We will use the dimensions set here to initialize PyTorch layers. 83 | 84 | Inputs: 85 | dims list, length of node feature for each hidden layer 86 | 87 | Outputs: 88 | None 89 | """ 90 | self.dims_feat = [dims[0]] + [ 91 | ((self.aggr_layer[l]=='concat') * self.order_layer[l] + 1) * dims[l+1] 92 | for l in range(len(dims) - 1) 93 | ] 94 | self.dims_weight = [(self.dims_feat[l],dims[l+1]) for l in range(len(dims)-1)] 95 | 96 | def set_idx_conv(self): 97 | """ 98 | Set the index of GNN layers for the full neural net. For example, if 99 | the full NN is having 1-0-1-0 arch (1-hop graph conv, followed by 0-hop 100 | MLP, ...). Then the layer indices will be 0, 2. 101 | """ 102 | idx_conv = np.where(np.array(self.order_layer) >= 1)[0] 103 | idx_conv = list(idx_conv[1:] - 1) 104 | idx_conv.append(len(self.order_layer) - 1) 105 | _o_arr = np.array(self.order_layer)[idx_conv] 106 | if np.prod(np.ediff1d(_o_arr)) == 0: 107 | self.idx_conv = idx_conv 108 | else: 109 | self.idx_conv = list(np.where(np.array(self.order_layer) == 1)[0]) 110 | 111 | 112 | def forward(self, node_subgraph, adj_subgraph): 113 | feat_subg = self.feat_full[node_subgraph] 114 | label_subg = self.label_full[node_subgraph] 115 | label_subg_converted = label_subg if self.sigmoid_loss else self.label_full_cat[node_subgraph] 116 | _, emb_subg = self.conv_layers((adj_subgraph, feat_subg)) 117 | emb_subg_norm = F.normalize(emb_subg, p=2, dim=1) 118 | pred_subg = self.classifier((None, emb_subg_norm))[1] 119 | return pred_subg, label_subg, label_subg_converted 120 | 121 | 122 | def _loss(self, preds, labels, norm_loss): 123 | """ 124 | The predictor performs sigmoid (for multi-class) or softmax (for single-class) 125 | """ 126 | if self.sigmoid_loss: 127 | norm_loss = norm_loss.unsqueeze(1) 128 | return torch.nn.BCEWithLogitsLoss(weight=norm_loss,reduction='sum')(preds, labels) 129 | else: 130 | _ls = torch.nn.CrossEntropyLoss(reduction='none')(preds, labels) 131 | return (norm_loss*_ls).sum() 132 | 133 | 134 | def get_aggregators(self): 135 | """ 136 | Return a list of aggregator instances. to be used in self.build() 137 | """ 138 | num_param = 0 139 | aggregators = [] 140 | for l in range(self.num_layers): 141 | aggr = self.aggregator_cls( 142 | *self.dims_weight[l], 143 | dropout=self.dropout, 144 | act=self.act_layer[l], 145 | order=self.order_layer[l], 146 | aggr=self.aggr_layer[l], 147 | bias=self.bias_layer[l], 148 | mulhead=self.mulhead, 149 | ) 150 | num_param += aggr.num_param 151 | aggregators.append(aggr) 152 | return aggregators, num_param 153 | 154 | def predict(self, preds): 155 | return nn.Sigmoid()(preds) if self.sigmoid_loss else F.softmax(preds, dim=1) 156 | 157 | 158 | def train_step(self, node_subgraph, adj_subgraph, norm_loss_subgraph): 159 | """ 160 | Forward and backward propagation 161 | """ 162 | self.train() 163 | self.optimizer.zero_grad() 164 | preds, labels, labels_converted = self(node_subgraph, adj_subgraph) 165 | loss = self._loss(preds, labels_converted, norm_loss_subgraph) # labels.squeeze()? 166 | loss.backward() 167 | torch.nn.utils.clip_grad_norm_(self.parameters(), 5) 168 | self.optimizer.step() 169 | return loss, self.predict(preds), labels 170 | 171 | def eval_step(self, node_subgraph, adj_subgraph, norm_loss_subgraph): 172 | """ 173 | Forward propagation only 174 | """ 175 | self.eval() 176 | with torch.no_grad(): 177 | preds,labels,labels_converted = self(node_subgraph, adj_subgraph) 178 | loss = self._loss(preds,labels_converted,norm_loss_subgraph) 179 | return loss, self.predict(preds), labels 180 | -------------------------------------------------------------------------------- /embedding/minibatch.py: -------------------------------------------------------------------------------- 1 | from utility.globals import * 2 | import math 3 | from utility.utils import * 4 | from utility.graph_samplers import * 5 | from graphsaint.norm_aggr import * 6 | import torch 7 | import scipy.sparse as sp 8 | import scipy 9 | 10 | import numpy as np 11 | import time 12 | 13 | 14 | 15 | def _coo_scipy2torch(adj): 16 | """ 17 | convert a scipy sparse COO matrix to torch 18 | """ 19 | values = adj.data 20 | indices = np.vstack((adj.row, adj.col)) 21 | i = torch.LongTensor(indices) 22 | v = torch.FloatTensor(values) 23 | return torch.sparse.FloatTensor(i,v, torch.Size(adj.shape)) 24 | 25 | 26 | 27 | class Minibatch: 28 | """ 29 | Provides minibatches for the trainer or evaluator. This class is responsible for 30 | calling the proper graph sampler and estimating normalization coefficients. 31 | """ 32 | def __init__(self, adj_full_norm, adj_train, role, train_params, cpu_eval=False): 33 | """ 34 | Inputs: 35 | adj_full_norm scipy CSR, adj matrix for the full graph (row-normalized) 36 | adj_train scipy CSR, adj matrix for the traing graph. Since we are 37 | under transductive setting, for any edge in this adj, 38 | both end points must be training nodes. 39 | role dict, key 'tr' -> list of training node IDs; 40 | key 'va' -> list of validation node IDs; 41 | key 'te' -> list of test node IDs. 42 | train_params dict, additional parameters related to training. e.g., 43 | how many subgraphs we want to get to estimate the norm 44 | coefficients. 45 | cpu_eval bool, whether or not we want to run full-batch evaluation 46 | on the CPU. 47 | 48 | Outputs: 49 | None 50 | """ 51 | self.use_cuda = (args_global.gpu >= 0) 52 | if cpu_eval: 53 | self.use_cuda=False 54 | 55 | self.node_train = np.array(role['tr']) 56 | self.node_val = np.array(role['va']) 57 | self.node_test = np.array(role['te']) 58 | 59 | self.adj_full_norm = _coo_scipy2torch(adj_full_norm.tocoo()) 60 | self.adj_train = adj_train 61 | # ----------------------- 62 | # sanity check (optional) 63 | # ----------------------- 64 | #for role_set in [self.node_val, self.node_test]: 65 | # for v in role_set: 66 | # assert self.adj_train.indptr[v+1] == self.adj_train.indptr[v] 67 | #_adj_train_T = sp.csr_matrix.tocsc(self.adj_train) 68 | #assert np.abs(_adj_train_T.indices - self.adj_train.indices).sum() == 0 69 | #assert np.abs(_adj_train_T.indptr - self.adj_train.indptr).sum() == 0 70 | #_adj_full_T = sp.csr_matrix.tocsc(adj_full_norm) 71 | #assert np.abs(_adj_full_T.indices - adj_full_norm.indices).sum() == 0 72 | #assert np.abs(_adj_full_T.indptr - adj_full_norm.indptr).sum() == 0 73 | #printf("SANITY CHECK PASSED", style="yellow") 74 | if self.use_cuda: 75 | # now i put everything on GPU. Ideally, full graph adj/feat 76 | # should be optionally placed on CPU 77 | self.adj_full_norm = self.adj_full_norm.cuda() 78 | 79 | # below: book-keeping for mini-batch 80 | self.node_subgraph = None 81 | self.batch_num = -1 82 | 83 | self.method_sample = None 84 | self.subgraphs_remaining_indptr = [] 85 | self.subgraphs_remaining_indices = [] 86 | self.subgraphs_remaining_data = [] 87 | self.subgraphs_remaining_nodes = [] 88 | self.subgraphs_remaining_edge_index = [] 89 | 90 | self.norm_loss_train = np.zeros(self.adj_train.shape[0]) 91 | # norm_loss_test is used in full batch evaluation (without sampling). 92 | # so neighbor features are simply averaged. 93 | self.norm_loss_test = np.zeros(self.adj_full_norm.shape[0]) 94 | _denom = len(self.node_train) + len(self.node_val) + len(self.node_test) 95 | self.norm_loss_test[self.node_train] = 1. / _denom 96 | self.norm_loss_test[self.node_val] = 1. / _denom 97 | self.norm_loss_test[self.node_test] = 1. / _denom 98 | self.norm_loss_test = torch.from_numpy(self.norm_loss_test.astype(np.float32)) 99 | if self.use_cuda: 100 | self.norm_loss_test = self.norm_loss_test.cuda() 101 | self.norm_aggr_train = np.zeros(self.adj_train.size) 102 | 103 | self.sample_coverage = train_params['sample_coverage'] 104 | self.deg_train = np.array(self.adj_train.sum(1)).flatten() 105 | 106 | def set_sampler(self, train_phases): 107 | """ 108 | Pick the proper graph sampler. Run the warm-up phase to estimate 109 | loss / aggregation normalization coefficients. 110 | 111 | Inputs: 112 | train_phases dict, config / params for the graph sampler 113 | 114 | Outputs: 115 | None 116 | """ 117 | self.subgraphs_remaining_indptr = [] 118 | self.subgraphs_remaining_indices = [] 119 | self.subgraphs_remaining_data = [] 120 | self.subgraphs_remaining_nodes = [] 121 | self.subgraphs_remaining_edge_index = [] 122 | self.method_sample == 'rw' 123 | self.size_subg_budget = train_phases['num_root'] * train_phases['depth'] 124 | self.graph_sampler = rw_sampling( 125 | self.adj_train, 126 | self.node_train, 127 | self.size_subg_budget, 128 | int(train_phases['num_root']), 129 | int(train_phases['depth']), 130 | ) 131 | self.norm_loss_train = np.zeros(self.adj_train.shape[0]) 132 | self.norm_aggr_train = np.zeros(self.adj_train.size).astype(np.float32) 133 | 134 | # ------------------------------------------------------------- 135 | # BELOW: estimation of loss / aggregation normalization factors 136 | # ------------------------------------------------------------- 137 | # For some special sampler, no need to estimate norm factors, we can calculate 138 | # the node / edge probabilities directly. 139 | # However, for integrity of the framework, we follow the same procedure 140 | # for all samplers: 141 | # 1. sample enough number of subgraphs 142 | # 2. update the counter for each node / edge in the training graph 143 | # 3. estimate norm factor alpha and lambda 144 | tot_sampled_nodes = 0 145 | while True: 146 | self.par_graph_sample('train') 147 | tot_sampled_nodes = sum([len(n) for n in self.subgraphs_remaining_nodes]) 148 | if tot_sampled_nodes > self.sample_coverage * self.node_train.size: 149 | break 150 | print() 151 | num_subg = len(self.subgraphs_remaining_nodes) 152 | for i in range(num_subg): 153 | self.norm_aggr_train[self.subgraphs_remaining_edge_index[i]] += 1 154 | self.norm_loss_train[self.subgraphs_remaining_nodes[i]] += 1 155 | assert self.norm_loss_train[self.node_val].sum() + self.norm_loss_train[self.node_test].sum() == 0 156 | for v in range(self.adj_train.shape[0]): 157 | i_s = self.adj_train.indptr[v] 158 | i_e = self.adj_train.indptr[v + 1] 159 | val = np.clip(self.norm_loss_train[v] / self.norm_aggr_train[i_s : i_e], 0, 1e4) 160 | val[np.isnan(val)] = 0.1 161 | self.norm_aggr_train[i_s : i_e] = val 162 | self.norm_loss_train[np.where(self.norm_loss_train==0)[0]] = 0.1 163 | self.norm_loss_train[self.node_val] = 0 164 | self.norm_loss_train[self.node_test] = 0 165 | self.norm_loss_train[self.node_train] = num_subg / self.norm_loss_train[self.node_train] / self.node_train.size 166 | self.norm_loss_train = torch.from_numpy(self.norm_loss_train.astype(np.float32)) 167 | if self.use_cuda: 168 | self.norm_loss_train = self.norm_loss_train.cuda() 169 | 170 | def par_graph_sample(self,phase): 171 | """ 172 | Perform graph sampling in parallel. A wrapper function for graph_samplers.py 173 | """ 174 | t0 = time.time() 175 | _indptr, _indices, _data, _v, _edge_index = self.graph_sampler.par_sample(phase) 176 | t1 = time.time() 177 | print('sampling 200 subgraphs: time = {:.3f} sec'.format(t1 - t0), end="\r") 178 | self.subgraphs_remaining_indptr.extend(_indptr) 179 | self.subgraphs_remaining_indices.extend(_indices) 180 | self.subgraphs_remaining_data.extend(_data) 181 | self.subgraphs_remaining_nodes.extend(_v) 182 | self.subgraphs_remaining_edge_index.extend(_edge_index) 183 | 184 | def one_batch(self, mode='train'): 185 | """ 186 | Generate one minibatch for trainer. In the 'train' mode, one minibatch corresponds 187 | to one subgraph of the training graph. In the 'val' or 'test' mode, one batch 188 | corresponds to the full graph (i.e., full-batch rather than minibatch evaluation 189 | for validation / test sets). 190 | 191 | Inputs: 192 | mode str, can be 'train', 'val', 'test' or 'valtest' 193 | 194 | Outputs: 195 | node_subgraph np array, IDs of the subgraph / full graph nodes 196 | adj scipy CSR, adj matrix of the subgraph / full graph 197 | norm_loss np array, loss normalization coefficients. In 'val' or 198 | 'test' modes, we don't need to normalize, and so the values 199 | in this array are all 1. 200 | """ 201 | if mode in ['val','test','valtest']: 202 | self.node_subgraph = np.arange(self.adj_full_norm.shape[0]) 203 | adj = self.adj_full_norm 204 | else: 205 | assert mode == 'train' 206 | if len(self.subgraphs_remaining_nodes) == 0: 207 | self.par_graph_sample('train') 208 | print() 209 | 210 | self.node_subgraph = self.subgraphs_remaining_nodes.pop() 211 | self.size_subgraph = len(self.node_subgraph) 212 | adj = sp.csr_matrix( 213 | ( 214 | self.subgraphs_remaining_data.pop(), 215 | self.subgraphs_remaining_indices.pop(), 216 | self.subgraphs_remaining_indptr.pop()), 217 | shape=(self.size_subgraph,self.size_subgraph, 218 | ) 219 | ) 220 | adj_edge_index = self.subgraphs_remaining_edge_index.pop() 221 | #print("{} nodes, {} edges, {} degree".format(self.node_subgraph.size,adj.size,adj.size/self.node_subgraph.size)) 222 | norm_aggr(adj.data, adj_edge_index, self.norm_aggr_train, num_proc=args_global.num_cpu_core) 223 | # adj.data[:] = self.norm_aggr_train[adj_edge_index][:] # this line is interchangable with the above line 224 | adj = adj_norm(adj, deg=self.deg_train[self.node_subgraph]) 225 | adj = _coo_scipy2torch(adj.tocoo()) 226 | if self.use_cuda: 227 | adj = adj.cuda() 228 | self.batch_num += 1 229 | norm_loss = self.norm_loss_test if mode in ['val','test', 'valtest'] else self.norm_loss_train 230 | norm_loss = norm_loss[self.node_subgraph] 231 | return self.node_subgraph, adj, norm_loss 232 | 233 | 234 | def num_training_batches(self): 235 | return math.ceil(self.node_train.shape[0] / float(self.size_subg_budget)) 236 | 237 | def shuffle(self): 238 | self.node_train = np.random.permutation(self.node_train) 239 | self.batch_num = -1 240 | 241 | def end(self): 242 | return (self.batch_num + 1) * self.size_subg_budget >= self.node_train.shape[0] 243 | -------------------------------------------------------------------------------- /utility/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import pdb 4 | import scipy.sparse 5 | from sklearn.preprocessing import StandardScaler 6 | import os 7 | import yaml 8 | import scipy.sparse as sp 9 | from utility.globals import * 10 | from torch.autograd import Variable 11 | import networkx as nx 12 | import torch 13 | 14 | def get_source_news(prefix,fold): 15 | role = json.load(open('./{}/{}/text_role.json'.format(prefix,fold))) 16 | source = role["source news"] 17 | return source 18 | 19 | def load_text_data(prefix, fold, normalize=True): 20 | """ 21 | Load the various data files residing in the `prefix` directory. 22 | Files to be loaded: 23 | adj_full.npz sparse matrix in CSR format, stored as scipy.sparse.csr_matrix 24 | The shape is N by N. Non-zeros in the matrix correspond to all 25 | the edges in the full graph. It doesn't matter if the two nodes 26 | connected by an edge are training, validation or test nodes. 27 | For unweighted graph, the non-zeros are all 1. 28 | adj_train.npz sparse matrix in CSR format, stored as a scipy.sparse.csr_matrix 29 | The shape is also N by N. However, non-zeros in the matrix only 30 | correspond to edges connecting two training nodes. The graph 31 | sampler only picks nodes/edges from this adj_train, not adj_full. 32 | Therefore, neither the attribute information nor the structural 33 | information are revealed during training. Also, note that only 34 | a x N rows and cols of adj_train contains non-zeros. For 35 | unweighted graph, the non-zeros are all 1. 36 | role.json a dict of three keys. Key 'tr' corresponds to the list of all 37 | 'tr': list of all training node indices 38 | 'va': list of all validation node indices 39 | 'te': list of all test node indices 40 | Note that in the raw data, nodes may have string-type ID. You 41 | need to re-assign numerical ID (0 to N-1) to the nodes, so that 42 | you can index into the matrices of adj, features and class labels. 43 | class_map.json a dict of length N. Each key is a node index, and each value is 44 | either a length C binary list (for multi-class classification) 45 | or an integer scalar (0 to C-1, for single-class classification). 46 | feats.npz a numpy array of shape N by F. Row i corresponds to the attribute 47 | vector of node i. 48 | 49 | Inputs: 50 | prefix string, directory containing the above graph related files 51 | normalize bool, whether or not to normalize the node features 52 | 53 | Outputs: 54 | adj_full scipy sparse CSR (shape N x N, |E| non-zeros), the adj matrix of 55 | the full graph, with N being total num of train + val + test nodes. 56 | adj_train scipy sparse CSR (shape N x N, |E'| non-zeros), the adj matrix of 57 | the training graph. While the shape is the same as adj_full, the 58 | rows/cols corresponding to val/test nodes in adj_train are all-zero. 59 | feats np array (shape N x f), the node feature matrix, with f being the 60 | length of each node feature vector. 61 | class_map dict, where key is the node ID and value is the classes this node 62 | belongs to. 63 | role dict, where keys are: 'tr' for train, 'va' for validation and 'te' 64 | for test nodes. The value is the list of IDs of nodes belonging to 65 | the train/val/test sets. 66 | """ 67 | adj_full = scipy.sparse.load_npz('./{}/text_adj_full.npz'.format(prefix)).astype(np.bool) 68 | adj_train = scipy.sparse.load_npz('./{}/{}/text_adj_train.npz'.format(prefix,fold)).astype(np.bool) 69 | role = json.load(open('./{}/{}/text_role.json'.format(prefix,fold))) 70 | feats = np.load('./{}/text_feats.npy'.format(prefix)) 71 | class_map = json.load(open('./{}/class_map.json'.format(prefix))) 72 | class_map = {int(k):v for k,v in class_map.items()} 73 | assert len(class_map) == feats.shape[0] 74 | # ---- normalize feats ---- 75 | train_nodes = np.array(list(set(adj_train.nonzero()[0]))) 76 | train_feats = feats[train_nodes] 77 | scaler = StandardScaler() 78 | scaler.fit(train_feats) 79 | feats = scaler.transform(feats) 80 | # ------------------------- 81 | return adj_full, adj_train, feats, class_map, role 82 | 83 | 84 | def load_user_data(prefix, fold, normalize=True): 85 | user_adj_full = scipy.sparse.load_npz('./{}/user_adj_full.npz'.format(prefix)).astype(np.bool) 86 | user_adj_train = scipy.sparse.load_npz('./{}/{}/user_adj_train.npz'.format(prefix,fold)).astype(np.bool) 87 | user_role = json.load(open('./{}/{}/user_role.json'.format(prefix,fold))) 88 | user_feats = np.load('./{}/user_feats.npy'.format(prefix)) 89 | # ---- normalize feats ---- 90 | train_nodes = np.array(list(set(user_adj_train.nonzero()[0]))) 91 | train_feats = user_feats[train_nodes] 92 | scaler = StandardScaler() 93 | scaler.fit(train_feats) 94 | user_feats = scaler.transform(user_feats) 95 | # ------------------------- 96 | return user_adj_full, user_adj_train, user_feats, user_role 97 | 98 | 99 | def process_graph_data(adj_full, adj_train, feats, class_map, role): 100 | """ 101 | setup vertex property map for output classes, train/val/test masks, and feats 102 | """ 103 | num_vertices = adj_full.shape[0] 104 | if isinstance(list(class_map.values())[0],list): 105 | num_classes = len(list(class_map.values())[0]) 106 | class_arr = np.zeros((num_vertices, num_classes)) 107 | for k,v in class_map.items(): 108 | class_arr[k] = v 109 | else: 110 | num_classes = max(class_map.values()) - min(class_map.values()) + 1 111 | class_arr = np.zeros((num_vertices, num_classes)) 112 | offset = min(class_map.values()) 113 | for k,v in class_map.items(): 114 | class_arr[k][v-offset] = 1 115 | return adj_full, adj_train, feats, class_arr, role 116 | 117 | 118 | def parse_layer_yml(arch_gcn,dim_input): 119 | """ 120 | Parse the *.yml config file to retrieve the GNN structure. 121 | """ 122 | num_layers = len(arch_gcn['arch'].split('-')) 123 | # set default values, then update by arch_gcn 124 | bias_layer = [arch_gcn['bias']]*num_layers 125 | act_layer = [arch_gcn['act']]*num_layers 126 | aggr_layer = [arch_gcn['aggr']]*num_layers 127 | dims_layer = [arch_gcn['dim']]*num_layers 128 | order_layer = [int(o) for o in arch_gcn['arch'].split('-')] 129 | return [dim_input]+dims_layer,order_layer,act_layer,bias_layer,aggr_layer 130 | 131 | 132 | def parse_n_prepare(flags): 133 | with open(flags.train_config) as f_train_config: 134 | train_config = yaml.safe_load(f_train_config) 135 | arch_gcn = { 136 | 'dim': -1, 137 | 'aggr': 'concat', 138 | 'loss': 'softmax', 139 | 'arch': '1', 140 | 'act': 'I', 141 | 'bias': 'norm' 142 | } 143 | arch_gcn.update(train_config['network'][0]) 144 | train_params = { 145 | 'lr': 0.01, 146 | 'weight_decay': 0., 147 | 'norm_loss': True, 148 | 'norm_aggr': True, 149 | 'q_threshold': 50, 150 | 'q_offset': 0 151 | } 152 | train_params.update(train_config['params'][0]) 153 | train_phases = train_config['phase'] 154 | for ph in train_phases: 155 | assert 'end' in ph 156 | assert 'sampler' in ph 157 | print("Loading training data..") 158 | return train_params,train_phases,arch_gcn 159 | 160 | 161 | def data_prepare(flags): 162 | temp_data = load_text_data(flags.data_prefix, flags.fold) 163 | text_train_data = process_graph_data(*temp_data) 164 | print("Done loading training data of text..") 165 | user_train_data = load_user_data(flags.data_prefix, flags.fold) 166 | print("Done loading training data of users..") 167 | return text_train_data, user_train_data 168 | 169 | 170 | def load_user_post_graph(prefix): 171 | user_post_graph = nx.read_edgelist('./{}/user_post_graph.txt'.format(prefix), create_using=nx.DiGraph()) 172 | edges = nx.edges(user_post_graph) 173 | relation = {} 174 | for edge in edges: 175 | relation[int(edge[1])] = int(edge[0]) 176 | return relation 177 | 178 | 179 | def emb_concatenation(text_emb, user_emb, node_subgraph, flags): 180 | posting_relation = load_user_post_graph(flags.data_prefix) 181 | rows, cols = text_emb.shape 182 | num_users, dim_feats = user_emb.shape 183 | concate_emb = [] 184 | for i in range(rows): 185 | if i in posting_relation.keys(): 186 | u = posting_relation[i] 187 | if u in node_subgraph: 188 | u_index = np.where(node_subgraph == u) 189 | concate_emb.append(torch.add(text_emb[i], user_emb[u_index[0][0]], alpha=1)) 190 | # concate_emb.append(torch.mul(text_emb[i], user_emb[u_index[0][0]])) 191 | else: 192 | concate_emb.append(text_emb[i]) 193 | else: 194 | concate_emb.append(text_emb[i]) 195 | concate_emb = torch.from_numpy(np.array([item.cpu().detach().numpy() for item in concate_emb])) 196 | return concate_emb 197 | 198 | 199 | def log_dir(f_train_config,prefix,git_branch,git_rev,timestamp): 200 | import getpass 201 | log_dir = args_global.dir_log+"/log_train/" + prefix.split("/")[-1] 202 | log_dir += "/{ts}-{model}-{gitrev:s}/".format( 203 | model='graphsaint', 204 | gitrev=git_rev.strip(), 205 | ts=timestamp) 206 | if not os.path.exists(log_dir): 207 | os.makedirs(log_dir) 208 | if f_train_config != '': 209 | from shutil import copyfile 210 | copyfile(f_train_config,'{}/{}'.format(log_dir,f_train_config.split('/')[-1])) 211 | return log_dir 212 | 213 | def sess_dir(dims,train_config,prefix,git_branch,git_rev,timestamp): 214 | import getpass 215 | log_dir = "saved_models/" + prefix.split("/")[-1] 216 | log_dir += "/{ts}-{model}-{gitrev:s}-{layer}/".format( 217 | model='graphsaint', 218 | gitrev=git_rev.strip(), 219 | layer='-'.join(dims), 220 | ts=timestamp) 221 | if not os.path.exists(log_dir): 222 | os.makedirs(log_dir) 223 | return sess_dir 224 | 225 | 226 | def adj_norm(adj, deg=None, sort_indices=True): 227 | """ 228 | Normalize adj according to the method of rw normalization. 229 | Note that sym norm is used in the original GCN paper (kipf), 230 | while rw norm is used in GraphSAGE and some other variants. 231 | Here we don't perform sym norm since it doesn't seem to 232 | help with accuracy improvement. 233 | 234 | # Procedure: 235 | # 1. adj add self-connection --> adj' 236 | # 2. D' deg matrix from adj' 237 | # 3. norm by D^{-1} x adj' 238 | if sort_indices is True, we re-sort the indices of the returned adj 239 | Note that after 'dot' the indices of a node would be in descending order 240 | rather than ascending order 241 | """ 242 | diag_shape = (adj.shape[0],adj.shape[1]) 243 | D = adj.sum(1).flatten() if deg is None else deg 244 | np.seterr(divide='ignore', invalid='ignore') 245 | norm_diag = sp.dia_matrix((1/D,0),shape=diag_shape) 246 | adj_norm = norm_diag.dot(adj) 247 | if sort_indices: 248 | adj_norm.sort_indices() 249 | return adj_norm 250 | 251 | def to_numpy(x): 252 | if isinstance(x, Variable): 253 | x = x.data 254 | return x.cpu().numpy() if x.is_cuda else x.numpy() 255 | 256 | 257 | 258 | ################## 259 | # PRINTING UTILS # 260 | #----------------# 261 | 262 | _bcolors = {'header': '\033[95m', 263 | 'blue': '\033[94m', 264 | 'green': '\033[92m', 265 | 'yellow': '\033[93m', 266 | 'red': '\033[91m', 267 | 'bold': '\033[1m', 268 | 'underline': '\033[4m'} 269 | 270 | 271 | def printf(msg,style=''): 272 | if not style or style == 'black': 273 | print(msg) 274 | else: 275 | print("{color1}{msg}{color2}".format(color1=_bcolors[style],msg=msg,color2='\033[0m')) 276 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utility.globals import * 2 | from embedding.text_model import GraphSAINT 3 | from embedding.user_model import UnGraphSAINT 4 | from embedding.minibatch import Minibatch 5 | from utility.utils import * 6 | from utility.metric import * 7 | import torch 8 | import time 9 | import torch.nn.functional as F 10 | 11 | 12 | def evaluate_full_batch(model, minibatch, loss, preds, labels, mode=''): 13 | """ 14 | Full batch evaluation: for validation and test sets only. 15 | When calculating the F1 score, we will mask the relevant root nodes 16 | (e.g., those belonging to the val / test sets). 17 | """ 18 | if mode == 'val': 19 | node_target = [minibatch.node_val] 20 | elif mode == 'test': 21 | node_target = [minibatch.node_test] 22 | else: 23 | assert mode == 'valtest' 24 | node_target = [minibatch.node_val, minibatch.node_test] 25 | acc, pre, rec, f1 = [], [], [], [] 26 | for n in node_target: 27 | results = evaluation(to_numpy(labels[n]), to_numpy(preds[n]), model.sigmoid_loss) 28 | acc.append(results[0]) 29 | pre.append(results[1]) 30 | rec.append(results[2]) 31 | f1.append(results[3]) 32 | acc = acc[0] if len(acc) == 1 else acc 33 | pre = pre[0] if len(pre) == 1 else pre 34 | rec = rec[0] if len(rec) == 1 else rec 35 | f1 = f1[0] if len(f1) == 1 else f1 36 | # loss is not very accurate in this case, since loss is also contributed by training nodes 37 | # on the other hand, for val / test, we mostly care about their accuracy only. 38 | # so the loss issue is not a problem. 39 | return loss, acc, pre, rec, f1 40 | 41 | 42 | def evaluate_source_news(model, minibatch, preds_vate, labels_vate, mode=''): 43 | role = get_source_news(args_global.data_prefix,args_global.fold) 44 | assert mode == 'valtest' 45 | node_target = [minibatch.node_val, minibatch.node_test] 46 | source_target = [] 47 | for nodes in node_target: 48 | souce = [] 49 | for n in nodes: 50 | if n in role: 51 | souce.append(n) 52 | source_target.append(souce) 53 | print(source_target) 54 | 55 | acc, pre, rec, f1 = [], [], [], [] 56 | for n in source_target: 57 | results = evaluation(to_numpy(labels_vate[n]), to_numpy(preds_vate[n]), model.sigmoid_loss) 58 | acc.append(results[0]) 59 | pre.append(results[1]) 60 | rec.append(results[2]) 61 | f1.append(results[3]) 62 | acc = acc[0] if len(acc) == 1 else acc 63 | pre = pre[0] if len(pre) == 1 else pre 64 | rec = rec[0] if len(rec) == 1 else rec 65 | f1 = f1[0] if len(f1) == 1 else f1 66 | 67 | acc_val, acc_test = acc 68 | pre_val, pre_test = pre 69 | rec_val, rec_test = rec 70 | f1_val, f1_test = f1 71 | return acc_val, acc_test, pre_val, pre_test, rec_val, rec_test, f1_val, f1_test 72 | 73 | 74 | def text_prepare(train_data,train_params,arch_gcn): 75 | """ 76 | Prepare some data structure and initialize model / minibatch handler before 77 | the actual iterative training taking place. 78 | """ 79 | adj_full, adj_train, feat_full, class_arr,role = train_data 80 | adj_full = adj_full.astype(np.int32) 81 | adj_train = adj_train.astype(np.int32) 82 | adj_full_norm = adj_norm(adj_full) 83 | num_classes = class_arr.shape[1] 84 | 85 | minibatch = Minibatch(adj_full_norm, adj_train, role, train_params) 86 | model = GraphSAINT(num_classes, arch_gcn, train_params, feat_full, class_arr) 87 | printf("TOTAL NUM OF PARAMS in text model = {}".format(sum(p.numel() for p in model.parameters())), style="yellow") 88 | minibatch_eval=Minibatch(adj_full_norm, adj_train, role, train_params, cpu_eval=True) 89 | model_eval=GraphSAINT(num_classes, arch_gcn, train_params, feat_full, class_arr, cpu_eval=True) 90 | if args_global.gpu >= 0: 91 | model = model.cuda() 92 | return model, minibatch, minibatch_eval, model_eval 93 | 94 | 95 | def user_prepare(train_data,train_params,arch_gcn): 96 | adj_full, adj_train, feat_full, role = train_data 97 | adj_full = adj_full.astype(np.int32) 98 | adj_train = adj_train.astype(np.int32) 99 | adj_full_norm = adj_norm(adj_full) 100 | 101 | minibatch = Minibatch(adj_full_norm, adj_train, role, train_params) 102 | model = UnGraphSAINT(arch_gcn, train_params, feat_full) 103 | printf("TOTAL NUM OF PARAMS in user model = {}".format(sum(p.numel() for p in model.parameters())), style="yellow") 104 | minibatch_eval=Minibatch(adj_full_norm, adj_train, role, train_params, cpu_eval=True) 105 | model_eval=UnGraphSAINT(arch_gcn, train_params, feat_full, cpu_eval=True) 106 | if args_global.gpu >= 0: 107 | model = model.cuda() 108 | return model, minibatch, minibatch_eval, model_eval 109 | 110 | 111 | def train_step(model_t, minibatch_t, model_u, minibatch_u): 112 | node_subgraph_t, adj_subgraph_t, norm_loss_subgraph_t = minibatch_t.one_batch(mode='train') 113 | node_subgraph_u, adj_subgraph_u, norm_loss_subgraph_u = minibatch_u.one_batch(mode='train') 114 | 115 | model_t.train() 116 | model_u.train() 117 | model_t.optimizer.zero_grad() 118 | model_u.optimizer.zero_grad() 119 | 120 | t_preds_train, labels_train, labels_converted = model_t.forward(node_subgraph_t, adj_subgraph_t) 121 | loss_t = model_t._loss(t_preds_train, labels_converted, norm_loss_subgraph_t) 122 | 123 | u_preds_train = model_u.forward(node_subgraph_u, adj_subgraph_u) 124 | loss_u = model_u.rec_loss(adj_subgraph_u, u_preds_train, norm_loss_subgraph_u) 125 | if args_global.gpu >= 0: 126 | loss_u = loss_u.cuda() 127 | loss_train = loss_u + loss_t 128 | 129 | loss_train.backward() 130 | torch.nn.utils.clip_grad_norm_(model_t.parameters(), 5) 131 | torch.nn.utils.clip_grad_norm_(model_u.parameters(), 5) 132 | model_t.optimizer.step() 133 | model_u.optimizer.step() 134 | pred_train = emb_concatenation(t_preds_train, u_preds_train, node_subgraph_u, args_global) 135 | preds_train = F.softmax(pred_train, dim=1) 136 | 137 | return loss_train, preds_train, labels_train 138 | 139 | 140 | def eval_step(model_eval_t, minibatch_t, model_eval_u, minibatch_u, mode=''): 141 | node_subgraph_eval_t, adj_subgraph_eval_t, norm_loss_subgraph_eval_t = minibatch_t.one_batch(mode=mode) 142 | node_subgraph_eval_u, adj_subgraph_eval_u, norm_loss_subgraph_eval_u = minibatch_u.one_batch(mode=mode) 143 | 144 | model_eval_t.eval() 145 | model_eval_u.eval() 146 | with torch.no_grad(): 147 | t_preds_eval, labels_eval, labels_converted_eval = model_eval_t.forward(node_subgraph_eval_t, adj_subgraph_eval_t) 148 | loss_t_eval = model_eval_t._loss(t_preds_eval, labels_converted_eval, norm_loss_subgraph_eval_t) 149 | 150 | u_preds_eval = model_eval_u.forward(node_subgraph_eval_u, adj_subgraph_eval_u) 151 | loss_u_eval = model_eval_u.rec_loss(adj_subgraph_eval_u, u_preds_eval, norm_loss_subgraph_eval_u) 152 | loss_eval = loss_t_eval + loss_u_eval 153 | pred_eval = emb_concatenation(t_preds_eval, u_preds_eval, node_subgraph_eval_u, args_global) 154 | preds_eval = F.softmax(pred_eval, dim=1) 155 | return loss_eval, preds_eval, labels_eval 156 | 157 | 158 | def train(train_phases, t_model, t_minibatch, t_minibatch_eval, t_model_eval, u_model, u_minibatch, u_minibatch_eval, u_model_eval, eval_val_every): 159 | if not args_global.cpu_eval: 160 | t_minibatch_eval=t_minibatch 161 | u_minibatch_eval=u_minibatch 162 | epoch_ph_start = 0 163 | acc_best, ep_best = 0, -1 164 | time_train = 0 165 | dir_saver = '{}/saved_models'.format(args_global.dir_log) 166 | path_saver_t = '{}/saved_models/saved_model_text_{}.pkl'.format(args_global.dir_log, timestamp) 167 | path_saver_u = '{}/saved_models/saved_model_user_{}.pkl'.format(args_global.dir_log, timestamp) 168 | 169 | # for ip, phase in enumerate(train_phases): 170 | # printf('START PHASE {:4d}'.format(ip),style='underline') 171 | phase = train_phases[0] 172 | t_minibatch.set_sampler(phase) 173 | u_minibatch.set_sampler(phase) 174 | t_num_batches = t_minibatch.num_training_batches() 175 | u_num_batches = u_minibatch.num_training_batches() 176 | for e in range(epoch_ph_start, int(phase['end'])): 177 | printf('Epoch {:4d}'.format(e),style='bold') 178 | t_minibatch.shuffle() 179 | u_minibatch.shuffle() 180 | l_loss_tr, l_acc_tr, l_pre_tr, l_rec_tr, l_f1_tr = [], [], [], [], [] 181 | time_train_ep = 0 182 | while not t_minibatch.end(): 183 | t1 = time.time() 184 | loss_train,preds_train,labels_train = train_step(t_model,t_minibatch,u_model,u_minibatch) 185 | time_train_ep += time.time() - t1 186 | if not t_minibatch.batch_num % args_global.eval_train_every: 187 | acc, pre, rec, f1 = evaluation(to_numpy(labels_train),to_numpy(preds_train),t_model.sigmoid_loss) 188 | l_loss_tr.append(loss_train) 189 | l_acc_tr.append(acc) 190 | l_pre_tr.append(pre) 191 | l_rec_tr.append(rec) 192 | l_f1_tr.append(f1) 193 | if (e+1)%eval_val_every == 0: 194 | if args_global.cpu_eval: 195 | torch.save(t_model.state_dict(),'t_tmp.pkl') 196 | torch.save(u_model.state_dict(), 'u_tmp.pkl') 197 | t_model_eval.load_state_dict(torch.load('t_tmp.pkl',map_location=lambda storage, loc: storage)) 198 | u_model_eval.load_state_dict(torch.load('u_tmp.pkl', map_location=lambda storage, loc: storage)) 199 | else: 200 | t_model_eval = t_model 201 | u_model_eval = u_model 202 | 203 | loss_eval, preds_eval, labels_eval = eval_step(t_model_eval, t_minibatch_eval, u_model_eval, u_minibatch_eval, mode='val') 204 | loss_val, acc_val, pre_val, rec_val, f1_val = evaluate_full_batch(t_model_eval, t_minibatch_eval, loss_eval, 205 | preds_eval, labels_eval, mode='val') 206 | printf('TRAIN (Ep avg): loss = {:.4f}\taccuracy = {:.4f}\tprecision = {:.4f}\trecall = {:.4f}\tF1 = {:.4f}\ttrain time = {:.4f} sec'\ 207 | .format(f_mean(l_loss_tr), f_mean(l_acc_tr), f_mean(l_pre_tr), f_mean(l_rec_tr), f_mean(l_f1_tr), time_train_ep)) 208 | printf('VALIDATION: loss = {:.4f}\taccuracy = {:.4f}\tprecision = {:.4f}\trecall = {:.4f}\tF1 = {:.4f}'\ 209 | .format(loss_val, acc_val, pre_val, rec_val, f1_val), style='yellow') 210 | if acc_val > acc_best: 211 | acc_best, ep_best = acc_val, e 212 | if not os.path.exists(dir_saver): 213 | os.makedirs(dir_saver) 214 | printf(' Saving model ...', style='yellow') 215 | torch.save(t_model.state_dict(), path_saver_t) 216 | torch.save(u_model.state_dict(), path_saver_u) 217 | time_train += time_train_ep 218 | epoch_ph_start = int(phase['end']) 219 | printf("Optimization Finished!", style="yellow") 220 | if ep_best >= 0: 221 | if args_global.cpu_eval: 222 | t_model_eval.load_state_dict(torch.load(path_saver_t, map_location=lambda storage, loc: storage)) 223 | u_model_eval.load_state_dict(torch.load(path_saver_u, map_location=lambda storage, loc: storage)) 224 | else: 225 | t_model.load_state_dict(torch.load(path_saver_t)) 226 | u_model.load_state_dict(torch.load(path_saver_u)) 227 | t_model_eval=t_model 228 | u_model_eval=u_model 229 | printf(' Restoring model ...', style='yellow') 230 | 231 | loss_vate, preds_vate, labels_vate = eval_step(t_model_eval,t_minibatch_eval,u_model_eval,u_minibatch_eval,mode='valtest') 232 | acc_val, acc_test, pre_val, pre_test, rec_val, rec_test, f1_val, f1_test = evaluate_source_news(t_model_eval, t_minibatch_eval, 233 | preds_vate, labels_vate, mode='valtest') 234 | 235 | printf("Full validation (Epoch {:4d}): \n Accuracy = {:.4f}\tPrecision = {:.4f}\tRecall = {:.4f}\tF1 = {:.4f}"\ 236 | .format(ep_best, acc_val, pre_val, rec_val, f1_val), style='red') 237 | printf("Full test stats: \n Accuracy = {:.4f}\tPrecision = {:.4f}\tRecall = {:.4f}\tF1 = {:.4f}"\ 238 | .format(acc_test, pre_test, rec_test, f1_test), style='red') 239 | 240 | 241 | if __name__ == '__main__': 242 | log_dir(args_global.train_config, args_global.data_prefix, git_branch, git_rev, timestamp) 243 | train_params, train_phases, arch_gcn = parse_n_prepare(args_global) 244 | text_train_data, user_train_data = data_prepare(args_global) 245 | if 'eval_val_every' not in train_params: 246 | train_params['eval_val_every'] = EVAL_VAL_EVERY_EP 247 | 248 | t_model, t_minibatch, t_minibatch_eval, t_model_eval = text_prepare(text_train_data, train_params, arch_gcn) 249 | u_model, u_minibatch, u_minibatch_eval, u_model_eval = user_prepare(user_train_data, train_params, arch_gcn) 250 | train(train_phases, t_model, t_minibatch, t_minibatch_eval, t_model_eval, u_model, u_minibatch, u_minibatch_eval, u_model_eval, train_params['eval_val_every']) 251 | -------------------------------------------------------------------------------- /graphsaint/cython_sampler.pyx: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | # distutils: language=c++ 3 | # distutils: extra_compile_args = -fopenmp -std=c++11 4 | # distutils: extra_link_args = -fopenmp 5 | cimport cython 6 | from cython.parallel import prange,parallel 7 | from cython.operator import dereference as deref, preincrement as inc 8 | from cython cimport Py_buffer 9 | from libcpp.vector cimport vector 10 | from libcpp.algorithm cimport sort, unique, lower_bound 11 | from libcpp.map cimport map 12 | from libcpp.unordered_map cimport unordered_map 13 | from libcpp.utility cimport pair 14 | import numpy as np 15 | cimport numpy as np 16 | from libc.stdio cimport printf 17 | from libcpp cimport bool 18 | import time,math 19 | import random 20 | from libc.stdlib cimport rand 21 | cdef extern from "stdlib.h": 22 | int RAND_MAX 23 | 24 | cimport graphsaint.cython_utils as cutils 25 | import graphsaint.cython_utils as cutils 26 | 27 | 28 | 29 | cdef class Sampler: 30 | cdef int num_proc,num_sample_per_proc 31 | cdef vector[int] adj_indptr_vec 32 | cdef vector[int] adj_indices_vec 33 | cdef vector[int] node_train_vec 34 | cdef vector[vector[int]] node_sampled 35 | cdef vector[vector[int]] ret_indptr 36 | cdef vector[vector[int]] ret_indices 37 | cdef vector[vector[int]] ret_indices_orig 38 | cdef vector[vector[float]] ret_data 39 | cdef vector[vector[int]] ret_edge_index 40 | 41 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 42 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 43 | np.ndarray[int,ndim=1,mode='c'] node_train, 44 | int num_proc, int num_sample_per_proc,*argv): 45 | cutils.npy2vec_int(adj_indptr,self.adj_indptr_vec) 46 | cutils.npy2vec_int(adj_indices,self.adj_indices_vec) 47 | cutils.npy2vec_int(node_train,self.node_train_vec) 48 | self.num_proc = num_proc 49 | self.num_sample_per_proc = num_sample_per_proc 50 | self.node_sampled = vector[vector[int]](num_proc*num_sample_per_proc) 51 | self.ret_indptr = vector[vector[int]](num_proc*num_sample_per_proc) 52 | self.ret_indices = vector[vector[int]](num_proc*num_sample_per_proc) 53 | self.ret_indices_orig = vector[vector[int]](num_proc*num_sample_per_proc) 54 | self.ret_data = vector[vector[float]](num_proc*num_sample_per_proc) 55 | self.ret_edge_index = vector[vector[int]](num_proc*num_sample_per_proc) 56 | 57 | cdef void adj_extract(self, int p) nogil: 58 | """ 59 | Extract a subg adj matrix from the original training adj matrix 60 | ret_indices_orig: the indices vector corresponding to node id in original G. 61 | """ 62 | cdef int r = 0 63 | cdef int idx_g = 0 64 | cdef int i, i_end, v, j 65 | cdef int num_v_orig, num_v_sub 66 | cdef int start_neigh, end_neigh 67 | cdef vector[int] _arr_bit 68 | cdef int cumsum 69 | num_v_orig = self.adj_indptr_vec.size()-1 70 | while r < self.num_sample_per_proc: 71 | _arr_bit = vector[int](num_v_orig,-1) 72 | idx_g = p*self.num_sample_per_proc+r 73 | num_v_sub = self.node_sampled[idx_g].size() 74 | self.ret_indptr[idx_g] = vector[int](num_v_sub+1,0) 75 | self.ret_indices[idx_g] = vector[int]() 76 | self.ret_indices_orig[idx_g] = vector[int]() 77 | self.ret_data[idx_g] = vector[float]() 78 | self.ret_edge_index[idx_g]=vector[int]() 79 | i_end = num_v_sub 80 | i = 0 81 | while i < i_end: 82 | _arr_bit[self.node_sampled[idx_g][i]] = i 83 | i = i + 1 84 | i = 0 85 | while i < i_end: 86 | v = self.node_sampled[idx_g][i] 87 | start_neigh = self.adj_indptr_vec[v] 88 | end_neigh = self.adj_indptr_vec[v+1] 89 | j = start_neigh 90 | while j < end_neigh: 91 | if _arr_bit[self.adj_indices_vec[j]] > -1: 92 | self.ret_indices[idx_g].push_back(_arr_bit[self.adj_indices_vec[j]]) 93 | self.ret_indices_orig[idx_g].push_back(self.adj_indices_vec[j]) 94 | self.ret_edge_index[idx_g].push_back(j) 95 | self.ret_indptr[idx_g][_arr_bit[v]+1] = self.ret_indptr[idx_g][_arr_bit[v]+1] + 1 96 | self.ret_data[idx_g].push_back(1.) 97 | j = j + 1 98 | i = i + 1 99 | cumsum = self.ret_indptr[idx_g][0] 100 | i = 0 101 | while i < i_end: 102 | cumsum = cumsum + self.ret_indptr[idx_g][i+1] 103 | self.ret_indptr[idx_g][i+1] = cumsum 104 | i = i + 1 105 | r = r + 1 106 | 107 | def get_return(self): 108 | """ 109 | Convert the subgraph related data structures from C++ to python. So that cython 110 | can return them to the PyTorch trainer. 111 | 112 | Inputs: 113 | None 114 | 115 | Outputs: 116 | see outputs of the `par_sample()` function. 117 | """ 118 | num_subg = self.num_proc*self.num_sample_per_proc 119 | l_subg_indptr = [] 120 | l_subg_indices = [] 121 | l_subg_data = [] 122 | l_subg_nodes = [] 123 | l_subg_edge_index = [] 124 | offset_nodes = [0] 125 | offset_indptr = [0] 126 | offset_indices = [0] 127 | offset_data = [0] 128 | offset_edge_index = [0] 129 | for r in range(num_subg): 130 | offset_nodes.append(offset_nodes[r]+self.node_sampled[r].size()) 131 | offset_indptr.append(offset_indptr[r]+self.ret_indptr[r].size()) 132 | offset_indices.append(offset_indices[r]+self.ret_indices[r].size()) 133 | offset_data.append(offset_data[r]+self.ret_data[r].size()) 134 | offset_edge_index.append(offset_edge_index[r]+self.ret_edge_index[r].size()) 135 | cdef vector[int] ret_nodes_vec = vector[int]() 136 | cdef vector[int] ret_indptr_vec = vector[int]() 137 | cdef vector[int] ret_indices_vec = vector[int]() 138 | cdef vector[int] ret_edge_index_vec = vector[int]() 139 | cdef vector[float] ret_data_vec = vector[float]() 140 | ret_nodes_vec.reserve(offset_nodes[num_subg]) 141 | ret_indptr_vec.reserve(offset_indptr[num_subg]) 142 | ret_indices_vec.reserve(offset_indices[num_subg]) 143 | ret_data_vec.reserve(offset_data[num_subg]) 144 | ret_edge_index_vec.reserve(offset_edge_index[num_subg]) 145 | for r in range(num_subg): 146 | ret_nodes_vec.insert(ret_nodes_vec.end(),self.node_sampled[r].begin(),self.node_sampled[r].end()) 147 | ret_indptr_vec.insert(ret_indptr_vec.end(),self.ret_indptr[r].begin(),self.ret_indptr[r].end()) 148 | ret_indices_vec.insert(ret_indices_vec.end(),self.ret_indices[r].begin(),self.ret_indices[r].end()) 149 | ret_edge_index_vec.insert(ret_edge_index_vec.end(),self.ret_edge_index[r].begin(),self.ret_edge_index[r].end()) 150 | ret_data_vec.insert(ret_data_vec.end(),self.ret_data[r].begin(),self.ret_data[r].end()) 151 | 152 | cdef cutils.array_wrapper_int wint_indptr = cutils.array_wrapper_int() 153 | cdef cutils.array_wrapper_int wint_indices = cutils.array_wrapper_int() 154 | cdef cutils.array_wrapper_int wint_nodes = cutils.array_wrapper_int() 155 | cdef cutils.array_wrapper_float wfloat_data = cutils.array_wrapper_float() 156 | cdef cutils.array_wrapper_int wint_edge_index = cutils.array_wrapper_int() 157 | 158 | wint_indptr.set_data(ret_indptr_vec) 159 | ret_indptr_np = np.frombuffer(wint_indptr,dtype=np.int32) 160 | wint_indices.set_data(ret_indices_vec) 161 | ret_indices_np = np.frombuffer(wint_indices,dtype=np.int32) 162 | wint_nodes.set_data(ret_nodes_vec) 163 | ret_nodes_np = np.frombuffer(wint_nodes,dtype=np.int32) 164 | wfloat_data.set_data(ret_data_vec) 165 | ret_data_np = np.frombuffer(wfloat_data,dtype=np.float32) 166 | wint_edge_index.set_data(ret_edge_index_vec) 167 | ret_edge_index_np = np.frombuffer(wint_edge_index,dtype=np.int32) 168 | 169 | for r in range(num_subg): 170 | l_subg_nodes.append(ret_nodes_np[offset_nodes[r]:offset_nodes[r+1]]) 171 | l_subg_indptr.append(ret_indptr_np[offset_indptr[r]:offset_indptr[r+1]]) 172 | l_subg_indices.append(ret_indices_np[offset_indices[r]:offset_indices[r+1]]) 173 | l_subg_data.append(ret_data_np[offset_data[r]:offset_data[r+1]]) 174 | l_subg_edge_index.append(ret_edge_index_np[offset_indices[r]:offset_indices[r+1]]) 175 | 176 | return l_subg_indptr,l_subg_indices,l_subg_data,l_subg_nodes,l_subg_edge_index 177 | 178 | cdef void sample(self, int p) nogil: 179 | pass 180 | 181 | @cython.boundscheck(False) 182 | @cython.wraparound(False) 183 | def par_sample(self): 184 | """ 185 | The main function for the sampler class. It launches multiple independent samplers 186 | in parallel (task parallelism by openmp), where the serial sampling function is defined 187 | in the corresponding sub-class. Then it returns node-induced subgraph by `_adj_extract()`, 188 | and convert C++ vectors to python lists / numpy arrays by `_get_return()`. 189 | 190 | Suppose we sample P subgraphs in parallel. Each subgraph has n nodes and e edges. 191 | 192 | Inputs: 193 | None 194 | 195 | Outputs (elements in the list of `ret`): 196 | l_subg_indptr list of np array, length of list = P and length of each array is n+1 197 | l_subg_indices list of np array, length of list = P and length of each array is m. 198 | node IDs in the array are renamed to be subgraph ID (range: 0 ~ n-1) 199 | l_subg_data list of np array, length of list = P and length of each array is m. 200 | Normally, values in the array should be all 1. 201 | l_subg_nodes list of np array, length of list = P and length of each array is n. 202 | Element i in the array shows the training graph node ID of the i-th 203 | subgraph node. 204 | l_subg_edge_index list of np array, length of list = P and length of each array is m. 205 | Element i in the array shows the training graph edge index of the 206 | i-the subgraph edge. 207 | """ 208 | cdef int p = 0 209 | with nogil, parallel(num_threads=self.num_proc): 210 | for p in prange(self.num_proc,schedule='dynamic'): 211 | self.sample(p) 212 | self.adj_extract(p) 213 | ret = self.get_return() 214 | _len = self.num_proc*self.num_sample_per_proc 215 | self.node_sampled.swap(vector[vector[int]](_len)) 216 | self.ret_indptr.swap(vector[vector[int]](_len)) 217 | self.ret_indices.swap(vector[vector[int]](_len)) 218 | self.ret_indices_orig.swap(vector[vector[int]](_len)) 219 | self.ret_data.swap(vector[vector[float]](_len)) 220 | self.ret_edge_index.swap(vector[vector[int]](_len)) 221 | return ret 222 | 223 | 224 | # ---------------------------------------------------- 225 | 226 | cdef class MRW(Sampler): 227 | cdef int size_frontier,size_subg 228 | cdef int avg_deg 229 | cdef vector[int] arr_deg_vec 230 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 231 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 232 | np.ndarray[int,ndim=1,mode='c'] node_train, 233 | int num_proc, int num_sample_per_proc, 234 | np.ndarray[int,ndim=1,mode='c'] p_dist, 235 | int max_deg, int size_frontier, int size_subg): 236 | self.size_frontier = size_frontier 237 | self.size_subg = size_subg 238 | _arr_deg = np.clip(p_dist,0,max_deg) 239 | cutils.npy2vec_int(_arr_deg,self.arr_deg_vec) 240 | self.avg_deg = _arr_deg.mean() 241 | 242 | cdef void sample(self, int p) nogil: 243 | cdef vector[int] frontier 244 | cdef int i = 0 245 | cdef int num_train_node = self.node_train_vec.size() 246 | cdef int r = 0 247 | cdef int alpha = 2 248 | cdef vector[int] arr_ind0 249 | cdef vector[int] arr_ind1 250 | cdef vector[int].iterator it 251 | arr_ind0.reserve(alpha*self.avg_deg) 252 | arr_ind1.reserve(alpha*self.avg_deg) 253 | cdef int c, cnt, j, k 254 | cdef int v, vidx, vpop, vneigh, offset, vnext 255 | cdef int idx_begin, idx_end 256 | cdef int num_neighs_pop, num_neighs_next 257 | while r < self.num_sample_per_proc: 258 | # prepare initial frontier 259 | arr_ind0.clear() 260 | arr_ind1.clear() 261 | frontier.clear() 262 | i = 0 263 | while i < self.size_frontier: # NB: here we don't care if a node appear twice 264 | frontier.push_back(self.node_train_vec[rand()%num_train_node]) 265 | i = i + 1 266 | # init indicator array 267 | it = frontier.begin() 268 | while it != frontier.end(): 269 | v = deref(it) 270 | cnt = arr_ind0.size() 271 | c = cnt 272 | while c < cnt + self.arr_deg_vec[v]: 273 | arr_ind0.push_back(v) 274 | arr_ind1.push_back(c-cnt) 275 | c = c + 1 276 | arr_ind1[cnt] = -self.arr_deg_vec[v] 277 | inc(it) 278 | # iteratively update frontier 279 | j = self.size_frontier 280 | while j < self.size_subg: 281 | # select next node to pop out of frontier 282 | while True: 283 | vidx = rand()%arr_ind0.size() 284 | vpop = arr_ind0[vidx] 285 | if vpop >= 0: 286 | break 287 | # prepare to update arr_ind* 288 | offset = arr_ind1[vidx] 289 | if offset < 0: 290 | idx_begin = vidx 291 | idx_end = idx_begin - offset 292 | else: 293 | idx_begin = vidx - offset 294 | idx_end = idx_begin - arr_ind1[idx_begin] 295 | # cleanup 1: invalidate entries 296 | k = idx_begin 297 | while k < idx_end: 298 | arr_ind0[k] = -1 299 | arr_ind1[k] = 0 300 | k = k + 1 301 | # cleanup 2: add new entries 302 | num_neighs_pop = self.adj_indptr_vec[vpop+1] - self.adj_indptr_vec[vpop] 303 | vnext = self.adj_indices_vec[self.adj_indptr_vec[vpop]+rand()%num_neighs_pop] 304 | self.node_sampled[p*self.num_sample_per_proc+r].push_back(vnext) 305 | num_neighs_next = self.arr_deg_vec[vnext] 306 | cnt = arr_ind0.size() 307 | c = cnt 308 | while c < cnt + num_neighs_next: 309 | arr_ind0.push_back(vnext) 310 | arr_ind1.push_back(c-cnt) 311 | c = c + 1 312 | arr_ind1[cnt] = -num_neighs_next 313 | j = j + 1 314 | self.node_sampled[p*self.num_sample_per_proc+r].insert(self.node_sampled[p*self.num_sample_per_proc+r].end(),frontier.begin(),frontier.end()) 315 | sort(self.node_sampled[p*self.num_sample_per_proc+r].begin(),self.node_sampled[p*self.num_sample_per_proc+r].end()) 316 | self.node_sampled[p*self.num_sample_per_proc+r].erase(unique(self.node_sampled[p*self.num_sample_per_proc+r].begin(),\ 317 | self.node_sampled[p*self.num_sample_per_proc+r].end()),self.node_sampled[p*self.num_sample_per_proc+r].end()) 318 | r = r + 1 319 | 320 | 321 | 322 | # ---------------------------------------------------- 323 | 324 | cdef class RW(Sampler): 325 | cdef int size_root, size_depth 326 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 327 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 328 | np.ndarray[int,ndim=1,mode='c'] node_train, 329 | int num_proc, int num_sample_per_proc, 330 | int size_root, int size_depth): 331 | self.size_root = size_root 332 | self.size_depth = size_depth 333 | 334 | cdef void sample(self, int p) nogil: 335 | cdef int iroot = 0 336 | cdef int idepth = 0 337 | cdef int r = 0 338 | cdef int idx_subg 339 | cdef int v 340 | cdef int num_train_node = self.node_train_vec.size() 341 | while r < self.num_sample_per_proc: 342 | idx_subg = p*self.num_sample_per_proc+r 343 | # sample root 344 | iroot = 0 345 | while iroot < self.size_root: 346 | v = self.node_train_vec[rand()%num_train_node] 347 | self.node_sampled[idx_subg].push_back(v) 348 | # sample random walk 349 | idepth = 0 350 | while idepth < self.size_depth: 351 | if (self.adj_indptr_vec[v+1]-self.adj_indptr_vec[v]>0): 352 | v = self.adj_indices_vec[self.adj_indptr_vec[v]+rand()%(self.adj_indptr_vec[v+1]-self.adj_indptr_vec[v])] 353 | self.node_sampled[idx_subg].push_back(v) 354 | idepth = idepth + 1 355 | iroot = iroot + 1 356 | r = r + 1 357 | sort(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()) 358 | self.node_sampled[idx_subg].erase(unique(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()),self.node_sampled[idx_subg].end()) 359 | 360 | 361 | 362 | 363 | # ---------------------------------------------------- 364 | 365 | cdef class Edge(Sampler): 366 | cdef vector[int] row_train_vec 367 | cdef vector[int] col_train_vec 368 | cdef vector[float] prob_edge_vec 369 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 370 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 371 | np.ndarray[int,ndim=1,mode='c'] node_train, 372 | int num_proc, int num_sample_per_proc, 373 | np.ndarray[int,ndim=1,mode='c'] row_train, 374 | np.ndarray[int,ndim=1,mode='c'] col_train, 375 | np.ndarray[float,ndim=1,mode='c'] prob_edge,*argv): 376 | cutils.npy2vec_int(row_train,self.row_train_vec) 377 | cutils.npy2vec_int(col_train,self.col_train_vec) 378 | cutils.npy2vec_float(prob_edge,self.prob_edge_vec) 379 | 380 | cdef void sample(self, int p) nogil: 381 | cdef int num_edge = self.row_train_vec.size() 382 | cdef int i=0 383 | cdef float ran=0. 384 | cdef int g=0 385 | cdef int idx_subg 386 | while g < self.num_sample_per_proc: 387 | idx_subg = p*self.num_sample_per_proc+g 388 | i = 0 389 | while i < num_edge: 390 | ran = ( rand()) / RAND_MAX 391 | if ran > self.prob_edge_vec[i]: 392 | # edge not selected 393 | i = i + 1 394 | continue 395 | self.node_sampled[idx_subg].push_back(self.row_train_vec[i]) 396 | self.node_sampled[idx_subg].push_back(self.col_train_vec[i]) 397 | i = i + 1 398 | sort(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()) 399 | self.node_sampled[idx_subg].erase(unique(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()),self.node_sampled[idx_subg].end()) 400 | g = g + 1 401 | 402 | 403 | cdef class Edge2(Sampler): 404 | """ 405 | approximate version of the above Edge class 406 | """ 407 | cdef vector[int] row_train_vec 408 | cdef vector[int] col_train_vec 409 | cdef vector[float] p_dist_cumsum_vec 410 | cdef int size_subg_e 411 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 412 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 413 | np.ndarray[int,ndim=1,mode='c'] node_train, 414 | int num_proc, int num_sample_per_proc, 415 | np.ndarray[int,ndim=1,mode='c'] row_train, 416 | np.ndarray[int,ndim=1,mode='c'] col_train, 417 | np.ndarray[float,ndim=1,mode='c'] p_dist_cumsum, 418 | int size_subg_e): 419 | self.size_subg_e = size_subg_e 420 | cutils.npy2vec_int(row_train,self.row_train_vec) 421 | cutils.npy2vec_int(col_train,self.col_train_vec) 422 | cutils.npy2vec_float(p_dist_cumsum,self.p_dist_cumsum_vec) 423 | 424 | cdef void sample(self, int p) nogil: 425 | cdef int i = 0 426 | cdef int r = 0 427 | cdef int e 428 | cdef int idx_subg 429 | cdef float ran = 0. 430 | cdef float ran_range = self.p_dist_cumsum_vec[self.p_dist_cumsum_vec.size()-1] 431 | while r < self.num_sample_per_proc: 432 | idx_subg = p*self.num_sample_per_proc+r 433 | i = 0 434 | while i < self.size_subg_e: 435 | ran = ( rand()) / RAND_MAX * ran_range 436 | e = lower_bound(self.p_dist_cumsum_vec.begin(),self.p_dist_cumsum_vec.end(),ran)-self.p_dist_cumsum_vec.begin() 437 | self.node_sampled[idx_subg].push_back(self.row_train_vec[e]) 438 | self.node_sampled[idx_subg].push_back(self.col_train_vec[e]) 439 | i = i + 1 440 | sort(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()) 441 | self.node_sampled[idx_subg].erase(unique(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()),self.node_sampled[idx_subg].end()) 442 | r = r + 1 443 | 444 | # ---------------------------------------------------- 445 | 446 | cdef class Node(Sampler): 447 | cdef int size_subg 448 | cdef vector[int] p_dist_cumsum_vec 449 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 450 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 451 | np.ndarray[int,ndim=1,mode='c'] node_train, 452 | int num_proc, int num_sample_per_proc, 453 | np.ndarray[int,ndim=1,mode='c'] p_dist_cumsum, 454 | int size_subg): 455 | self.size_subg = size_subg 456 | cutils.npy2vec_int(p_dist_cumsum,self.p_dist_cumsum_vec) 457 | 458 | cdef void sample(self, int p) nogil: 459 | cdef int i = 0 460 | cdef int r = 0 461 | cdef int idx_subg 462 | cdef int sample 463 | cdef int rand_range = self.p_dist_cumsum_vec[self.node_train_vec.size()-1] 464 | while r < self.num_sample_per_proc: 465 | idx_subg = p*self.num_sample_per_proc+r 466 | i = 0 467 | while i < self.size_subg: 468 | sample = rand()%rand_range 469 | self.node_sampled[idx_subg].push_back(self.node_train_vec[lower_bound(self.p_dist_cumsum_vec.begin(),self.p_dist_cumsum_vec.end(),sample)-self.p_dist_cumsum_vec.begin()]) 470 | i = i + 1 471 | r = r + 1 472 | sort(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()) 473 | self.node_sampled[idx_subg].erase(unique(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()),self.node_sampled[idx_subg].end()) 474 | 475 | # ----------------------------------------------------- 476 | 477 | cdef class FullBatch(Sampler): 478 | def __cinit__(self, np.ndarray[int,ndim=1,mode='c'] adj_indptr, 479 | np.ndarray[int,ndim=1,mode='c'] adj_indices, 480 | np.ndarray[int,ndim=1,mode='c'] node_train, 481 | int num_proc, int num_sample_per_proc): 482 | pass 483 | 484 | cdef void sample(self, int p) nogil: 485 | cdef int i = 0 486 | cdef int r = 0 487 | cdef int idx_subg 488 | cdef int sample 489 | while r < self.num_sample_per_proc: 490 | idx_subg = p*self.num_sample_per_proc+r 491 | i = 0 492 | while i < self.node_train_vec.size(): 493 | sample = i 494 | self.node_sampled[idx_subg].push_back(self.node_train_vec[sample]) 495 | i = i + 1 496 | r = r + 1 497 | sort(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()) 498 | self.node_sampled[idx_subg].erase(unique(self.node_sampled[idx_subg].begin(),self.node_sampled[idx_subg].end()),self.node_sampled[idx_subg].end()) 499 | -------------------------------------------------------------------------------- /embedding/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import scipy.sparse as sp 4 | 5 | 6 | F_ACT = {'relu': nn.ReLU(), 7 | 'I': lambda x:x} 8 | 9 | """ 10 | NOTE 11 | For the various GNN layers, we optionally support batch normalization. Yet, due to the 12 | non-IID nature of GNN samplers (whether it is graph-sampling or layer sampling based), 13 | we may need some modification to the standard batch-norm layer operations to achieve 14 | optimal accuracy on graphs. 15 | 16 | The study of optimal GNN-based batch-norm is out-of-scope for the current version of 17 | GraphSAINT. So as a compromise, we provide multiple implementations of batch-norm 18 | layer which can be optionally inserted at the output of the GNN layers. 19 | 20 | Specifically, we have the various choices for the field in the layer classes 21 | 'bias' means no batch-norm is applied. Only add the bias to the hidden 22 | features of the GNN layer. 23 | 'norm' means we calculate the mean and variance of the GNN hidden features, 24 | and then scale the hidden features manually by the mean and variance. 25 | In this case, we need explicitly create the 'offset' and 'scale' params. 26 | 'norm-nn' means we use the torch.nn.BatchNorm1d layer implemented by torch. 27 | In this case, no need to explicitly maintain the BN internal params. 28 | """ 29 | 30 | 31 | class HighOrderAggregator(nn.Module): 32 | def __init__(self, dim_in, dim_out, dropout=0., act='relu', \ 33 | order=1, aggr='mean', bias='norm-nn', **kwargs): 34 | """ 35 | Layer implemented here combines the GraphSAGE-mean [1] layer with MixHop [2] layer. 36 | We define the concept of `order`: an order-k layer aggregates neighbor information 37 | from 0-hop all the way to k-hop. The operation is approximately: 38 | X W_0 [+] A X W_1 [+] ... [+] A^k X W_k 39 | where [+] is some aggregation operation such as addition or concatenation. 40 | 41 | Special cases: 42 | Order = 0 --> standard MLP layer 43 | Order = 1 --> standard GraphSAGE layer 44 | 45 | [1]: https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf 46 | [2]: https://arxiv.org/abs/1905.00067 47 | 48 | Inputs: 49 | dim_in int, feature dimension for input nodes 50 | dim_out int, feature dimension for output nodes 51 | dropout float, dropout on weight matrices W_0 to W_k 52 | act str, activation function. See F_ACT at the top of this file 53 | order int, see definition above 54 | aggr str, if 'mean' then [+] operation adds features of various hops 55 | if 'concat' then [+] concatenates features of various hops 56 | bias str, if 'bias' then apply a bias vector to features of each hop 57 | if 'norm' then perform batch-normalization on output features 58 | 59 | Outputs: 60 | None 61 | """ 62 | super(HighOrderAggregator,self).__init__() 63 | assert bias in ['bias', 'norm', 'norm-nn'] 64 | self.order, self.aggr = order, aggr 65 | self.act, self.bias = F_ACT[act], bias 66 | self.dropout = dropout 67 | self.f_lin, self.f_bias = [], [] 68 | self.offset, self.scale = [], [] 69 | self.num_param = 0 70 | for o in range(self.order + 1): 71 | self.f_lin.append(nn.Linear(dim_in, dim_out, bias=False)) 72 | nn.init.xavier_uniform_(self.f_lin[-1].weight) 73 | self.f_bias.append(nn.Parameter(torch.zeros(dim_out))) 74 | self.num_param += dim_in * dim_out 75 | self.num_param += dim_out 76 | self.offset.append(nn.Parameter(torch.zeros(dim_out))) 77 | self.scale.append(nn.Parameter(torch.ones(dim_out))) 78 | if self.bias == 'norm' or self.bias == 'norm-nn': 79 | self.num_param += 2 * dim_out 80 | self.f_lin = nn.ModuleList(self.f_lin) 81 | self.f_dropout = nn.Dropout(p=self.dropout) 82 | self.params = nn.ParameterList(self.f_bias + self.offset + self.scale) 83 | self.f_bias = self.params[:self.order + 1] 84 | if self.bias == 'norm': 85 | self.offset = self.params[self.order + 1 : 2 * self.order + 2] 86 | self.scale = self.params[2 * self.order + 2 : ] 87 | elif self.bias == 'norm-nn': 88 | final_dim_out = dim_out * ((aggr=='concat') * (order + 1) + (aggr=='mean')) 89 | self.f_norm = nn.BatchNorm1d(final_dim_out, eps=1e-9, track_running_stats=True) 90 | self.num_param = int(self.num_param) 91 | 92 | def _spmm(self, adj_norm, _feat): 93 | """ sparce feature matrix multiply dense feature matrix """ 94 | # alternative ways: use geometric.propagate or torch.mm 95 | return torch.sparse.mm(adj_norm, _feat) 96 | 97 | def _f_feat_trans(self, _feat, _id): 98 | feat = self.act(self.f_lin[_id](_feat) + self.f_bias[_id]) 99 | if self.bias == 'norm': 100 | mean = feat.mean(dim=1).view(feat.shape[0],1) 101 | var = feat.var(dim=1, unbiased=False).view(feat.shape[0], 1) + 1e-9 102 | feat_out = (feat - mean) * self.scale[_id] * torch.rsqrt(var) + self.offset[_id] 103 | else: 104 | feat_out = feat 105 | return feat_out 106 | 107 | def forward(self, inputs): 108 | """ 109 | Inputs:. 110 | adj_norm normalized adj matrix of the subgraph 111 | feat_in 2D matrix of input node features 112 | 113 | Outputs: 114 | adj_norm same as input (to facilitate nn.Sequential) 115 | feat_out 2D matrix of output node features 116 | """ 117 | adj_norm, feat_in = inputs 118 | feat_in = self.f_dropout(feat_in) 119 | feat_hop = [feat_in] 120 | # generate A^i X 121 | for o in range(self.order): 122 | # propagate(edge_index, x=x, norm=norm) 123 | feat_hop.append(self._spmm(adj_norm, feat_hop[-1])) 124 | feat_partial = [self._f_feat_trans(ft, idf) for idf, ft in enumerate(feat_hop)] 125 | if self.aggr == 'mean': 126 | feat_out = feat_partial[0] 127 | for o in range(len(feat_partial) - 1): 128 | feat_out += feat_partial[o + 1] 129 | elif self.aggr == 'concat': 130 | feat_out = torch.cat(feat_partial, 1) 131 | else: 132 | raise NotImplementedError 133 | if self.bias == 'norm-nn': 134 | feat_out = self.f_norm(feat_out) 135 | return adj_norm, feat_out # return adj_norm to support Sequential 136 | 137 | 138 | class JumpingKnowledge(nn.Module): 139 | def __init__(self): 140 | """ 141 | To be added soon. For now please see the tensorflow version for JK layers 142 | """ 143 | pass 144 | 145 | 146 | class AttentionAggregator(nn.Module): 147 | """ 148 | This layer follows the design of Graph Attention Network (GAT: https://arxiv.org/abs/1710.10903). 149 | We extend GAT to higher order as well (see the HighOrderAggregator class above), even though most 150 | of the time, order-1 layer should be sufficient. The enhancement to SAGE-mean architecture is 151 | that GAT performs *weighted* aggregation on neighbor features. The edge weight is generated by 152 | additional learnable MLP layer. Such weight means "attention". GAT proposed multi-head attention 153 | so that there can be multiple weights for each edge. The k-head attention can be speficied by the 154 | `mulhead` parameter. 155 | 156 | Note that 157 | 1. In GraphSAINT minibatch training, we remove the softmax normalization across the neighbors. 158 | Reason: since the minibatch does not see the full neighborhood, softmax does not make much 159 | sense now. We see significant accuracy improvement by removing the softmax step. See also 160 | Equations 8 and 9, Appendix C.3 of GraphSAINT (https://arxiv.org/pdf/1907.04931.pdf). 161 | 2. For order > 1, we obtain attention from neighbors from lower order up to higher order. 162 | 163 | Inputs: 164 | dim_in int, feature dimension for input nodes 165 | dim_out int, feature dimension for output nodes 166 | dropout float, dropout on weight matrices W_0 to W_k 167 | act str, activation function. See F_ACT at the top of this file 168 | order int, see definition in HighOrderAggregator 169 | aggr str, if 'mean' then [+] operation adds features of various hops 170 | if 'concat' then [+] concatenates features of various hops 171 | bias str, if 'bias' then apply a bias vector to features of each hop 172 | if 'norm' then perform batch-normalization on output features 173 | mulhead int, the number of heads for attention 174 | 175 | Outputs: 176 | None 177 | """ 178 | def __init__(self, dim_in, dim_out, dropout=0., act='relu', \ 179 | order=1, aggr='mean', bias='norm', mulhead=1): 180 | super(AttentionAggregator,self).__init__() 181 | assert bias in ['bias', 'norm', 'norm-nn'] 182 | self.num_param = 0 183 | self.mulhead = mulhead 184 | self.order, self.aggr = order, aggr 185 | self.act, self.bias = F_ACT[act], bias 186 | self.att_act = nn.LeakyReLU(negative_slope=0.2) 187 | self.dropout = dropout 188 | self._f_lin = [] 189 | self._offset, self._scale = [], [] 190 | self._attention = [] 191 | # mostly we have order = 1 for GAT 192 | # "+1" since we do batch norm of order-0 and order-1 outputs separately 193 | for o in range(self.order+1): 194 | for i in range(self.mulhead): 195 | self._f_lin.append(nn.Linear(dim_in, int(dim_out / self.mulhead), bias=True)) 196 | nn.init.xavier_uniform_(self._f_lin[-1].weight) 197 | # _offset and _scale are for 'norm' type of batch norm 198 | self._offset.append(nn.Parameter(torch.zeros(int(dim_out / self.mulhead)))) 199 | self._scale.append(nn.Parameter(torch.ones(int(dim_out / self.mulhead)))) 200 | self.num_param += dim_in * dim_out / self.mulhead + 2 * dim_out / self.mulhead 201 | if o < self.order: 202 | self._attention.append(nn.Parameter(torch.ones(1, int(dim_out / self.mulhead * 2)))) 203 | nn.init.xavier_uniform_(self._attention[-1]) 204 | self.num_param += dim_out / self.mulhead * 2 205 | self.mods = nn.ModuleList(self._f_lin) 206 | self.f_dropout = nn.Dropout(p=self.dropout) 207 | self.params = nn.ParameterList(self._offset + self._scale + self._attention) 208 | self.f_lin = [] 209 | self.offset, self.scale = [], [] 210 | self.attention = [] 211 | # We need traverse order and mulhead the second time, just because we want to support 212 | # higher order. Reason: if we have torch parameters in a python list, i.e.: 213 | # [nn.Parameter(), nn.Parameter(), ...] 214 | # PyTorch cannot automically add these parameters into the learnable parameters. 215 | for o in range(self.order+1): 216 | self.f_lin.append([]) 217 | self.offset.append([]) 218 | self.scale.append([]) 219 | self.attention.append([]) 220 | for i in range(self.mulhead): 221 | self.f_lin[-1].append(self.mods[o * self.mulhead + i]) 222 | if self.bias == 'norm': # not used in 'norm-nn' mode 223 | self.offset[-1].append(self.params[o * self.mulhead + i]) 224 | self.scale[-1].append(self.params[len(self._offset) + o * self.mulhead + i]) 225 | if o < self.order: # excluding the order-0 part 226 | self.attention[-1].append(self.params[len(self._offset) * 2 + o * self.mulhead + i]) 227 | if self.bias == 'norm-nn': 228 | final_dim_out = dim_out*((aggr=='concat')*(order+1) + (aggr=='mean')) 229 | self.f_norm = nn.BatchNorm1d(final_dim_out, eps=1e-9, track_running_stats=True) 230 | self.num_param = int(self.num_param) 231 | 232 | def _spmm(self, adj_norm, _feat): 233 | return torch.sparse.mm(adj_norm, _feat) 234 | 235 | def _f_feat_trans(self, _feat, f_lin): 236 | feat_out = [] 237 | for i in range(self.mulhead): 238 | feat_out.append(self.act(f_lin[i](_feat))) 239 | return feat_out 240 | 241 | def _aggregate_attention(self, adj, feat_neigh, feat_self, attention): 242 | attention_self = self.att_act(attention[:, : feat_self.shape[1]].mm(feat_self.t())).squeeze() 243 | attention_neigh = self.att_act(attention[:, feat_neigh.shape[1] :].mm(feat_neigh.t())).squeeze() 244 | attention_norm = (attention_self[adj._indices()[0]] + attention_neigh[adj._indices()[1]]) * adj._values() 245 | att_adj = torch.sparse.FloatTensor(adj._indices(), attention_norm, torch.Size(adj.shape)) 246 | return self._spmm(att_adj, feat_neigh) 247 | 248 | def forward(self, inputs): 249 | """ 250 | Inputs: 251 | inputs tuple / list of two elements: 252 | 1. feat_in: 2D matrix of node features input to the layer 253 | 2. adj_norm: normalized subgraph adj. Normalization should 254 | consider both the node degree and aggregation normalization 255 | 256 | Outputs: 257 | feat_out 2D matrix of features for output nodes of the layer 258 | adj_norm normalized adj same as the input. We have to return it to 259 | support nn.Sequential called in models.py 260 | """ 261 | adj_norm, feat_in = inputs 262 | feat_in = self.f_dropout(feat_in) 263 | # generate A^i X 264 | feat_partial = [] 265 | for o in range(self.order + 1): 266 | feat_partial.append(self._f_feat_trans(feat_in, self.f_lin[o])) 267 | for o in range(1,self.order + 1): 268 | for s in range(o): 269 | for i in range(self.mulhead): 270 | feat_partial[o][i] = self._aggregate_attention( 271 | adj_norm, 272 | feat_partial[o][i], 273 | feat_partial[o - s - 1][i], 274 | self.attention[o-1][i], 275 | ) 276 | if self.bias == 'norm': 277 | # normalize per-order, per-head 278 | for o in range(self.order + 1): 279 | for i in range(self.mulhead): 280 | mean = feat_partial[o][i].mean(dim=1).unsqueeze(1) 281 | var = feat_partial[o][i].var(dim=1, unbiased=False).unsqueeze(1) + 1e-9 282 | feat_partial[o][i] = (feat_partial[o][i] - mean) \ 283 | * self.scale[o][i] * torch.rsqrt(var) + self.offset[o][i] 284 | 285 | for o in range(self.order + 1): 286 | feat_partial[o] = torch.cat(feat_partial[o], 1) 287 | if self.aggr == 'mean': 288 | feat_out = feat_partial[0] 289 | for o in range(len(feat_partial) - 1): 290 | feat_out += feat_partial[o + 1] 291 | elif self.aggr == 'concat': 292 | feat_out = torch.cat(feat_partial, 1) 293 | else: 294 | raise NotImplementedError 295 | if self.bias == 'norm-nn': 296 | feat_out = self.f_norm(feat_out) 297 | return adj_norm, feat_out 298 | 299 | 300 | class GatedAttentionAggregator(nn.Module): 301 | """ 302 | Gated attentionn network (GaAN: https://arxiv.org/pdf/1803.07294.pdf). 303 | The general idea of attention is similar to GAT. The main difference is that GaAN adds 304 | a gated weight for each attention head. Therefore, we can selectively pick important 305 | heads for better expressive power. Note that this layer is quite expensive to execute, 306 | since the operations to compute attention are complicated. Therefore, we only support 307 | order <= 1 (See HighOrderAggregator for definition of order). 308 | 309 | Inputs: 310 | dim_in int, feature dimension for input nodes 311 | dim_out int, feature dimension for output nodes 312 | dropout float, dropout on weight matrices W_0 to W_k 313 | act str, activation function. See F_ACT at the top of this file 314 | order int, see definition in HighOrderAggregator 315 | aggr str, if 'mean' then [+] operation adds features of various hops 316 | if 'concat' then [+] concatenates features of various hops 317 | bias str, if 'bias' then apply a bias vector to features of each hop 318 | if 'norm' then perform batch-normalization on output features 319 | mulhead int, the number of heads for attention 320 | dim_gate int, output dimension of theta_m during gate value calculation 321 | 322 | Outputs: 323 | None 324 | """ 325 | 326 | def __init__( 327 | self, 328 | dim_in, 329 | dim_out, 330 | dropout=0.0, 331 | act="relu", 332 | order=1, 333 | aggr="mean", 334 | bias="norm", 335 | mulhead=1, 336 | dim_gate=64, 337 | ): 338 | super(GatedAttentionAggregator, self).__init__() 339 | self.num_param = 0 # TODO: update param count 340 | self.multi_head = mulhead 341 | assert self.multi_head > 0 and dim_out % self.multi_head == 0 342 | self.order, self.aggr = order, aggr 343 | self.act, self.bias = F_ACT[act], bias 344 | self.att_act = nn.LeakyReLU(negative_slope=0.2) 345 | self.dropout = dropout 346 | self.dim_gate = dim_gate 347 | self._f_lin = [] 348 | self._offset, self._scale = [], [] 349 | self._attention = [] 350 | for i in range(self.order + 1): 351 | self._offset.append(nn.Parameter(torch.zeros(dim_out))) 352 | self._scale.append(nn.Parameter(torch.ones(dim_out))) 353 | for _j in range(self.multi_head): 354 | self._f_lin.append( 355 | nn.Linear(dim_in, int(dim_out / self.multi_head), bias=True) 356 | ) 357 | nn.init.xavier_uniform_(self._f_lin[-1].weight) 358 | if i < self.order: 359 | self._attention.append( 360 | nn.Parameter(torch.ones(1, int(dim_out / self.multi_head * 2))) 361 | ) 362 | nn.init.xavier_uniform_(self._attention[-1]) 363 | self._weight_gate = nn.Parameter( 364 | torch.ones(dim_in * 2 + dim_gate, self.multi_head) 365 | ) 366 | nn.init.xavier_uniform_(self._weight_gate) 367 | self._weight_pool_gate = nn.Parameter(torch.ones(dim_in, dim_gate)) 368 | nn.init.xavier_uniform_(self._weight_pool_gate) 369 | self.mods = nn.ModuleList(self._f_lin) 370 | self.f_dropout = nn.Dropout(p=self.dropout) 371 | self.params = nn.ParameterList( 372 | self._offset 373 | + self._scale 374 | + self._attention 375 | + [self._weight_gate, self._weight_pool_gate] 376 | ) 377 | self.f_lin = [] 378 | self.offset, self.scale = [], [] 379 | self.attention = [] 380 | for i in range(self.order + 1): 381 | self.f_lin.append([]) 382 | self.attention.append([]) 383 | self.offset.append(self.params[i]) 384 | self.scale.append(self.params[len(self._offset) + i]) 385 | for j in range(self.multi_head): 386 | self.f_lin[-1].append(self.mods[i * self.multi_head + j]) 387 | if i < self.order: 388 | self.attention[-1].append( 389 | self.params[len(self._offset) * 2 + i * self.multi_head + j] 390 | ) 391 | self.weight_gate = self.params[-2] 392 | self.weight_pool_gate = self.params[-1] 393 | 394 | def _spmm(self, adj_norm, _feat): 395 | return torch.sparse.mm(adj_norm, _feat) 396 | 397 | def _f_feat_trans(self, _feat, f_lin): 398 | feat_out = [] 399 | for i in range(self.multi_head): 400 | feat_out.append(self.act(f_lin[i](_feat))) 401 | return feat_out 402 | 403 | def _aggregate_attention(self, adj, feat_neigh, feat_self, attention): 404 | attention_self = self.att_act( 405 | attention[:, : feat_self.shape[1]].mm(feat_self.t()) 406 | ).squeeze() 407 | attention_neigh = self.att_act( 408 | attention[:, feat_neigh.shape[1] :].mm(feat_neigh.t()) 409 | ).squeeze() 410 | att_adj = torch.sparse.FloatTensor( 411 | adj._indices(), 412 | (attention_self[adj._indices()[0]] + attention_neigh[adj._indices()[1]]) 413 | * adj._values(), 414 | torch.Size(adj.shape), 415 | ) 416 | return self._spmm(att_adj, feat_neigh) 417 | 418 | def _batch_norm(self, feat): 419 | for i in range(self.order + 1): 420 | mean = feat[i].mean(dim=1).unsqueeze(1) 421 | var = feat[i].var(dim=1, unbiased=False).unsqueeze(1) + 1e-9 422 | feat[i] = (feat[i] - mean) * self.scale[i] * torch.rsqrt(var) \ 423 | + self.offset[i] 424 | return feat 425 | 426 | def _compute_gate_value(self, adj, feat, adj_sp_csr): 427 | """ 428 | See equation (3) of the GaAN paper. Gate value is applied in front of each head. 429 | Symbols such as zj follows the equations in the paper. 430 | """ 431 | zj = feat.mm(self.weight_pool_gate) 432 | neigh_zj = [] 433 | # use loop instead since torch does not support sparse tensor slice 434 | for i in range(adj.shape[0]): 435 | if adj_sp_csr.indptr[i] < adj_sp_csr.indptr[i + 1]: 436 | neigh_zj.append( 437 | torch.max( 438 | zj[ 439 | adj_sp_csr.indices[ 440 | adj_sp_csr.indptr[i] : adj_sp_csr.indptr[i + 1] 441 | ] 442 | ], 443 | 0, 444 | )[0].unsqueeze(0) 445 | ) 446 | else: 447 | if zj.is_cuda: 448 | neigh_zj.append(torch.zeros(1, self.dim_gate).cuda()) 449 | else: 450 | neigh_zj.append(torch.zeros(1, self.dim_gate)) 451 | neigh_zj = torch.cat(neigh_zj, 0) 452 | neigh_mean = self._spmm(adj, feat) 453 | gate_feat = torch.cat([feat, neigh_zj, neigh_mean], 1) 454 | return gate_feat.mm(self.weight_gate) 455 | 456 | def forward(self, inputs): 457 | """ 458 | Inputs: 459 | inputs tuple / list of two elements: 460 | 1. adj_norm: normalized subgraph adj. Normalization should 461 | consider both the node degree and aggregation normalization 462 | 2. feat_in: 2D matrix of node features input to the layer 463 | 464 | Outputs: 465 | feat_out 2D matrix of features for output nodes of the layer 466 | adj_norm normalized adj same as the input. We have to return it to 467 | support nn.Sequential called in models.py 468 | """ 469 | adj_norm, feat_in = inputs 470 | feat_in = self.f_dropout(feat_in) 471 | # compute gate value 472 | adj_norm_cpu = adj_norm.cpu() 473 | adj_norm_sp_csr = sp.coo_matrix( 474 | ( 475 | adj_norm_cpu._values().numpy(), 476 | ( 477 | adj_norm_cpu._indices()[0].numpy(), 478 | adj_norm_cpu._indices()[1].numpy(), 479 | ), 480 | ), 481 | shape=(adj_norm.shape[0], adj_norm.shape[0]), 482 | ).tocsr() 483 | gate_value = self._compute_gate_value(adj_norm, feat_in, adj_norm_sp_csr) 484 | feat_partial = [] 485 | for i in range(self.order + 1): 486 | feat_partial.append(self._f_feat_trans(feat_in, self.f_lin[i])) 487 | for i in range(1, self.order + 1): 488 | for j in range(i): 489 | for k in range(self.multi_head): 490 | feat_partial[i][k] = self._aggregate_attention( 491 | adj_norm, 492 | feat_partial[i][k], 493 | feat_partial[i - j - 1][k], 494 | self.attention[i - 1][k], 495 | ) 496 | feat_partial[i][k] *= gate_value[:, k].unsqueeze(1) 497 | for i in range(self.order + 1): 498 | feat_partial[i] = torch.cat(feat_partial[i], 1) 499 | # if norm before concatenation, gate value vanishes 500 | if self.bias == "norm": 501 | feat_partial = self._batch_norm(feat_partial) 502 | if self.aggr == "mean": 503 | feat_out = feat_partial[0] 504 | for i in range(len(feat_partial) - 1): 505 | feat_out += feat_partial[i + 1] 506 | elif self.aggr == "concat": 507 | feat_out = torch.cat(feat_partial, 1) 508 | else: 509 | raise NotImplementedError 510 | return adj_norm, feat_out 511 | --------------------------------------------------------------------------------