├── README.md ├── modules.py ├── sampler.py ├── train_sampling.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GraphSAINT 2 | 3 | This DGL example implements the paper: GraphSAINT: Graph Sampling Based Inductive Learning Method. 4 | 5 | Paper link: https://arxiv.org/abs/1907.04931 6 | 7 | Author's code: https://github.com/GraphSAINT/GraphSAINT 8 | 9 | Contributor: Liu Tang ([@lt610](https://github.com/lt610)) 10 | 11 | ## Dependencies 12 | 13 | - Python 3.7.0 14 | - PyTorch 1.6.0 15 | - NumPy 1.19.2 16 | - Scikit-learn 0.23.2 17 | - DGL 0.5.3 18 | 19 | ## Dataset 20 | 21 | All datasets used are provided by Author's [code](https://github.com/GraphSAINT/GraphSAINT). They are available in [Google Drive](https://drive.google.com/drive/folders/1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [Baidu Wangpan (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg#list/path=%2F)). Once you download the datasets, you need to rename graphsaintdata to data. Dataset summary("m" stands for multi-label classification, and "s" for single-label.): 22 | | Dataset | Nodes | Edges | Degree | Feature | Classes | Train/Val/Test | 23 | | :-: | :-: | :-: | :-: | :-: | :-: | :-: | 24 | | PPI | 14,755 | 225,270 | 15 | 50 | 121(m) | 0.66/0.12/0.22 | 25 | | Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) | 0.50/0.25/0.25 | 26 | | Reddit | 232,965 | 11,606,919 | 50 | 602 | 41(s) | 0.66/0.10/0.24 | 27 | | Yelp | 716,847 | 6,877,410 | 10 | 300 | 100(m) | 0.75/0.10/0.15 | 28 | | Amazon | 1,598,960 | 132,169,734 | 83 | 200 | 107(m) | 0.85/0.05/0.10 | 29 | 30 | ## Minibatch training 31 | 32 | Run with following: 33 | ```bash 34 | python train_sampling.py --gpu 0 --dataset ppi --sampler node --node-budget 6000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 35 | python train_sampling.py --gpu 0 --dataset ppi --sampler edge --edge-budget 4000 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1 36 | python train_sampling.py --gpu 0 --dataset ppi --sampler rw --num-roots 3000 --length 2 --num-repeat 50 --n-epochs 1000 --n-hidden 512 --arch 1-0-1-0 --dropout 0.1 37 | python train_sampling.py --gpu 0 --dataset flickr --sampler node --node-budget 8000 --num-repeat 25 --n-epochs 30 --n-hidden 256 --arch 1-1-0 --dropout 0.2 38 | python train_sampling.py --gpu 0 --dataset flickr --sampler edge --edge-budget 6000 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2 39 | python train_sampling.py --gpu 0 --dataset flickr --sampler rw --num-roots 6000 --length 2 --num-repeat 25 --n-epochs 15 --n-hidden 256 --arch 1-1-0 --dropout 0.2 40 | python train_sampling.py --gpu 0 --dataset reddit --sampler node --node-budget 8000 --num-repeat 50 --n-epochs 40 --n-hidden 128 --arch 1-0-1-0 --dropout 0.1 41 | python train_sampling.py --gpu 0 --dataset reddit --sampler edge --edge-budget 6000 --num-repeat 50 --n-epochs 40 --n-hidden 128 --arch 1-0-1-0 --dropout 0.1 42 | python train_sampling.py --gpu 0 --dataset reddit --sampler rw --num-roots 2000 --length 4 --num-repeat 50 --n-epochs 30 --n-hidden 128 --arch 1-0-1-0 --dropout 0.1 43 | python train_sampling.py --gpu 0 --dataset yelp --sampler node --node-budget 5000 --num-repeat 50 --n-epochs 50 --n-hidden 512 --arch 1-1-0 --dropout 0.1 44 | python train_sampling.py --gpu 0 --dataset yelp --sampler edge --edge-budget 2500 --num-repeat 50 --n-epochs 100 --n-hidden 512 --arch 1-1-0 --dropout 0.1 45 | python train_sampling.py --gpu 0 --dataset yelp --sampler rw --num-roots 1250 --length 2 --num-repeat 50 --n-epochs 75 --n-hidden 512 --arch 1-1-0 --dropout 0.1 46 | python train_sampling.py --gpu 0 --dataset amazon --sampler node --node-budget 4500 --num-repeat 50 --n-epochs 30 --n-hidden 512 --arch 1-1-0 --dropout 0.1 47 | python train_sampling.py --gpu 0 --dataset amazon --sampler edge --edge-budget 2000 --num-repeat 50 --n-epochs 30 --n-hidden 512 --arch 1-1-0 --dropout 0.1 48 | python train_sampling.py --gpu 0 --dataset amazon --sampler rw --num-roots 1500 --length 2 --num-repeat 50 --n-epochs 30 --n-hidden 512 --arch 1-1-0 --dropout 0.1 49 | ``` 50 | 51 | ## Comparison 52 | 53 | * Paper: results from the paper 54 | * Running: results from experiments with the authors' code 55 | * DGL: results from experiments with the DGL example 56 | 57 | ### F1-micro 58 | 59 | #### Random node sampler 60 | 61 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 62 | | --- | --- | --- | --- | --- | --- | 63 | | Paper | 0.960±0.001 | 0.507±0.001 | 0.962±0.001 | 0.641±0.000 | 0.782±0.004 | 64 | | Running | 0.9628 | 0.5077 | 0.9622 | 0.6393 | 0.7695 | 65 | | DGL | 0.9618 | 0.4828 | 0.9621 | 0.6360 | 0.7748 | 66 | 67 | #### Random edge sampler 68 | 69 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 70 | | --- | --- | --- | --- | --- | --- | 71 | | Paper | 0.981±0.007 | 0.510±0.002 | 0.966±0.001 | 0.653±0.003 | 0.807±0.001 | 72 | | Running | 0.9810 | 0.5066 | 0.9656 | 0.6531 | 0.8071 | 73 | | DGL | 0.9818 | 0.5054 | 0.9653 | 0.6517 | exceed | 74 | 75 | #### Random walk sampler 76 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 77 | | --- | --- | --- | --- | --- | --- | 78 | | Paper | 0.981±0.004 | 0.511±0.001 | 0.966±0.001 | 0.653±0.003 | 0.815±0.001 | 79 | | Running | 0.9812 | 0.5104 | 0.9648 | 0.6527 | 0.8131 | 80 | | DGL | 0.9818 | 0.5018 | 0.9649 | 0.6516 | 0.8150 | 81 | 82 | ### Sampling time 83 | 84 | #### Random node sampler 85 | 86 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 87 | | --- | --- | --- | --- | --- | --- | 88 | | Sampling(Running) | 0.77 | 0.65 | 7.46 | 26.29 | 571.42 | 89 | | Sampling(DGL) | 0.24 | 0.57 | 5.06 | 30.04 | 163.75 | 90 | | Normalization(Running) | 0.69 | 2.84 | 11.54 | 32.72 | 407.20 | 91 | | Normalization(DGL) | 1.04 | 0.41 | 21.05 | 68.63 | 2006.94 | 92 | 93 | #### Random edge sampler 94 | 95 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 96 | | --- | --- | --- | --- | --- | --- | 97 | | Sampling(Running) | 0.72 | 0.56 | 4.46 | 12.38 | 101.76 | 98 | | Sampling(DGL) | 0.50 | 0.72 | 53.88 | 254.63 | exceed | 99 | | Normalization(Running) | 0.68 | 2.62 | 9.42 | 26.64 | 62.59 | 100 | | Normalization(DGL) | 0.61 | 0.38 | 14.69 | 23.63 | exceed | 101 | 102 | #### Random walk sampler 103 | 104 | | Method | PPI | Flickr | Reddit | Yelp | Amazon | 105 | | --- | --- | --- | --- | --- | --- | 106 | | Sampling(Running) | 0.83 | 1.22 | 6.69 | 18.84 | 209.83 | 107 | | Sampling(DGL) | 0.28 | 0.63 | 4.02 | 22.01 | 55.09 | 108 | | Normalization(Running) | 0.87 | 2.60 | 10.28 | 24.41 | 145.85 | 109 | | Normalization(DGL) | 0.70 | 0.42 | 18.34 | 32.16 | 683.96 | 110 | -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch as th 4 | import dgl.function as fn 5 | 6 | 7 | class GCNLayer(nn.Module): 8 | def __init__(self, in_dim, out_dim, order=1, act=None, 9 | dropout=0, batch_norm=False, aggr="concat"): 10 | super(GCNLayer, self).__init__() 11 | self.lins = nn.ModuleList() 12 | self.bias = nn.ParameterList() 13 | for _ in range(order + 1): 14 | self.lins.append(nn.Linear(in_dim, out_dim, bias=False)) 15 | self.bias.append(nn.Parameter(th.zeros(out_dim))) 16 | 17 | self.order = order 18 | self.act = act 19 | self.dropout = nn.Dropout(dropout) 20 | 21 | self.batch_norm = batch_norm 22 | if batch_norm: 23 | self.offset, self.scale = nn.ParameterList(), nn.ParameterList() 24 | for _ in range(order + 1): 25 | self.offset.append(nn.Parameter(th.zeros(out_dim))) 26 | self.scale.append(nn.Parameter(th.ones(out_dim))) 27 | 28 | self.aggr = aggr 29 | self.reset_parameters() 30 | 31 | def reset_parameters(self): 32 | for lin in self.lins: 33 | nn.init.xavier_normal_(lin.weight) 34 | 35 | def feat_trans(self, features, idx): 36 | h = self.lins[idx](features) + self.bias[idx] 37 | 38 | if self.act is not None: 39 | h = self.act(h) 40 | 41 | if self.batch_norm: 42 | mean = h.mean(dim=1).view(h.shape[0], 1) 43 | var = h.var(dim=1, unbiased=False).view(h.shape[0], 1) + 1e-9 44 | h = (h - mean) * self.scale[idx] * th.rsqrt(var) + self.offset[idx] 45 | 46 | return h 47 | 48 | def forward(self, graph, features): 49 | g = graph.local_var() 50 | h_in = self.dropout(features) 51 | h_hop = [h_in] 52 | 53 | D_norm = g.ndata['train_D_norm'] if 'train_D_norm' in g.ndata else g.ndata['full_D_norm'] 54 | for _ in range(self.order): 55 | g.ndata['h'] = h_hop[-1] 56 | if 'w' not in g.edata: 57 | g.edata['w'] = th.ones((g.num_edges(), )).to(features.device) 58 | g.update_all(fn.u_mul_e('h', 'w', 'm'), 59 | fn.sum('m', 'h')) 60 | h = g.ndata.pop('h') 61 | h = h * D_norm 62 | h_hop.append(h) 63 | 64 | h_part = [self.feat_trans(ft, idx) for idx, ft in enumerate(h_hop)] 65 | if self.aggr == "mean": 66 | h_out = h_part[0] 67 | for i in range(len(h_part) - 1): 68 | h_out = h_out + h_part[i + 1] 69 | elif self.aggr == "concat": 70 | h_out = th.cat(h_part, 1) 71 | else: 72 | raise NotImplementedError 73 | 74 | return h_out 75 | 76 | 77 | class GCNNet(nn.Module): 78 | def __init__(self, in_dim, hid_dim, out_dim, arch="1-1-0", 79 | act=F.relu, dropout=0, batch_norm=False, aggr="concat"): 80 | super(GCNNet, self).__init__() 81 | self.gcn = nn.ModuleList() 82 | 83 | orders = list(map(int, arch.split('-'))) 84 | self.gcn.append(GCNLayer(in_dim=in_dim, out_dim=hid_dim, order=orders[0], 85 | act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) 86 | pre_out = ((aggr == "concat") * orders[0] + 1) * hid_dim 87 | 88 | for i in range(1, len(orders)-1): 89 | self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[i], 90 | act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) 91 | pre_out = ((aggr == "concat") * orders[i] + 1) * hid_dim 92 | 93 | self.gcn.append(GCNLayer(in_dim=pre_out, out_dim=hid_dim, order=orders[-1], 94 | act=act, dropout=dropout, batch_norm=batch_norm, aggr=aggr)) 95 | pre_out = ((aggr == "concat") * orders[-1] + 1) * hid_dim 96 | 97 | self.out_layer = GCNLayer(in_dim=pre_out, out_dim=out_dim, order=0, 98 | act=None, dropout=dropout, batch_norm=False, aggr=aggr) 99 | 100 | def forward(self, graph): 101 | h = graph.ndata['feat'] 102 | 103 | for layer in self.gcn: 104 | h = layer(graph, h) 105 | 106 | h = F.normalize(h, p=2, dim=1) 107 | h = self.out_layer(graph, h) 108 | 109 | return h 110 | 111 | -------------------------------------------------------------------------------- /sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | import torch as th 5 | import random 6 | import numpy as np 7 | import dgl.function as fn 8 | import dgl 9 | from dgl.sampling import random_walk, pack_traces 10 | 11 | 12 | # The base class of sampler 13 | # (TODO): online sampling 14 | class SAINTSampler(object): 15 | def __init__(self, dn, g, train_nid, node_budget, num_repeat=50): 16 | """ 17 | :param dn: name of dataset 18 | :param g: full graph 19 | :param train_nid: ids of training nodes 20 | :param node_budget: expected number of sampled nodes 21 | :param num_repeat: number of times of repeating sampling one node 22 | """ 23 | self.g = g 24 | self.train_g: dgl.graph = g.subgraph(train_nid) 25 | self.dn, self.num_repeat = dn, num_repeat 26 | self.node_counter = th.zeros((self.train_g.num_nodes(),)) 27 | self.edge_counter = th.zeros((self.train_g.num_edges(),)) 28 | self.prob = None 29 | 30 | graph_fn, norm_fn = self.__generate_fn__() 31 | 32 | if os.path.exists(graph_fn): 33 | self.subgraphs = np.load(graph_fn, allow_pickle=True) 34 | aggr_norm, loss_norm = np.load(norm_fn, allow_pickle=True) 35 | else: 36 | os.makedirs('./subgraphs/', exist_ok=True) 37 | 38 | self.subgraphs = [] 39 | self.N, sampled_nodes = 0, 0 40 | 41 | t = time.perf_counter() 42 | while sampled_nodes <= self.train_g.num_nodes() * num_repeat: 43 | subgraph = self.__sample__() 44 | self.subgraphs.append(subgraph) 45 | sampled_nodes += subgraph.shape[0] 46 | self.N += 1 47 | print(f'Sampling time: [{time.perf_counter() - t:.2f}s]') 48 | np.save(graph_fn, self.subgraphs) 49 | 50 | t = time.perf_counter() 51 | self.__counter__() 52 | aggr_norm, loss_norm = self.__compute_norm__() 53 | print(f'Normalization time: [{time.perf_counter() - t:.2f}s]') 54 | np.save(norm_fn, (aggr_norm, loss_norm)) 55 | 56 | self.train_g.ndata['l_n'] = th.Tensor(loss_norm) 57 | self.train_g.edata['w'] = th.Tensor(aggr_norm) 58 | self.__compute_degree_norm() 59 | 60 | self.num_batch = math.ceil(self.train_g.num_nodes() / node_budget) 61 | random.shuffle(self.subgraphs) 62 | self.__clear__() 63 | print("The number of subgraphs is: ", len(self.subgraphs)) 64 | print("The size of subgraphs is about: ", len(self.subgraphs[-1])) 65 | 66 | def __clear__(self): 67 | self.prob = None 68 | self.node_counter = None 69 | self.edge_counter = None 70 | self.g = None 71 | 72 | def __counter__(self): 73 | 74 | for sampled_nodes in self.subgraphs: 75 | sampled_nodes = th.from_numpy(sampled_nodes) 76 | self.node_counter[sampled_nodes] += 1 77 | 78 | subg = self.train_g.subgraph(sampled_nodes) 79 | sampled_edges = subg.edata[dgl.EID] 80 | self.edge_counter[sampled_edges] += 1 81 | 82 | def __generate_fn__(self): 83 | raise NotImplementedError 84 | 85 | def __compute_norm__(self): 86 | self.node_counter[self.node_counter == 0] = 1 87 | self.edge_counter[self.edge_counter == 0] = 1 88 | 89 | loss_norm = self.N / self.node_counter / self.train_g.num_nodes() 90 | 91 | self.train_g.ndata['n_c'] = self.node_counter 92 | self.train_g.edata['e_c'] = self.edge_counter 93 | self.train_g.apply_edges(fn.v_div_e('n_c', 'e_c', 'a_n')) 94 | aggr_norm = self.train_g.edata.pop('a_n') 95 | 96 | self.train_g.ndata.pop('n_c') 97 | self.train_g.edata.pop('e_c') 98 | 99 | return aggr_norm.numpy(), loss_norm.numpy() 100 | 101 | def __compute_degree_norm(self): 102 | 103 | self.train_g.ndata['train_D_norm'] = 1. / self.train_g.in_degrees().float().clamp(min=1).unsqueeze(1) 104 | self.g.ndata['full_D_norm'] = 1. / self.g.in_degrees().float().clamp(min=1).unsqueeze(1) 105 | 106 | def __sample__(self): 107 | raise NotImplementedError 108 | 109 | def __len__(self): 110 | return self.num_batch 111 | 112 | def __iter__(self): 113 | self.n = 0 114 | return self 115 | 116 | def __next__(self): 117 | if self.n < self.num_batch: 118 | result = self.train_g.subgraph(self.subgraphs[self.n]) 119 | self.n += 1 120 | return result 121 | else: 122 | random.shuffle(self.subgraphs) 123 | raise StopIteration() 124 | 125 | 126 | class SAINTNodeSampler(SAINTSampler): 127 | def __init__(self, node_budget, dn, g, train_nid, num_repeat=50): 128 | self.node_budget = node_budget 129 | super(SAINTNodeSampler, self).__init__(dn, g, train_nid, node_budget, num_repeat) 130 | 131 | def __generate_fn__(self): 132 | graph_fn = os.path.join('./subgraphs/{}_Node_{}_{}.npy'.format(self.dn, self.node_budget, 133 | self.num_repeat)) 134 | norm_fn = os.path.join('./subgraphs/{}_Node_{}_{}_norm.npy'.format(self.dn, self.node_budget, 135 | self.num_repeat)) 136 | return graph_fn, norm_fn 137 | 138 | def __sample__(self): 139 | if self.prob is None: 140 | self.prob = self.train_g.in_degrees().float().clamp(min=1) 141 | 142 | sampled_nodes = th.multinomial(self.prob, num_samples=self.node_budget, replacement=True).unique() 143 | return sampled_nodes.numpy() 144 | 145 | 146 | class SAINTEdgeSampler(SAINTSampler): 147 | def __init__(self, edge_budget, dn, g, train_nid, num_repeat=50): 148 | self.edge_budget = edge_budget 149 | super(SAINTEdgeSampler, self).__init__(dn, g, train_nid, edge_budget * 2, num_repeat) 150 | 151 | def __generate_fn__(self): 152 | graph_fn = os.path.join('./subgraphs/{}_Edge_{}_{}.npy'.format(self.dn, self.edge_budget, 153 | self.num_repeat)) 154 | norm_fn = os.path.join('./subgraphs/{}_Edge_{}_{}_norm.npy'.format(self.dn, self.edge_budget, 155 | self.num_repeat)) 156 | return graph_fn, norm_fn 157 | 158 | def __sample__(self): 159 | if self.prob is None: 160 | src, dst = self.train_g.edges() 161 | src_degrees, dst_degrees = self.train_g.in_degrees(src).float().clamp(min=1),\ 162 | self.train_g.in_degrees(dst).float().clamp(min=1) 163 | self.prob = 1. / src_degrees + 1. / dst_degrees 164 | 165 | sampled_edges = th.multinomial(self.prob, num_samples=self.edge_budget, replacement=True).unique() 166 | 167 | sampled_src, sampled_dst = self.train_g.find_edges(sampled_edges) 168 | sampled_nodes = th.cat([sampled_src, sampled_dst]).unique() 169 | return sampled_nodes.numpy() 170 | 171 | 172 | class SAINTRandomWalkSampler(SAINTSampler): 173 | def __init__(self, num_roots, length, dn, g, train_nid, num_repeat=50): 174 | self.num_roots, self.length = num_roots, length 175 | super(SAINTRandomWalkSampler, self).__init__(dn, g, train_nid, num_roots * length, num_repeat) 176 | 177 | def __generate_fn__(self): 178 | graph_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}.npy'.format(self.dn, self.num_roots, 179 | self.length, self.num_repeat)) 180 | norm_fn = os.path.join('./subgraphs/{}_RW_{}_{}_{}_norm.npy'.format(self.dn, self.num_roots, 181 | self.length, self.num_repeat)) 182 | return graph_fn, norm_fn 183 | 184 | def __sample__(self): 185 | sampled_roots = th.randint(0, self.train_g.num_nodes(), (self.num_roots, )) 186 | traces, types = random_walk(self.train_g, nodes=sampled_roots, length=self.length) 187 | sampled_nodes, _, _, _ = pack_traces(traces, types) 188 | sampled_nodes = sampled_nodes.unique() 189 | return sampled_nodes.numpy() 190 | 191 | 192 | -------------------------------------------------------------------------------- /train_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from sampler import SAINTNodeSampler, SAINTEdgeSampler, SAINTRandomWalkSampler 8 | from modules import GCNNet 9 | from utils import Logger, evaluate, save_log_dir, load_data 10 | 11 | 12 | def main(args): 13 | 14 | multilabel_data = set(['ppi', 'yelp', 'amazon']) 15 | multilabel = args.dataset in multilabel_data 16 | 17 | # load and preprocess dataset 18 | data = load_data(args, multilabel) 19 | g = data.g 20 | train_mask = g.ndata['train_mask'] 21 | val_mask = g.ndata['val_mask'] 22 | test_mask = g.ndata['test_mask'] 23 | labels = g.ndata['label'] 24 | 25 | train_nid = data.train_nid 26 | 27 | in_feats = g.ndata['feat'].shape[1] 28 | n_classes = data.num_classes 29 | n_nodes = g.num_nodes() 30 | n_edges = g.num_edges() 31 | 32 | n_train_samples = train_mask.int().sum().item() 33 | n_val_samples = val_mask.int().sum().item() 34 | n_test_samples = test_mask.int().sum().item() 35 | 36 | print("""----Data statistics------' 37 | #Nodes %d 38 | #Edges %d 39 | #Classes %d 40 | #Train samples %d 41 | #Val samples %d 42 | #Test samples %d""" % 43 | (n_nodes, n_edges, n_classes, 44 | n_train_samples, 45 | n_val_samples, 46 | n_test_samples)) 47 | # load sampler 48 | if args.sampler == "node": 49 | subg_iter = SAINTNodeSampler(args.node_budget, args.dataset, g, 50 | train_nid, args.num_repeat) 51 | elif args.sampler == "edge": 52 | subg_iter = SAINTEdgeSampler(args.edge_budget, args.dataset, g, 53 | train_nid, args.num_repeat) 54 | elif args.sampler == "rw": 55 | subg_iter = SAINTRandomWalkSampler(args.num_roots, args.length, args.dataset, g, 56 | train_nid, args.num_repeat) 57 | 58 | # set device for dataset tensors 59 | if args.gpu < 0: 60 | cuda = False 61 | else: 62 | cuda = True 63 | torch.cuda.set_device(args.gpu) 64 | val_mask = val_mask.cuda() 65 | test_mask = test_mask.cuda() 66 | g = g.to(args.gpu) 67 | 68 | print('labels shape:', g.ndata['label'].shape) 69 | print("features shape:", g.ndata['feat'].shape) 70 | 71 | model = GCNNet( 72 | in_dim=in_feats, 73 | hid_dim=args.n_hidden, 74 | out_dim=n_classes, 75 | arch=args.arch, 76 | dropout=args.dropout, 77 | batch_norm=not args.no_batch_norm, 78 | aggr=args.aggr 79 | ) 80 | 81 | if cuda: 82 | model.cuda() 83 | 84 | # logger and so on 85 | log_dir = save_log_dir(args) 86 | logger = Logger(os.path.join(log_dir, 'loggings')) 87 | logger.write(args) 88 | 89 | # use optimizer 90 | optimizer = torch.optim.Adam(model.parameters(), 91 | lr=args.lr) 92 | 93 | # set train_nids to cuda tensor 94 | if cuda: 95 | train_nid = torch.from_numpy(train_nid).cuda() 96 | print("GPU memory allocated before training(MB)", 97 | torch.cuda.memory_allocated(device=train_nid.device) / 1024 / 1024) 98 | start_time = time.time() 99 | best_f1 = -1 100 | 101 | for epoch in range(args.n_epochs): 102 | for j, subg in enumerate(subg_iter): 103 | # sync with upper level training graph 104 | if cuda: 105 | subg = subg.to(torch.cuda.current_device()) 106 | model.train() 107 | # forward 108 | pred = model(subg) 109 | batch_labels = subg.ndata['label'] 110 | 111 | if multilabel: 112 | loss = F.binary_cross_entropy_with_logits(pred, batch_labels, reduction='sum', 113 | weight=subg.ndata['l_n'].unsqueeze(1)) 114 | else: 115 | loss = F.cross_entropy(pred, batch_labels, reduction='none') 116 | loss = (subg.ndata['l_n'] * loss).sum() 117 | 118 | optimizer.zero_grad() 119 | loss.backward() 120 | torch.nn.utils.clip_grad_norm(model.parameters(), 5) 121 | optimizer.step() 122 | if j == len(subg_iter) - 1: 123 | print(f"epoch:{epoch+1}/{args.n_epochs}, Iteration {j+1}/" 124 | f"{len(subg_iter)}:training loss", loss.item()) 125 | 126 | # evaluate 127 | if epoch % args.val_every == 0: 128 | val_f1_mic, val_f1_mac = evaluate( 129 | model, g, labels, val_mask, multilabel) 130 | print( 131 | "Val F1-mic {:.4f}, Val F1-mac {:.4f}".format(val_f1_mic, val_f1_mac)) 132 | if val_f1_mic > best_f1: 133 | best_f1 = val_f1_mic 134 | print('new best val f1:', best_f1) 135 | torch.save(model.state_dict(), os.path.join( 136 | log_dir, 'best_model.pkl')) 137 | 138 | end_time = time.time() 139 | print(f'training using time {end_time - start_time}') 140 | 141 | # test 142 | if args.use_val: 143 | model.load_state_dict(torch.load(os.path.join( 144 | log_dir, 'best_model.pkl'))) 145 | test_f1_mic, test_f1_mac = evaluate( 146 | model, g, labels, test_mask, multilabel) 147 | print("Test F1-mic {:.4f}, Test F1-mac {:.4f}".format(test_f1_mic, test_f1_mac)) 148 | 149 | 150 | if __name__ == '__main__': 151 | parser = argparse.ArgumentParser(description='GraphSAINT') 152 | # data source params 153 | parser.add_argument("--dataset", type=str, choices=['ppi', 'flickr', 'reddit', 'yelp', 'amazon'], default='ppi', 154 | help="Name of dataset.") 155 | 156 | # cuda params 157 | parser.add_argument("--gpu", type=int, default=-1, 158 | help="GPU index. Default: -1, using CPU.") 159 | 160 | # sampler params 161 | parser.add_argument("--sampler", type=str, default="node", choices=['node', 'edge', 'rw'], 162 | help="Type of sampler") 163 | parser.add_argument("--node-budget", type=int, default=6000, 164 | help="Expected number of sampled nodes when using node sampler") 165 | parser.add_argument("--edge-budget", type=int, default=4000, 166 | help="Expected number of sampled edges when using edge sampler") 167 | parser.add_argument("--num-roots", type=int, default=3000, 168 | help="Expected number of sampled root nodes when using random walk sampler") 169 | parser.add_argument("--length", type=int, default=2, 170 | help="The length of random walk when using random walk sampler") 171 | parser.add_argument("--num-repeat", type=int, default=50, 172 | help="Number of times of repeating sampling one node to estimate edge / node probability") 173 | 174 | # model params 175 | parser.add_argument("--n-hidden", type=int, default=512, 176 | help="Number of hidden gcn units") 177 | parser.add_argument("--arch", type=str, default="1-0-1-0", 178 | help="Network architecture. 1 means an order-1 layer (self feature plus 1-hop neighbor " 179 | "feature), and 0 means an order-0 layer (self feature only)") 180 | parser.add_argument("--dropout", type=float, default=0, 181 | help="Dropout rate") 182 | parser.add_argument("--no-batch-norm", action='store_true', 183 | help="Whether to use batch norm") 184 | parser.add_argument("--aggr", type=str, default="concat", choices=['mean', 'concat'], 185 | help="How to aggregate the self feature and neighbor features") 186 | 187 | # training params 188 | parser.add_argument("--n-epochs", type=int, default=100, 189 | help="Number of training epochs") 190 | parser.add_argument("--lr", type=float, default=0.01, 191 | help="Learning rate") 192 | parser.add_argument("--val-every", type=int, default=1, 193 | help="Frequency of evaluation on the validation set in number of epochs") 194 | parser.add_argument("--use-val", action='store_true', 195 | help="whether to use validated best model to test") 196 | parser.add_argument("--note", type=str, default='none', 197 | help="Note for log dir") 198 | 199 | args = parser.parse_args() 200 | 201 | print(args) 202 | 203 | main(args) 204 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from functools import namedtuple 4 | import scipy.sparse 5 | from sklearn.preprocessing import StandardScaler 6 | import dgl 7 | import numpy as np 8 | import torch 9 | from sklearn.metrics import f1_score 10 | 11 | 12 | class Logger(object): 13 | '''A custom logger to log stdout to a logging file.''' 14 | def __init__(self, path): 15 | """Initialize the logger. 16 | 17 | Parameters 18 | --------- 19 | path : str 20 | The file path to be stored in. 21 | """ 22 | self.path = path 23 | 24 | def write(self, s): 25 | with open(self.path, 'a') as f: 26 | f.write(str(s)) 27 | print(s) 28 | return 29 | 30 | 31 | def save_log_dir(args): 32 | log_dir = './log/{}/{}'.format(args.dataset, args.note) 33 | os.makedirs(log_dir, exist_ok=True) 34 | return log_dir 35 | 36 | 37 | def calc_f1(y_true, y_pred, multilabel): 38 | if multilabel: 39 | y_pred[y_pred > 0] = 1 40 | y_pred[y_pred <= 0] = 0 41 | else: 42 | y_pred = np.argmax(y_pred, axis=1) 43 | return f1_score(y_true, y_pred, average="micro"), \ 44 | f1_score(y_true, y_pred, average="macro") 45 | 46 | 47 | def evaluate(model, g, labels, mask, multilabel=False): 48 | model.eval() 49 | with torch.no_grad(): 50 | logits = model(g) 51 | logits = logits[mask] 52 | labels = labels[mask] 53 | f1_mic, f1_mac = calc_f1(labels.cpu().numpy(), 54 | logits.cpu().numpy(), multilabel) 55 | return f1_mic, f1_mac 56 | 57 | 58 | # load data of GraphSAINT and convert them to the format of dgl 59 | def load_data(args, multilabel): 60 | prefix = "data/{}".format(args.dataset) 61 | DataType = namedtuple('Dataset', ['num_classes', 'train_nid', 'g']) 62 | 63 | adj_full = scipy.sparse.load_npz('./{}/adj_full.npz'.format(prefix)).astype(np.bool) 64 | g = dgl.from_scipy(adj_full) 65 | num_nodes = g.num_nodes() 66 | 67 | adj_train = scipy.sparse.load_npz('./{}/adj_train.npz'.format(prefix)).astype(np.bool) 68 | train_nid = np.array(list(set(adj_train.nonzero()[0]))) 69 | 70 | role = json.load(open('./{}/role.json'.format(prefix))) 71 | mask = np.zeros((num_nodes,), dtype=bool) 72 | train_mask = mask.copy() 73 | train_mask[role['tr']] = True 74 | val_mask = mask.copy() 75 | val_mask[role['va']] = True 76 | test_mask = mask.copy() 77 | test_mask[role['te']] = True 78 | 79 | feats = np.load('./{}/feats.npy'.format(prefix)) 80 | scaler = StandardScaler() 81 | scaler.fit(feats[train_nid]) 82 | feats = scaler.transform(feats) 83 | 84 | class_map = json.load(open('./{}/class_map.json'.format(prefix))) 85 | class_map = {int(k): v for k, v in class_map.items()} 86 | if multilabel: 87 | num_classes = len(list(class_map.values())[0]) 88 | class_arr = np.zeros((num_nodes, num_classes)) 89 | for k, v in class_map.items(): 90 | class_arr[k] = v 91 | else: 92 | num_classes = max(class_map.values()) - min(class_map.values()) + 1 93 | class_arr = np.zeros((num_nodes,)) 94 | for k, v in class_map.items(): 95 | class_arr[k] = v 96 | 97 | g.ndata['feat'] = torch.tensor(feats, dtype=torch.float) 98 | g.ndata['label'] = torch.tensor(class_arr, dtype=torch.float if multilabel else torch.long) 99 | g.ndata['train_mask'] = torch.tensor(train_mask, dtype=torch.bool) 100 | g.ndata['val_mask'] = torch.tensor(val_mask, dtype=torch.bool) 101 | g.ndata['test_mask'] = torch.tensor(test_mask, dtype=torch.bool) 102 | 103 | data = DataType(g=g, num_classes=num_classes, train_nid=train_nid) 104 | return data 105 | --------------------------------------------------------------------------------