├── requirements.txt ├── Data └── cora.npz ├── NOTICE ├── run.sh ├── CONTRIBUTING.md ├── README.md ├── args.py ├── sample.py ├── utils.py ├── test.py ├── model.py └── LICENSE /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-gpu>=1.10 2 | numpy 3 | scipy 4 | scikit-learn 5 | -------------------------------------------------------------------------------- /Data/cora.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/linkedin/PASS-GNN/HEAD/Data/cora.npz -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright 2021 LinkedIn Corporation 2 | All Rights Reserved. 3 | 4 | Licensed under the CC BY-NC 4.0 License (the "License"). 5 | See LICENSE in the project root for license information. 6 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | DATASETS=("cora" "citeseer" "pubmed" "amazon_computer" "amazon_photo" "ms_cs" "ms_physic") 2 | SCOPES=(5 5 5 20 20 5 10) 3 | SAMPLES=(1 1 1 1 1 1 1) 4 | 5 | length=${#DATASETS[@]} 6 | 7 | for ((i=1;i<=$length;i++)) 8 | do 9 | echo "DATASET: " ${DATASETS[$i]} 10 | python test.py --dataset "${DATASETS[$i]}" --sample_scope ${SCOPES[$i]} --sample_num ${SAMPLES[$i]} 11 | done 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Contribution Agreement 2 | ====================== 3 | 4 | As a contributor, you represent that the code you submit is your 5 | original work or that of your employer (in which case you represent 6 | you have the right to bind your employer). By submitting code, you 7 | (and, if applicable, your employer) are licensing the submitted code 8 | to LinkedIn and the open source community subject to the CC BY-NC 4.0 9 | license. 10 | 11 | Responsible Disclosure of Security Vulnerabilities 12 | ================================================== 13 | 14 | Please do not file reports on Github for security issues. Please 15 | review the guidelines on at (link to more info). Reports should be 16 | encrypted using PGP (link to PGP key) and sent to 17 | security@linkedin.com preferably with the title "Github 18 | linkedin/ - ". 19 | 20 | Tips for Getting Your Pull Request Accepted 21 | =========================================== 22 | 23 | *Note: These are suggestions. Customize as needed.* 24 | 25 | 1. Make sure all new features are tested and the tests pass. 26 | 2. Bug fixes must include a test case demonstrating the error that it 27 | fixes. 28 | 3. Open an issue first and seek advice for your change before 29 | submitting a pull request. Large features which have never been 30 | discussed are unlikely to be accepted. **You have been warned.** 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PASS: Performance-Adaptive Sampling Strategy Towards Fast and Accurate Graph Neural Networks 2 | 3 | PASS is a neighborhood sampler for graph neural network models. 4 | PASS samples neighbors informative for a target task by optimizing a sampling policy directly towards task performance. 5 | 6 | You can see our KDD 2021 paper ["Performance-Adaptive Sampling Strategy Towards Fast and Accurate Graph Neural Networks"](https://minjiyoon.xyz/Paper/PASS.pdf) for more details. 7 | This implementation is based on Pytorch. 8 | 9 | ## Requirement 10 | 11 | Use the package manager pip to install requirements: 12 | 13 | ```bash 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Dataset 18 | 19 | We use open-source dataset, [GNN-benchmark](https://github.com/shchur/gnn-benchmark), for our experiments. 20 | Our code reads npz-format graph datasets. 21 | 22 | 23 | ## Usage 24 | 25 | In **args.py**, you can find a list of hyperparameters. 26 | Some of them are related to Neural Network training, some are related to GNN structure, and the others are related to our sampling strategies. 27 | You can find descriptions of hyperparameters in **args.py** file. 28 | 29 | Here is the example command to run PASS. 30 | ```bash 31 | python test.py --dataset cora --sample_num 5 32 | ``` 33 | Once you download all npz-format datasets in **run.sh** into ./Data/ directory, you can simply run **run.sh** to test all datasets with different sampling numbers. 34 | 35 | 36 | ## Citation 37 | 38 | Please consider citing the following paper when using our code for your application. 39 | ```bash 40 | @inproceedings{yoon2021performance, 41 | title={Performance-Adaptive Sampling Strategy Towards Fast and Accurate Graph Neural Networks}, 42 | author={Yoon, Minji and Gervet, Th{\'e}ophile and Shi, Baoxu and Niu, Sufeng and He, Qi and Yang, Jaewon}, 43 | booktitle={Proceedings of the 27th ACM SIGKDD Conference on Knowledge Discovery \& Data Mining}, 44 | pages={2046--2056}, 45 | year={2021} 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args(): 4 | parser = argparse.ArgumentParser() 5 | # Training-related hyperparameters 6 | parser.add_argument('--epochs', type=int, default=300, 7 | help='Number of epochs to train.') 8 | parser.add_argument('--batch_size', type=int, default=64, 9 | help='Size of batch.') 10 | parser.add_argument('--lr', type=float, default=0.001, 11 | help='Initial learning rate.') 12 | parser.add_argument('--weight_decay', type=float, default=5e-4, 13 | help='Weight decay (L2 loss on parameters).') 14 | parser.add_argument('--dropout', type=float, default=0.0, 15 | help='Dropout rate (1 - keep probability).') 16 | parser.add_argument('--early_stopping', type=int, default=10, 17 | help='Number of epochs to wait before early stop.') 18 | 19 | # Dataset-related hyperparameters 20 | parser.add_argument('--data_dir', type=str, default="./Data/", 21 | help='Dataset location.') 22 | parser.add_argument('--dataset', type=str, default="cora", 23 | help='Dataset to use.') 24 | 25 | # GCN structure-related hyperparameters 26 | parser.add_argument('--hidden_dim', type=int, default=64, 27 | help='Hidden dimension') 28 | parser.add_argument('--step_num', type=int, default=2, 29 | help='Number of message-passing steps') 30 | parser.add_argument('--nonlinear', dest='nonlinear', action='store_true') 31 | parser.add_argument('--linear', dest='nonlinear', action='store_false') 32 | parser.set_defaults(nonlinear=True) 33 | 34 | # Sampling-related hyperparameters 35 | parser.add_argument('--sample_scope', type=int, default=32, 36 | help='Number of candidates for sampling') 37 | parser.add_argument('--sample_num', type=int, default=5, 38 | help='Number of sampled neighbors') 39 | 40 | args, _ = parser.parse_known_args() 41 | return args 42 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as f 4 | 5 | 6 | # Decide neighborhood for sampling 7 | def sampleScope(sample_scope, adj, edge, weight): 8 | # Graph is presented as edge/weight lists of fixed length (sample_scope) 9 | # Edge list maintains the indices of neighborhood 10 | # Weight list maintains the weight(attention) of neighborhood 11 | 12 | num_data = adj.shape[0] 13 | edge.data *= num_data 14 | for v in range(num_data): 15 | neighbors = torch.from_numpy(np.nonzero(adj[v, :])[1]) 16 | len_neighbors = len(neighbors) 17 | len_neighbors = neighbors.shape[0] 18 | if len_neighbors == 0: 19 | edge.data[v, 0] = v 20 | weight.data[v, 0] = 1 21 | elif len_neighbors > sample_scope: 22 | perm = torch.randperm(len_neighbors)[:sample_scope] 23 | neighbors = neighbors[perm] 24 | edge.data[v, :sample_scope] = neighbors 25 | weight.data[v, :sample_scope] = 1./sample_scope 26 | else: 27 | edge.data[v, :len_neighbors] = neighbors 28 | weight.data[v, :len_neighbors] = 1./len_neighbors 29 | return 30 | 31 | # Generate sub-adjacency matrix 32 | def computeSubGraph(edge, weight, in_nodes, out_nodes): 33 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 34 | # Collect nodes participated in the sub-adjacency matrix 35 | adj = torch.gather(edge[in_nodes], dim=1, index=out_nodes).type(torch.int64) 36 | # Unique indices for nodes to generate a new sub-adjacency matrix 37 | unique, index = torch.unique(adj.view(-1), return_inverse=True) 38 | adj_shape = adj.shape 39 | del adj 40 | 41 | row = (torch.arange(index.shape[0]) // adj_shape[1]).type(torch.int64).to(device) 42 | col = index 43 | attention = torch.ones(index.size()).to(device) 44 | indices = torch.stack([row, col]) 45 | dense_shape = torch.Size([adj_shape[0], unique.shape[0]]) 46 | new_adj = torch.sparse.FloatTensor(indices, attention, dense_shape).to_dense() 47 | del indices 48 | del attention 49 | 50 | # Unweighted attention: same as normalized binary adjacency matrix 51 | indices = torch.nonzero(new_adj, as_tuple=True) 52 | new_adj[indices] = 1 53 | del indices 54 | # Normalize 55 | row_sum = torch.sum(new_adj, dim=1) 56 | row_sum = torch.diag(1/row_sum) 57 | new_adj = torch.spmm(row_sum, new_adj) 58 | return unique, new_adj 59 | 60 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | import scipy.sparse as sp 5 | 6 | # Ratio for data split (training/validation/test) 7 | train_ratio = 0.5 8 | val_ratio = 0.1 9 | test_ratio = 0.4 10 | 11 | # Split dataset into training/validation/test sets 12 | def split_dataset(labels): 13 | label_min = labels.min() 14 | label_max = labels.max() + 1 15 | idx = {} 16 | for i in range(label_min, label_max): 17 | idx[i] = [] 18 | for i in range(len(labels)): 19 | idx[labels[i]].append(i) 20 | train_idx = [] 21 | val_idx = [] 22 | test_idx = [] 23 | for i in range(label_min, label_max): 24 | train_num = int(train_ratio*len(idx[i])) 25 | for j in range(train_num): 26 | train_idx.append(idx[i][j]) 27 | for i in range(label_min, label_max): 28 | train_num = int(train_ratio*len(idx[i])) 29 | val_num = int(val_ratio*len(idx[i])) 30 | for j in range(train_num, train_num+val_num): 31 | val_idx.append(idx[i][j]) 32 | for i in range(label_min, label_max): 33 | train_val_num = int((train_ratio + val_ratio)*len(idx[i])) 34 | test_idx.extend(idx[i][train_val_num:-1]) 35 | random.shuffle(train_idx) 36 | random.shuffle(val_idx) 37 | random.shuffle(test_idx) 38 | return train_idx, val_idx, test_idx 39 | 40 | # Load npz-format files 41 | def load_npz(device, datadir='./Data', dataset="ms_cs"): 42 | dataset = datadir + "/" + dataset + ".npz" 43 | with np.load(dataset, allow_pickle = True) as loader: 44 | loader = dict(loader) 45 | graph = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], loader['adj_indptr']), shape=loader['adj_shape']) 46 | # Undirected graph 47 | graph = graph + graph.transpose() 48 | # Remove self-loop 49 | for i in range(graph.shape[0]): 50 | graph[i,i] = 1 51 | # Feature matrix 52 | if 'attr_data' in loader: 53 | # Attributes are stored as a sparse CSR matrix 54 | attr_matrix = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], loader['attr_indptr']), shape=loader['attr_shape']) 55 | features = torch.FloatTensor(sp.csr_matrix.toarray(attr_matrix)) 56 | elif 'attr_matrix' in loader: 57 | # Attributes are stored as a (dense) np.ndarray 58 | attr_matrix = loader['attr_matrix'] 59 | features = torch.FloatTensor(attr_matrix) 60 | else: 61 | attr_matrix = None 62 | features = None 63 | # Labels 64 | if 'labels_data' in loader: 65 | # Labels are stored as a CSR matrix 66 | labels = sp.csr_matrix((loader['labels_data'], loader['labels_indices'], loader['labels_indptr']), shape=loader['labels_shape']) 67 | elif 'labels' in loader: 68 | # Labels are stored as a numpy array 69 | labels = loader['labels'] 70 | else: 71 | labels = None 72 | 73 | features = torch.nn.functional.normalize(features, p=2, dim=1) 74 | 75 | train_idx, val_idx, test_idx = split_dataset(labels) 76 | label_num = labels.max() - labels.min() + 1 77 | labels = np.append(labels, [-1]) 78 | labels = torch.LongTensor(labels).to(device) 79 | 80 | idx = {} 81 | idx["train"] = torch.LongTensor(train_idx).to(device) 82 | idx["val"] = torch.LongTensor(val_idx).to(device) 83 | idx["test"] = torch.LongTensor(test_idx).to(device) 84 | 85 | return graph, features, labels, label_num, idx 86 | 87 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.optim as optim 4 | 5 | from math import sqrt 6 | from time import perf_counter 7 | 8 | from args import get_args 9 | from model import sampleGCN 10 | from utils import load_npz 11 | 12 | import warnings 13 | warnings.filterwarnings('ignore') 14 | 15 | # Train GCN models 16 | def train_model(model, labels, idx, args): 17 | # Register trainable parameters to the optimizer 18 | ml = list() 19 | for index, module in enumerate(model.W): 20 | if (index == 0): 21 | ml.append({'params': module.parameters(), 'weight_decay': model.weight_decay}) 22 | else: 23 | ml.append({'params': module.parameters()}) 24 | ml.append({'params': model.sample_W}) 25 | ml.append({'params': model.sample_W2}) 26 | ml.append({'params': model.sample_a}) 27 | 28 | optimizer = optim.Adam(ml, lr=args.lr) 29 | 30 | start_time = perf_counter() 31 | patient = 0 32 | loss = np.inf 33 | train_idx = idx["train"] 34 | val_idx = idx["val"] 35 | for _ in range(args.epochs): 36 | total_train_loss = 0 37 | total_sample_loss = 0 38 | # Batch training 39 | for b in range(train_idx.shape[0] // args.batch_size): 40 | batch_idx = train_idx[b * args.batch_size : (b+1) * args.batch_size] 41 | model.train() 42 | optimizer.zero_grad() 43 | output = model(batch_idx) 44 | train_loss = model.calc_loss(output, labels[batch_idx]) 45 | total_train_loss = total_train_loss + train_loss.item() 46 | train_loss.backward() 47 | 48 | # Loss for sampling probability function 49 | # Gradient of intermediate tensor 50 | chain_grad = model.X1.grad 51 | # Compute intermediate loss for sampling probability parameters 52 | sample_loss = model.sample_loss(chain_grad.detach()) 53 | total_sample_loss = total_sample_loss + sample_loss.item() 54 | sample_loss.backward() 55 | 56 | optimizer.step() 57 | torch.cuda.empty_cache() 58 | 59 | with torch.no_grad(): 60 | model.eval() 61 | output = model(val_idx) 62 | new_loss = model.calc_loss(output, labels[val_idx]) 63 | if new_loss >= loss: 64 | patient = patient + 1 65 | else: 66 | patient = 0 67 | loss = new_loss 68 | if patient == args.early_stopping: 69 | break 70 | 71 | train_time = perf_counter() - start_time 72 | return train_time 73 | 74 | 75 | # Test GCN models 76 | def test_model(model, labels, idx, args): 77 | start_time = perf_counter() 78 | test_idx = idx["test"] 79 | model.eval() 80 | total_acc = 0 81 | for b in range(test_idx.shape[0] // args.batch_size): 82 | batch_idx = test_idx[b * args.batch_size : (b+1) * args.batch_size] 83 | output = model(batch_idx) 84 | np_pred = output.cpu() 85 | np_target = labels[batch_idx].cpu() 86 | acc_mic, acc_mac = model.calc_f1(np_pred, np_target) 87 | total_acc = total_acc + acc_mic 88 | total_acc = total_acc / (test_idx.shape[0] // args.batch_size) 89 | test_time = perf_counter() - start_time 90 | return total_acc, test_time 91 | 92 | 93 | # Train, test, and compute average performance of GCN models 94 | def run_gnn(device, args, graph, features, feat_size, labels, label_max, idx): 95 | trial = 3 96 | total_acc = 0 97 | total_acc2 = 0 98 | total_train_time = 0 99 | total_test_time = 0 100 | for _ in range(trial): 101 | model = sampleGCN(feat_size, label_max, args.hidden_dim, args.step_num, 102 | graph, features, args.sample_scope, args.sample_num, 103 | args.nonlinear, args.dropout, args.weight_decay) 104 | # Move to GPUs (or CPUs) 105 | model = model.to(device) 106 | model.device = device 107 | train_time = train_model(model, labels, idx, args) 108 | total_train_time = total_train_time + train_time 109 | 110 | mic_acc, test_time = test_model(model, labels, idx, args) 111 | total_acc = total_acc + mic_acc 112 | total_acc2 = total_acc2 + mic_acc**2 113 | total_test_time = total_test_time + test_time 114 | del model 115 | 116 | total_acc = total_acc/trial 117 | total_acc2 = sqrt(total_acc2/trial - total_acc**2) 118 | total_train_time = total_train_time/trial 119 | total_test_time = total_test_time/trial 120 | 121 | print("Train time: {:.4f},\tTest time: {:.4f},\tAccuracy: {:.4f},\tstd: {:.8f}".format(total_train_time, total_test_time, total_acc, total_acc2)) 122 | 123 | def main(): 124 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 125 | args = get_args() 126 | 127 | graph, features, labels, label_max, idx = load_npz(device, args.data_dir, args.dataset) 128 | feat_size = features.size(1) 129 | run_gnn(device, args, graph, features, feat_size, labels, label_max, idx) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | 135 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from sample import sampleScope, computeSubGraph 7 | from sklearn import metrics 8 | 9 | # GCN template 10 | class GCN(nn.Module): 11 | 12 | def __init__(self, input_dim, output_dim, hidden_dim, step_num, 13 | nonlinear, dropout, weight_decay): 14 | super(GCN, self).__init__() 15 | 16 | self.input_dim = input_dim 17 | self.output_dim = output_dim 18 | self.hidden_dim = hidden_dim 19 | self.step_num = step_num 20 | 21 | self.W = nn.ModuleList([nn.Linear(input_dim, hidden_dim, bias=False)]) 22 | for _ in range(step_num-2): 23 | self.W.append(nn.Linear(hidden_dim, hidden_dim, bias=False)) 24 | self.W.append(nn.Linear(hidden_dim, output_dim, bias=False)) 25 | 26 | for w in self.W: 27 | nn.init.xavier_uniform_(w.weight) 28 | 29 | self.nonlinear = nonlinear 30 | self.dropout = nn.Dropout(dropout) 31 | self.weight_decay = weight_decay 32 | 33 | def initialize(self): 34 | return NotImplementedError 35 | 36 | def forward(self, ids): 37 | return NotImplementedError 38 | 39 | def calc_loss(self, y_pred, y_true): 40 | loss_train = F.cross_entropy(y_pred, y_true) 41 | return loss_train 42 | 43 | def calc_f1(self, y_pred, y_true): 44 | y_pred = torch.argmax(y_pred, dim=1) 45 | return metrics.f1_score(y_true, y_pred, average="micro"),\ 46 | metrics.f1_score(y_true, y_pred, average="macro") 47 | 48 | # GCN models generating sub-adjacency matrices for every batch 49 | class sampleGCN(GCN): 50 | 51 | def __init__(self, input_dim, output_dim, hidden_dim, step_num, 52 | graph, feature, sample_scope, sample_num, 53 | nonlinear, dropout, weight_decay): 54 | super(sampleGCN, self).__init__(input_dim, output_dim, hidden_dim, step_num, 55 | nonlinear, dropout, weight_decay) 56 | 57 | self.node_num = graph.shape[0] 58 | self.sample_scope = sample_scope 59 | self.sample_num = sample_num 60 | 61 | # Indices of neighbors 62 | self.edge = nn.Parameter(torch.ones(self.node_num+1, sample_scope, dtype=torch.float)) 63 | self.edge.requires_grad = False 64 | 65 | # Weights (attention) of neighbors 66 | self.weight = torch.zeros(self.node_num+1, sample_scope, dtype=torch.float) 67 | self.weight.requires_grad = False 68 | 69 | # Decide the neighborhood for sampling 70 | sampleScope(self.sample_scope, graph, self.edge, self.weight) 71 | 72 | # Trainable parameters of sampling probability function 73 | self.sample_W = nn.Parameter(torch.zeros(size=(input_dim, hidden_dim))) 74 | self.sample_W.requires_grad = True 75 | nn.init.xavier_uniform_(self.sample_W.data, gain=1.414) 76 | self.sample_W2 = nn.Parameter(torch.zeros(size=(input_dim, hidden_dim))) 77 | self.sample_W2.requires_grad = True 78 | nn.init.xavier_uniform_(self.sample_W2.data, gain=1.414) 79 | self.sample_a = nn.Parameter(torch.FloatTensor(np.array([[10e-3], [10e-3], [10e-1]]))) 80 | self.sample_a.requires_grad = True 81 | self.softmax_a = nn.Softmax(dim=0) 82 | 83 | # Feature matrix 84 | self.feature = torch.cat([feature, torch.zeros(1, input_dim)], dim=0) 85 | 86 | # Loss for sampling probability parameters 87 | def sample_loss(self, loss_up): 88 | # batch_sampler: nodes from upper layer sampled their neighbors 89 | # batch_sampled: nodes from lower layer were sampled by their parents 90 | # log probability for "batch_sampler" to sample "batch_sampled" 91 | logp = self.get_policy(self.batch_sampler).log_prob(self.batch_sampled.transpose(0,1)).transpose(0,1) 92 | batch_sampled_node = torch.gather(self.edge[self.batch_sampler], dim=1, index=self.batch_sampled).type(torch.int64) 93 | X = logp.unsqueeze(2)*self.feature[batch_sampled_node].to(self.device) 94 | X = X.mean(dim=1) 95 | # Chain rule 96 | batch_loss = torch.bmm(loss_up.unsqueeze(1), X.unsqueeze(2)) 97 | return batch_loss.mean() 98 | 99 | # A helper function to compute node attention 100 | def _node_attention(self, source, target, weight): 101 | # source: embedding of source nodes 102 | # target: embedding of target nodes 103 | # weight: weight for the attention transformation. 104 | ss = torch.mm(source.reshape(-1, self.input_dim), weight) 105 | tt = torch.mm(target.reshape(-1, self.input_dim), weight) 106 | att = torch.bmm(ss.unsqueeze(1), tt.unsqueeze(2)).squeeze(2) 107 | return att 108 | 109 | # Compute sampling probability given source node. 110 | def weight_function(self, source_idx): 111 | # source_idx: parent node we use to compute the sampling weights. 112 | # 1st head of importance sampling 113 | neighbor_idx = self.edge[source_idx].type(torch.int64) 114 | source_ = self.feature[source_idx].unsqueeze(1).expand(-1, self.sample_scope, -1).to(self.device) 115 | target_ = self.feature[neighbor_idx].to(self.device) 116 | att1 = self._node_attention(source_, target_, self.sample_W) 117 | # 2nd head of importance sampling 118 | att2 = self._node_attention(source_, target_, self.sample_W2) 119 | # Random sampling 120 | att3 = self.weight[source_idx].view(-1).unsqueeze(1).to(self.device) 121 | # Attention of Attentions 122 | att = torch.cat([att1, att2, att3], dim=1) 123 | att = F.relu(torch.mm(att, self.softmax_a(self.sample_a))) 124 | att = att + 10e-10*torch.ones_like(att) 125 | att = att.reshape(-1, self.sample_scope) 126 | return att 127 | 128 | # Get sampling probability distributions (= policy) of each node 129 | def get_policy(self, target_nodes): 130 | probs = self.weight_function(target_nodes) 131 | return torch.distributions.Categorical(probs=probs) 132 | 133 | # Sample neighbors (=action) from sampling probability distributions 134 | def get_action(self, target_nodes): 135 | policy = self.get_policy(target_nodes) 136 | return policy.sample_n(self.sample_num).transpose(0,1) 137 | 138 | # Sample neighbors and generate sub-adjacency matrices for the given batch 139 | def sampleNodes(self, nodes): 140 | all_adj = [[]] * self.step_num 141 | all_feats = [[]] * self.step_num 142 | # Top-1 layer = batch node 143 | in_nodes = nodes 144 | # Top-down sampling from top-1 layer to the input layer 145 | for i in range(self.step_num): 146 | layer = self.step_num - i - 1 147 | # Neighbors are sampled dynamically 148 | out_nodes = self.get_action(in_nodes).type(torch.int64) 149 | if layer == 0: 150 | self.batch_sampler = in_nodes 151 | self.batch_sampled = out_nodes 152 | # Generate sub-adjacency matrix 153 | out_nodes, adj = computeSubGraph(self.edge, self.weight, in_nodes, out_nodes) 154 | if layer == 0: 155 | start_feat = self.feature[out_nodes].to(self.device) 156 | all_feats[layer] = self.feature[in_nodes].to(self.device) 157 | all_adj[layer] = adj 158 | in_nodes = out_nodes 159 | return start_feat, all_feats, all_adj 160 | 161 | def forward(self, ids): 162 | # Sample neighbors and generate sub-adjacency matrices for the given batch 163 | X, feat_list, adj_list = self.sampleNodes(ids) 164 | for idx, (adj, feat, w) in enumerate(zip(adj_list, feat_list, self.W)): 165 | X = torch.sparse.mm(adj, X) 166 | # For loss computation: consider only 1st layer nodes 167 | if idx == 0: 168 | self.X1 = nn.Parameter(X) 169 | X = self.X1 170 | X = w(X) 171 | # Softmax of the final layer will be taken in loss function later 172 | if idx < self.step_num - 1: 173 | X = self.dropout(X) 174 | if self.nonlinear: 175 | X = F.relu(X) 176 | return X 177 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | Copyright 2021 LinkedIn Corporation 4 | All Rights Reserved. 5 | 6 | ======================================================================= 7 | 8 | Creative Commons Corporation ("Creative Commons") is not a law firm and 9 | does not provide legal services or legal advice. Distribution of 10 | Creative Commons public licenses does not create a lawyer-client or 11 | other relationship. Creative Commons makes its licenses and related 12 | information available on an "as-is" basis. Creative Commons gives no 13 | warranties regarding its licenses, any material licensed under their 14 | terms and conditions, or any related information. Creative Commons 15 | disclaims all liability for damages resulting from their use to the 16 | fullest extent possible. 17 | 18 | Using Creative Commons Public Licenses 19 | 20 | Creative Commons public licenses provide a standard set of terms and 21 | conditions that creators and other rights holders may use to share 22 | original works of authorship and other material subject to copyright 23 | and certain other rights specified in the public license below. The 24 | following considerations are for informational purposes only, are not 25 | exhaustive, and do not form part of our licenses. 26 | 27 | Considerations for licensors: Our public licenses are 28 | intended for use by those authorized to give the public 29 | permission to use material in ways otherwise restricted by 30 | copyright and certain other rights. Our licenses are 31 | irrevocable. Licensors should read and understand the terms 32 | and conditions of the license they choose before applying it. 33 | Licensors should also secure all rights necessary before 34 | applying our licenses so that the public can reuse the 35 | material as expected. Licensors should clearly mark any 36 | material not subject to the license. This includes other CC- 37 | licensed material, or material used under an exception or 38 | limitation to copyright. More considerations for licensors: 39 | wiki.creativecommons.org/Considerations_for_licensors 40 | 41 | Considerations for the public: By using one of our public 42 | licenses, a licensor grants the public permission to use the 43 | licensed material under specified terms and conditions. If 44 | the licensor's permission is not necessary for any reason--for 45 | example, because of any applicable exception or limitation to 46 | copyright--then that use is not regulated by the license. Our 47 | licenses grant only permissions under copyright and certain 48 | other rights that a licensor has authority to grant. Use of 49 | the licensed material may still be restricted for other 50 | reasons, including because others have copyright or other 51 | rights in the material. A licensor may make special requests, 52 | such as asking that all changes be marked or described. 53 | Although not required by our licenses, you are encouraged to 54 | respect those requests where reasonable. More considerations 55 | for the public: 56 | wiki.creativecommons.org/Considerations_for_licensees 57 | 58 | ======================================================================= 59 | 60 | Creative Commons Attribution-NonCommercial 4.0 International Public 61 | License 62 | 63 | By exercising the Licensed Rights (defined below), You accept and agree 64 | to be bound by the terms and conditions of this Creative Commons 65 | Attribution-NonCommercial 4.0 International Public License ("Public 66 | License"). To the extent this Public License may be interpreted as a 67 | contract, You are granted the Licensed Rights in consideration of Your 68 | acceptance of these terms and conditions, and the Licensor grants You 69 | such rights in consideration of benefits the Licensor receives from 70 | making the Licensed Material available under these terms and 71 | conditions. 72 | 73 | 74 | Section 1 -- Definitions. 75 | 76 | a. Adapted Material means material subject to Copyright and Similar 77 | Rights that is derived from or based upon the Licensed Material 78 | and in which the Licensed Material is translated, altered, 79 | arranged, transformed, or otherwise modified in a manner requiring 80 | permission under the Copyright and Similar Rights held by the 81 | Licensor. For purposes of this Public License, where the Licensed 82 | Material is a musical work, performance, or sound recording, 83 | Adapted Material is always produced where the Licensed Material is 84 | synched in timed relation with a moving image. 85 | 86 | b. Adapter's License means the license You apply to Your Copyright 87 | and Similar Rights in Your contributions to Adapted Material in 88 | accordance with the terms and conditions of this Public License. 89 | 90 | c. Copyright and Similar Rights means copyright and/or similar rights 91 | closely related to copyright including, without limitation, 92 | performance, broadcast, sound recording, and Sui Generis Database 93 | Rights, without regard to how the rights are labeled or 94 | categorized. For purposes of this Public License, the rights 95 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 96 | Rights. 97 | d. Effective Technological Measures means those measures that, in the 98 | absence of proper authority, may not be circumvented under laws 99 | fulfilling obligations under Article 11 of the WIPO Copyright 100 | Treaty adopted on December 20, 1996, and/or similar international 101 | agreements. 102 | 103 | e. Exceptions and Limitations means fair use, fair dealing, and/or 104 | any other exception or limitation to Copyright and Similar Rights 105 | that applies to Your use of the Licensed Material. 106 | 107 | f. Licensed Material means the artistic or literary work, database, 108 | or other material to which the Licensor applied this Public 109 | License. 110 | 111 | g. Licensed Rights means the rights granted to You subject to the 112 | terms and conditions of this Public License, which are limited to 113 | all Copyright and Similar Rights that apply to Your use of the 114 | Licensed Material and that the Licensor has authority to license. 115 | 116 | h. Licensor means the individual(s) or entity(ies) granting rights 117 | under this Public License. 118 | 119 | i. NonCommercial means not primarily intended for or directed towards 120 | commercial advantage or monetary compensation. For purposes of 121 | this Public License, the exchange of the Licensed Material for 122 | other material subject to Copyright and Similar Rights by digital 123 | file-sharing or similar means is NonCommercial provided there is 124 | no payment of monetary compensation in connection with the 125 | exchange. 126 | 127 | j. Share means to provide material to the public by any means or 128 | process that requires permission under the Licensed Rights, such 129 | as reproduction, public display, public performance, distribution, 130 | dissemination, communication, or importation, and to make material 131 | available to the public including in ways that members of the 132 | public may access the material from a place and at a time 133 | individually chosen by them. 134 | 135 | k. Sui Generis Database Rights means rights other than copyright 136 | resulting from Directive 96/9/EC of the European Parliament and of 137 | the Council of 11 March 1996 on the legal protection of databases, 138 | as amended and/or succeeded, as well as other essentially 139 | equivalent rights anywhere in the world. 140 | 141 | l. You means the individual or entity exercising the Licensed Rights 142 | under this Public License. Your has a corresponding meaning. 143 | 144 | 145 | Section 2 -- Scope. 146 | 147 | a. License grant. 148 | 149 | 1. Subject to the terms and conditions of this Public License, 150 | the Licensor hereby grants You a worldwide, royalty-free, 151 | non-sublicensable, non-exclusive, irrevocable license to 152 | exercise the Licensed Rights in the Licensed Material to: 153 | 154 | a. reproduce and Share the Licensed Material, in whole or 155 | in part, for NonCommercial purposes only; and 156 | 157 | b. produce, reproduce, and Share Adapted Material for 158 | NonCommercial purposes only. 159 | 160 | 2. Exceptions and Limitations. For the avoidance of doubt, where 161 | Exceptions and Limitations apply to Your use, this Public 162 | License does not apply, and You do not need to comply with 163 | its terms and conditions. 164 | 165 | 3. Term. The term of this Public License is specified in Section 166 | 6(a). 167 | 168 | 4. Media and formats; technical modifications allowed. The 169 | Licensor authorizes You to exercise the Licensed Rights in 170 | all media and formats whether now known or hereafter created, 171 | and to make technical modifications necessary to do so. The 172 | Licensor waives and/or agrees not to assert any right or 173 | authority to forbid You from making technical modifications 174 | necessary to exercise the Licensed Rights, including 175 | technical modifications necessary to circumvent Effective 176 | Technological Measures. For purposes of this Public License, 177 | simply making modifications authorized by this Section 2(a) 178 | (4) never produces Adapted Material. 179 | 180 | 5. Downstream recipients. 181 | 182 | a. Offer from the Licensor -- Licensed Material. Every 183 | recipient of the Licensed Material automatically 184 | receives an offer from the Licensor to exercise the 185 | Licensed Rights under the terms and conditions of this 186 | Public License. 187 | 188 | b. No downstream restrictions. You may not offer or impose 189 | any additional or different terms or conditions on, or 190 | apply any Effective Technological Measures to, the 191 | Licensed Material if doing so restricts exercise of the 192 | Licensed Rights by any recipient of the Licensed 193 | Material. 194 | 195 | 6. No endorsement. Nothing in this Public License constitutes or 196 | may be construed as permission to assert or imply that You 197 | are, or that Your use of the Licensed Material is, connected 198 | with, or sponsored, endorsed, or granted official status by, 199 | the Licensor or others designated to receive attribution as 200 | provided in Section 3(a)(1)(A)(i). 201 | 202 | b. Other rights. 203 | 204 | 1. Moral rights, such as the right of integrity, are not 205 | licensed under this Public License, nor are publicity, 206 | privacy, and/or other similar personality rights; however, to 207 | the extent possible, the Licensor waives and/or agrees not to 208 | assert any such rights held by the Licensor to the limited 209 | extent necessary to allow You to exercise the Licensed 210 | Rights, but not otherwise. 211 | 212 | 2. Patent and trademark rights are not licensed under this 213 | Public License. 214 | 215 | 3. To the extent possible, the Licensor waives any right to 216 | collect royalties from You for the exercise of the Licensed 217 | Rights, whether directly or through a collecting society 218 | under any voluntary or waivable statutory or compulsory 219 | licensing scheme. In all other cases the Licensor expressly 220 | reserves any right to collect such royalties, including when 221 | the Licensed Material is used other than for NonCommercial 222 | purposes. 223 | 224 | 225 | Section 3 -- License Conditions. 226 | 227 | Your exercise of the Licensed Rights is expressly made subject to the 228 | following conditions. 229 | 230 | a. Attribution. 231 | 232 | 1. If You Share the Licensed Material (including in modified 233 | form), You must: 234 | 235 | a. retain the following if it is supplied by the Licensor 236 | with the Licensed Material: 237 | 238 | i. identification of the creator(s) of the Licensed 239 | Material and any others designated to receive 240 | attribution, in any reasonable manner requested by 241 | the Licensor (including by pseudonym if 242 | designated); 243 | 244 | ii. a copyright notice; 245 | 246 | iii. a notice that refers to this Public License; 247 | 248 | iv. a notice that refers to the disclaimer of 249 | warranties; 250 | 251 | v. a URI or hyperlink to the Licensed Material to the 252 | extent reasonably practicable; 253 | 254 | b. indicate if You modified the Licensed Material and 255 | retain an indication of any previous modifications; and 256 | 257 | c. indicate the Licensed Material is licensed under this 258 | Public License, and include the text of, or the URI or 259 | hyperlink to, this Public License. 260 | 261 | 2. You may satisfy the conditions in Section 3(a)(1) in any 262 | reasonable manner based on the medium, means, and context in 263 | which You Share the Licensed Material. For example, it may be 264 | reasonable to satisfy the conditions by providing a URI or 265 | hyperlink to a resource that includes the required 266 | information. 267 | 268 | 3. If requested by the Licensor, You must remove any of the 269 | information required by Section 3(a)(1)(A) to the extent 270 | reasonably practicable. 271 | 272 | 4. If You Share Adapted Material You produce, the Adapter's 273 | License You apply must not prevent recipients of the Adapted 274 | Material from complying with this Public License. 275 | 276 | 277 | Section 4 -- Sui Generis Database Rights. 278 | 279 | Where the Licensed Rights include Sui Generis Database Rights that 280 | apply to Your use of the Licensed Material: 281 | 282 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 283 | to extract, reuse, reproduce, and Share all or a substantial 284 | portion of the contents of the database for NonCommercial purposes 285 | only; 286 | 287 | b. if You include all or a substantial portion of the database 288 | contents in a database in which You have Sui Generis Database 289 | Rights, then the database in which You have Sui Generis Database 290 | Rights (but not its individual contents) is Adapted Material; and 291 | 292 | c. You must comply with the conditions in Section 3(a) if You Share 293 | all or a substantial portion of the contents of the database. 294 | 295 | For the avoidance of doubt, this Section 4 supplements and does not 296 | replace Your obligations under this Public License where the Licensed 297 | Rights include other Copyright and Similar Rights. 298 | 299 | 300 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 301 | 302 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 303 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 304 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 305 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 306 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 307 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 308 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 309 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 310 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 311 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 312 | 313 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 314 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 315 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 316 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 317 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 318 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 319 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 320 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 321 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 322 | 323 | c. The disclaimer of warranties and limitation of liability provided 324 | above shall be interpreted in a manner that, to the extent 325 | possible, most closely approximates an absolute disclaimer and 326 | waiver of all liability. 327 | 328 | 329 | Section 6 -- Term and Termination. 330 | 331 | a. This Public License applies for the term of the Copyright and 332 | Similar Rights licensed here. However, if You fail to comply with 333 | this Public License, then Your rights under this Public License 334 | terminate automatically. 335 | 336 | b. Where Your right to use the Licensed Material has terminated under 337 | Section 6(a), it reinstates: 338 | 339 | 1. automatically as of the date the violation is cured, provided 340 | it is cured within 30 days of Your discovery of the 341 | violation; or 342 | 343 | 2. upon express reinstatement by the Licensor. 344 | 345 | For the avoidance of doubt, this Section 6(b) does not affect any 346 | right the Licensor may have to seek remedies for Your violations 347 | of this Public License. 348 | 349 | c. For the avoidance of doubt, the Licensor may also offer the 350 | Licensed Material under separate terms or conditions or stop 351 | distributing the Licensed Material at any time; however, doing so 352 | will not terminate this Public License. 353 | 354 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 355 | License. 356 | 357 | 358 | Section 7 -- Other Terms and Conditions. 359 | 360 | a. The Licensor shall not be bound by any additional or different 361 | terms or conditions communicated by You unless expressly agreed. 362 | 363 | b. Any arrangements, understandings, or agreements regarding the 364 | Licensed Material not stated herein are separate from and 365 | independent of the terms and conditions of this Public License. 366 | 367 | 368 | Section 8 -- Interpretation. 369 | 370 | a. For the avoidance of doubt, this Public License does not, and 371 | shall not be interpreted to, reduce, limit, restrict, or impose 372 | conditions on any use of the Licensed Material that could lawfully 373 | be made without permission under this Public License. 374 | 375 | b. To the extent possible, if any provision of this Public License is 376 | deemed unenforceable, it shall be automatically reformed to the 377 | minimum extent necessary to make it enforceable. If the provision 378 | cannot be reformed, it shall be severed from this Public License 379 | without affecting the enforceability of the remaining terms and 380 | conditions. 381 | 382 | c. No term or condition of this Public License will be waived and no 383 | failure to comply consented to unless expressly agreed to by the 384 | Licensor. 385 | 386 | d. Nothing in this Public License constitutes or may be interpreted 387 | as a limitation upon, or waiver of, any privileges and immunities 388 | that apply to the Licensor or You, including from the legal 389 | processes of any jurisdiction or authority. 390 | 391 | ======================================================================= 392 | 393 | Creative Commons is not a party to its public 394 | licenses. Notwithstanding, Creative Commons may elect to apply one of 395 | its public licenses to material it publishes and in those instances 396 | will be considered the “Licensor.” The text of the Creative Commons 397 | public licenses is dedicated to the public domain under the CC0 Public 398 | Domain Dedication. Except for the limited purpose of indicating that 399 | material is shared under a Creative Commons public license or as 400 | otherwise permitted by the Creative Commons policies published at 401 | creativecommons.org/policies, Creative Commons does not authorize the 402 | use of the trademark "Creative Commons" or any other trademark or logo 403 | of Creative Commons without its prior written consent including, 404 | without limitation, in connection with any unauthorized modifications 405 | to any of its public licenses or any other arrangements, 406 | understandings, or agreements concerning use of licensed material. For 407 | the avoidance of doubt, this paragraph does not form part of the 408 | public licenses. 409 | 410 | Creative Commons may be contacted at creativecommons.org. 411 | --------------------------------------------------------------------------------