├── README.md ├── collator.py ├── lr.py ├── main.py ├── model.py ├── preprocess_data.py └── start.sh /README.md: -------------------------------------------------------------------------------- 1 | # Graph_Transformer 2 | A pytorch implementation of Graph Transformer for node classification. 3 | 4 | Our implementation is based on "Do Transformers Really Perform Bad for Graph Representation" (NeurIPS'21) [[paper]](https://proceedings.neurips.cc/paper/2021/hash/f1c1592588411002af340cbaedd6fc33-Abstract.html) [[github]](https://github.com/microsoft/Graphormer) and "Gophormer: Ego-Graph Transformer for Node Classification" [[arxiv]](https://arxiv.org/abs/2110.13094) 5 | 6 | ## Run the code 7 | Use preprocess_data.py for data preprocessing 8 | 9 | Use main.py to train the graph transformer 10 | -------------------------------------------------------------------------------- /collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def pad_1d_unsqueeze(x, padlen): 5 | xlen = x.size(0) 6 | if xlen < padlen: 7 | new_x = x.new_zeros([padlen], dtype=x.dtype) 8 | new_x[:xlen] = x 9 | x = new_x 10 | return x.unsqueeze(0) 11 | 12 | 13 | def pad_2d_unsqueeze(x, padlen): 14 | xlen, xdim = x.size() 15 | if xlen < padlen: 16 | new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) 17 | new_x[:xlen, :] = x 18 | x = new_x 19 | return x.unsqueeze(0) 20 | 21 | 22 | def pad_attn_bias_unsqueeze(x, padlen): 23 | xlen = x.size(0) 24 | if xlen < padlen: 25 | new_x = x.new_zeros( 26 | [padlen, padlen], dtype=x.dtype).fill_(float('-inf')) 27 | new_x[:xlen, :xlen] = x 28 | new_x[xlen:, :xlen] = 0 29 | x = new_x 30 | return x.unsqueeze(0) 31 | 32 | 33 | def pad_edge_type_unsqueeze(x, padlen): 34 | xlen = x.size(0) 35 | if xlen < padlen: 36 | new_x = x.new_zeros([padlen, padlen, x.size(-1)], dtype=x.dtype) 37 | new_x[:xlen, :xlen, :] = x 38 | x = new_x 39 | return x.unsqueeze(0) 40 | 41 | 42 | def pad_spatial_pos_unsqueeze(x, padlen): 43 | xlen = x.size(0) 44 | if xlen < padlen: 45 | new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) 46 | new_x[:xlen, :xlen] = x 47 | x = new_x 48 | return x.unsqueeze(0) 49 | 50 | 51 | def pad_3d_unsqueeze(x, padlen1, padlen2, padlen3): 52 | x = x + 1 53 | xlen1, xlen2, xlen3, xlen4 = x.size() 54 | if xlen1 < padlen1 or xlen2 < padlen2 or xlen3 < padlen3: 55 | new_x = x.new_zeros([padlen1, padlen2, padlen3, xlen4], dtype=x.dtype) 56 | new_x[:xlen1, :xlen2, :xlen3, :] = x 57 | x = new_x 58 | return x.unsqueeze(0) 59 | 60 | 61 | class Batch(): 62 | def __init__(self, attn_bias, x, y, ids): 63 | super(Batch, self).__init__() 64 | self.x, self.y = x, y 65 | self.attn_bias = attn_bias 66 | self.ids = ids 67 | 68 | def to(self, device): 69 | self.x, self.y = self.x.to(device), self.y.to(device) 70 | self.attn_bias = self.attn_bias.to(device) 71 | self.ids = self.ids.to(device) 72 | return self 73 | 74 | def __len__(self): 75 | return self.y.size(0) 76 | 77 | 78 | def collator(items, feature, shuffle=False, perturb=False): 79 | batch_list = [] 80 | for item in items: 81 | for x in item: 82 | batch_list.append((x[0], x[1], x[2][0])) 83 | if shuffle: 84 | random.shuffle(batch_list) 85 | attn_biases, xs, ys = zip(*batch_list) 86 | max_node_num = max(i.size(0) for i in xs) 87 | y = torch.cat([i.unsqueeze(0) for i in ys]) 88 | x = torch.cat([pad_2d_unsqueeze(feature[i], max_node_num) for i in xs]) 89 | ids = torch.cat([i.unsqueeze(0) for i in xs]) 90 | if perturb: 91 | x += torch.FloatTensor(x.shape).uniform_(-0.1, 0.1) 92 | attn_bias = torch.cat([i.unsqueeze(0) for i in attn_biases]) 93 | 94 | return Batch( 95 | attn_bias=attn_bias, 96 | x=x, 97 | y=y, 98 | ids=ids, 99 | ) 100 | 101 | -------------------------------------------------------------------------------- /lr.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolynomialDecayLR(_LRScheduler): 5 | 6 | def __init__(self, optimizer, warmup, tot, lr, end_lr, power, last_epoch=-1, verbose=False): 7 | self.warmup = warmup 8 | self.tot = tot 9 | self.lr = lr 10 | self.end_lr = end_lr 11 | self.power = power 12 | super(PolynomialDecayLR, self).__init__(optimizer) 13 | 14 | 15 | def get_lr(self): 16 | if self._step_count <= self.warmup: 17 | self.warmup_factor = self._step_count / float(self.warmup) 18 | lr = self.warmup_factor * self.lr 19 | elif self._step_count >= self.tot: 20 | lr = self.end_lr 21 | else: 22 | warmup = self.warmup 23 | lr_range = self.lr - self.end_lr 24 | pct_remaining = 1 - (self._step_count - warmup) / (self.tot - warmup) 25 | lr = lr_range * pct_remaining ** (self.power) + self.end_lr 26 | 27 | return [lr for group in self.optimizer.param_groups] 28 | 29 | def _get_closed_form_lr(self): 30 | assert False 31 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collator import collator 5 | import random 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | from functools import partial 9 | from model import GT 10 | from lr import PolynomialDecayLR 11 | import argparse 12 | import math 13 | from tqdm import tqdm 14 | from preprocess_data import node_sampling, process_data 15 | from torch.nn.functional import normalize 16 | import scipy.sparse as sp 17 | from numpy.linalg import inv 18 | 19 | 20 | def train(args, model, device, loader, optimizer): 21 | model.train() 22 | 23 | for batch in tqdm(loader, desc="Iteration"): 24 | batch = batch.to(device) 25 | pred = model(batch) 26 | y_true = batch.y.view(-1) 27 | loss = F.nll_loss(pred, y_true) 28 | optimizer.zero_grad() 29 | loss.backward() 30 | optimizer.step() 31 | 32 | 33 | def eval_train(args, model, device, loader): 34 | y_true = [] 35 | y_pred = [] 36 | loss_list = [] 37 | model.eval() 38 | with torch.no_grad(): 39 | for batch in tqdm(loader, desc="Iteration"): 40 | batch = batch.to(device) 41 | pred = model(batch) 42 | loss_list.append(F.nll_loss(pred, batch.y.view(-1)).item()) 43 | y_true.append(batch.y) 44 | y_pred.append(pred.argmax(1)) 45 | 46 | y_pred = torch.cat(y_pred) 47 | y_true = torch.cat(y_true) 48 | correct = (y_pred == y_true).sum() 49 | acc = correct.item() / len(y_true) 50 | 51 | return acc, np.mean(loss_list) 52 | 53 | 54 | def eval(args, model, device, loader): 55 | y_true = [] 56 | y_pred = [] 57 | loss_list = [] 58 | model.eval() 59 | with torch.no_grad(): 60 | for batch in tqdm(loader, desc="Iteration"): 61 | batch = batch.to(device) 62 | pred = model(batch) 63 | loss_list.append(F.nll_loss(pred, batch.y.view(-1)).item()) 64 | y_true.append(batch.y) 65 | y_pred.append(pred.argmax(1)) 66 | 67 | y_pred = torch.cat(y_pred) 68 | y_true = torch.cat(y_true) 69 | 70 | pred_list = [] 71 | for i in torch.split(y_pred, args.num_data_augment, dim=0): 72 | pred_list.append(i.bincount().argmax().unsqueeze(0)) 73 | y_pred = torch.cat(pred_list) 74 | y_true = y_true.view(-1, args.num_data_augment)[:, 0] 75 | correct = (y_pred == y_true).sum() 76 | acc = correct.item() / len(y_true) 77 | 78 | return acc, np.mean(loss_list) 79 | 80 | 81 | def random_split(data_list, frac_train, frac_valid, frac_test, seed): 82 | np.testing.assert_almost_equal(frac_train + frac_valid + frac_test, 1.0) 83 | random.seed(seed) 84 | all_idx = np.arange(len(data_list)) 85 | random.shuffle(all_idx) 86 | train_idx = all_idx[:int(frac_train * len(data_list))] 87 | val_idx = all_idx[int(frac_train * len(data_list)):int((frac_train+frac_valid) * len(data_list))] 88 | test_idx = all_idx[int((frac_train+frac_valid) * len(data_list)):] 89 | train_list = [] 90 | test_list = [] 91 | val_list = [] 92 | for i in train_idx: 93 | train_list.append(data_list[i]) 94 | for i in val_idx: 95 | val_list.append(data_list[i]) 96 | for i in test_idx: 97 | test_list.append(data_list[i]) 98 | return train_list, val_list, test_list 99 | 100 | 101 | def main(): 102 | parser = argparse.ArgumentParser(description='PyTorch implementation of graph transformer') 103 | parser.add_argument('--dataset_name', type=str, default='pubmed') 104 | parser.add_argument('--n_layers', type=int, default=4) 105 | parser.add_argument('--num_heads', type=int, default=8) 106 | parser.add_argument('--hidden_dim', type=int, default=128) 107 | parser.add_argument('--ffn_dim', type=int, default=128) 108 | parser.add_argument('--attn_bias_dim', type=int, default=6) 109 | parser.add_argument('--intput_dropout_rate', type=float, default=0.1) 110 | parser.add_argument('--dropout_rate', type=float, default=0.3) 111 | parser.add_argument('--weight_decay', type=float, default=0.01) 112 | parser.add_argument('--attention_dropout_rate', type=float, default=0.5) 113 | parser.add_argument('--checkpoint_path', type=str, default='') 114 | parser.add_argument('--warmup_epochs', type=int, default=50) 115 | parser.add_argument('--epochs', type=int, default=500) 116 | parser.add_argument('--peak_lr', type=float, default=2e-4) 117 | parser.add_argument('--end_lr', type=float, default=1e-9) 118 | parser.add_argument('--validate', action='store_true', default=False) 119 | parser.add_argument('--test', action='store_true', default=False) 120 | parser.add_argument('--num_data_augment', type=int, default=8) 121 | parser.add_argument('--num_global_node', type=int, default=1) 122 | parser.add_argument('--batch_size', type=int, default=32) 123 | parser.add_argument('--seed', type=int, default=42) 124 | parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataset loading') 125 | parser.add_argument('--device', type=int, default=2, help='which gpu to use if any (default: 0)') 126 | parser.add_argument('--perturb_feature', type=bool, default=False) 127 | args = parser.parse_args() 128 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 129 | 130 | data_list = torch.load('./dataset/'+args.dataset_name+'/data.pt') 131 | feature = torch.load('./dataset/'+args.dataset_name+'/feature.pt') 132 | y = torch.load('./dataset/'+args.dataset_name+'/y.pt') 133 | train_dataset, test_dataset, valid_dataset = random_split(data_list, frac_train=0.6, frac_valid=0.2, frac_test=0.2, seed=args.seed) 134 | print('dataset load successfully') 135 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers = args.num_workers, collate_fn=partial(collator, feature=feature, shuffle=True, perturb=args.perturb_feature)) 136 | val_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers, collate_fn=partial(collator, feature=feature, shuffle=False)) 137 | test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers = args.num_workers, collate_fn=partial(collator, feature=feature, shuffle=False)) 138 | print(args) 139 | 140 | model = GT( 141 | n_layers=args.n_layers, 142 | num_heads=args.num_heads, 143 | input_dim=feature.shape[1], 144 | hidden_dim=args.hidden_dim, 145 | output_dim=y.max().item()+1, 146 | attn_bias_dim=args.attn_bias_dim, 147 | attention_dropout_rate=args.attention_dropout_rate, 148 | dropout_rate=args.dropout_rate, 149 | intput_dropout_rate=args.intput_dropout_rate, 150 | ffn_dim=args.ffn_dim, 151 | num_global_node=args.num_global_node 152 | ) 153 | if not args.test and not args.validate: 154 | print(model) 155 | print('Total params:', sum(p.numel() for p in model.parameters())) 156 | model.to(device) 157 | 158 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.peak_lr, weight_decay=args.weight_decay) 159 | lr_scheduler = PolynomialDecayLR( 160 | optimizer, 161 | warmup=args.warmup_epochs, 162 | tot=args.epochs, 163 | lr=args.peak_lr, 164 | end_lr=args.end_lr, 165 | power=1.0) 166 | 167 | val_acc_list, test_acc_list = [], [] 168 | for epoch in range(1, args.epochs+1): 169 | print("====epoch " + str(epoch)) 170 | train(args, model, device, train_loader, optimizer) 171 | lr_scheduler.step() 172 | 173 | print("====Evaluation") 174 | train_acc, train_loss = eval_train(args, model, device, train_loader) 175 | 176 | val_acc, val_loss = eval(args, model, device, val_loader) 177 | test_acc, test_loss = eval(args, model, device, test_loader) 178 | 179 | print("train_acc: %f val_acc: %f test_acc: %f" % (train_acc, val_acc, test_acc)) 180 | print("train_loss: %f val_loss: %f test_loss: %f" % (train_loss, val_loss, test_loss)) 181 | val_acc_list.append(val_acc) 182 | test_acc_list.append(test_acc) 183 | 184 | print('best validation acc: ', max(val_acc_list)) 185 | print('best test acc: ', max(test_acc_list)) 186 | print('best acc: ', test_acc_list[val_acc_list.index(max(val_acc_list))]) 187 | np.save('./exps/' + args.dataset_name + '/test_acc_list', np.array(test_acc_list)) 188 | np.save('./exps/' + args.dataset_name + '/val_acc_list', np.array(val_acc_list)) 189 | 190 | 191 | if __name__ == "__main__": 192 | main() 193 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | 7 | def init_params(module, n_layers): 8 | if isinstance(module, nn.Linear): 9 | module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers)) 10 | if module.bias is not None: 11 | module.bias.data.zero_() 12 | if isinstance(module, nn.Embedding): 13 | module.weight.data.normal_(mean=0.0, std=0.02) 14 | 15 | 16 | class FeedForwardNetwork(nn.Module): 17 | def __init__(self, hidden_size, ffn_size, dropout_rate): 18 | super(FeedForwardNetwork, self).__init__() 19 | 20 | self.layer1 = nn.Linear(hidden_size, ffn_size) 21 | self.gelu = nn.GELU() 22 | self.layer2 = nn.Linear(ffn_size, hidden_size) 23 | 24 | def forward(self, x): 25 | x = self.layer1(x) 26 | x = self.gelu(x) 27 | x = self.layer2(x) 28 | return x 29 | 30 | 31 | class MultiHeadAttention(nn.Module): 32 | def __init__(self, hidden_size, attention_dropout_rate, num_heads, attn_bias_dim): 33 | super(MultiHeadAttention, self).__init__() 34 | 35 | self.num_heads = num_heads 36 | 37 | self.att_size = att_size = hidden_size // num_heads 38 | self.scale = att_size ** -0.5 39 | 40 | self.linear_q = nn.Linear(hidden_size, num_heads * att_size) 41 | self.linear_k = nn.Linear(hidden_size, num_heads * att_size) 42 | self.linear_v = nn.Linear(hidden_size, num_heads * att_size) 43 | self.linear_bias = nn.Linear(attn_bias_dim, num_heads) 44 | self.att_dropout = nn.Dropout(attention_dropout_rate) 45 | 46 | self.output_layer = nn.Linear(num_heads * att_size, hidden_size) 47 | 48 | def forward(self, q, k, v, attn_bias=None): 49 | orig_q_size = q.size() 50 | 51 | d_k = self.att_size 52 | d_v = self.att_size 53 | batch_size = q.size(0) 54 | 55 | # head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i) 56 | q = self.linear_q(q).view(batch_size, -1, self.num_heads, d_k) 57 | k = self.linear_k(k).view(batch_size, -1, self.num_heads, d_k) 58 | v = self.linear_v(v).view(batch_size, -1, self.num_heads, d_v) 59 | attn_bias = self.linear_bias(attn_bias).permute(0, 3, 1, 2) 60 | 61 | q = q.transpose(1, 2) # [b, h, q_len, d_k] 62 | v = v.transpose(1, 2) # [b, h, v_len, d_v] 63 | k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len] 64 | 65 | # Scaled Dot-Product Attention. 66 | # Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V 67 | q = q * self.scale 68 | x = torch.matmul(q, k) # [b, h, q_len, k_len] 69 | if attn_bias is not None: 70 | x = x + attn_bias 71 | 72 | x = torch.softmax(x, dim=3) 73 | x = self.att_dropout(x) 74 | x = x.matmul(v) # [b, h, q_len, attn] 75 | 76 | x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn] 77 | x = x.view(batch_size, -1, self.num_heads * d_v) 78 | 79 | x = self.output_layer(x) 80 | 81 | assert x.size() == orig_q_size 82 | return x 83 | 84 | 85 | class EncoderLayer(nn.Module): 86 | def __init__(self, hidden_size, ffn_size, dropout_rate, attention_dropout_rate, num_heads, attn_bias_dim): 87 | super(EncoderLayer, self).__init__() 88 | 89 | self.self_attention_norm = nn.LayerNorm(hidden_size) 90 | self.self_attention = MultiHeadAttention( 91 | hidden_size, attention_dropout_rate, num_heads, attn_bias_dim) 92 | self.self_attention_dropout = nn.Dropout(dropout_rate) 93 | 94 | self.ffn_norm = nn.LayerNorm(hidden_size) 95 | self.ffn = FeedForwardNetwork(hidden_size, ffn_size, dropout_rate) 96 | self.ffn_dropout = nn.Dropout(dropout_rate) 97 | 98 | 99 | def forward(self, x, attn_bias=None): 100 | y = self.self_attention_norm(x) 101 | y = self.self_attention(y, y, y, attn_bias) 102 | y = self.self_attention_dropout(y) 103 | x = x + y 104 | 105 | y = self.ffn_norm(x) 106 | y = self.ffn(y) 107 | y = self.ffn_dropout(y) 108 | x = x + y 109 | return x 110 | 111 | 112 | class GT(nn.Module): 113 | def __init__( 114 | self, 115 | n_layers, 116 | num_heads, 117 | input_dim, 118 | hidden_dim, 119 | output_dim, 120 | attn_bias_dim, 121 | dropout_rate, 122 | intput_dropout_rate, 123 | ffn_dim, 124 | num_global_node, 125 | attention_dropout_rate, 126 | ): 127 | super().__init__() 128 | 129 | self.num_heads = num_heads 130 | self.node_encoder = nn.Linear(input_dim, hidden_dim) 131 | self.input_dropout = nn.Dropout(intput_dropout_rate) 132 | encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout_rate, attention_dropout_rate, num_heads, attn_bias_dim) 133 | for _ in range(n_layers)] 134 | self.layers = nn.ModuleList(encoders) 135 | self.n_layers = n_layers 136 | self.final_ln = nn.LayerNorm(hidden_dim) 137 | self.downstream_out_proj = nn.Linear(hidden_dim, output_dim) 138 | self.hidden_dim = hidden_dim 139 | self.num_global_node = num_global_node 140 | self.graph_token = nn.Embedding(self.num_global_node, hidden_dim) 141 | self.graph_token_virtual_distance = nn.Embedding(self.num_global_node, attn_bias_dim) 142 | self.apply(lambda module: init_params(module, n_layers=n_layers)) 143 | 144 | def forward(self, batched_data, perturb=None): 145 | attn_bias, x = batched_data.attn_bias, batched_data.x 146 | # graph_attn_bias 147 | n_graph, n_node = x.size()[:2] 148 | graph_attn_bias = attn_bias.clone() 149 | node_feature = self.node_encoder(x) # [n_graph, n_node, n_hidden] 150 | if perturb is not None: 151 | node_feature += perturb 152 | 153 | global_node_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1) 154 | node_feature = torch.cat([node_feature, global_node_feature], dim=1) 155 | 156 | graph_attn_bias = torch.cat([graph_attn_bias, self.graph_token_virtual_distance.weight.unsqueeze(0).unsqueeze(2). 157 | repeat(n_graph, 1, n_node, 1)], dim=1) 158 | graph_attn_bias = torch.cat( 159 | [graph_attn_bias, self.graph_token_virtual_distance.weight.unsqueeze(0).unsqueeze(0). 160 | repeat(n_graph, n_node+self.num_global_node, 1, 1)], dim=2) 161 | 162 | # transfomrer encoder 163 | output = self.input_dropout(node_feature) 164 | 165 | for enc_layer in self.layers: 166 | output = enc_layer(output, graph_attn_bias) 167 | output = self.final_ln(output) 168 | 169 | # output part 170 | output = self.downstream_out_proj(output[:, 0, :]) 171 | return F.log_softmax(output, dim=1) 172 | 173 | 174 | -------------------------------------------------------------------------------- /preprocess_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os.path as osp 4 | import pickle 5 | from torch_geometric.data import InMemoryDataset, download_url, Data 6 | from torch.nn import functional as F 7 | from torch.utils.data import DataLoader 8 | from functools import partial 9 | import scipy.sparse as sp 10 | from numpy.linalg import inv 11 | from torch_geometric.datasets import Planetoid, Amazon, Actor 12 | from torch.nn.functional import normalize 13 | import torch_geometric.transforms as T 14 | from torch_geometric.utils.undirected import is_undirected, to_undirected 15 | from torch_sparse import coalesce 16 | from tqdm import tqdm 17 | 18 | 19 | def adj_normalize(mx): 20 | "A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2" 21 | mx = mx + sp.eye(mx.shape[0]) 22 | rowsum = np.array(mx.sum(1)) 23 | r_inv = np.power(rowsum, -0.5).flatten() 24 | r_inv[np.isinf(r_inv)] = 0. 25 | r_mat_inv = sp.diags(r_inv) 26 | mx = r_mat_inv.dot(mx).dot(r_mat_inv) 27 | return mx 28 | 29 | 30 | def eigenvector(L): 31 | EigVal, EigVec = np.linalg.eig(L.toarray()) 32 | idx = EigVal.argsort() # increasing order 33 | EigVal, EigVec = EigVal[idx], np.real(EigVec[:, idx]) 34 | return torch.tensor(EigVec[:, 1:11], dtype = torch.float32) 35 | 36 | 37 | def column_normalize(mx): 38 | "A' = A * D^-1 " 39 | rowsum = np.array(mx.sum(1)) 40 | r_inv = np.power(rowsum, -1.0).flatten() 41 | r_inv[np.isinf(r_inv)] = 0. 42 | r_mat_inv = sp.diags(r_inv) 43 | mx = mx.dot(r_mat_inv) 44 | return mx 45 | 46 | def process_data(p=None): 47 | name = 'pubmed' 48 | dataset = Planetoid(root='./data/', name=name) 49 | #dataset = Actor(root='./data/') 50 | data = dataset[0] 51 | adj = sp.coo_matrix((np.ones(data.edge_index.shape[1]), (data.edge_index[0], data.edge_index[1])), 52 | shape=(data.y.shape[0], data.y.shape[0]), dtype=np.float32) 53 | normalized_adj = adj_normalize(adj) 54 | column_normalized_adj = column_normalize(adj) 55 | sp.save_npz('./dataset/'+name+'/normalized_adj.npz', normalized_adj) 56 | sp.save_npz('./dataset/' + name + '/column_normalized_adj.npz', column_normalized_adj) 57 | c = 0.15 58 | k1 = 15 59 | Samples = 8 # sampled subgraphs for each node 60 | power_adj_list = [normalized_adj] 61 | for m in range(5): 62 | power_adj_list.append(power_adj_list[0]*power_adj_list[m]) 63 | torch.save(data.x, './dataset/' + name + '/x.pt') 64 | torch.save(data.y, './dataset/' + name + '/y.pt') 65 | torch.save(data.edge_index, './dataset/' + name + '/edge_index.pt') 66 | 67 | sampling_matrix = c * inv((sp.eye(adj.shape[0]) - (1 - c) * normalized_adj).toarray()) # power_adj_list[1].toarray() 68 | feature = data.x 69 | 70 | #create subgraph samples 71 | data_list = [] 72 | for id in tqdm(range(data.y.shape[0])): 73 | s = sampling_matrix[id] 74 | s[id] = -1000.0 75 | top_neighbor_index = s.argsort()[-k1:] 76 | 77 | # use sampling matrix for node sampling 78 | # can use random sampling here 79 | s = sampling_matrix[id] 80 | s[id] = 0 81 | s = np.maximum(s, 0) 82 | sample_num1 = np.minimum(k1, (s > 0).sum()) 83 | #create subgraph samples for ensemble 84 | sub_data_list = [] 85 | for _ in range(Samples): 86 | if sample_num1 > 0: 87 | sample_index1 = np.random.choice(a=np.arange(data.y.shape[0]), size=sample_num1, replace=False, p=s/s.sum()) 88 | else: 89 | sample_index1 = np.array([], dtype=int) 90 | 91 | node_feature_id = torch.cat([torch.tensor([id, ]), torch.tensor(sample_index1, dtype=int), torch.tensor(top_neighbor_index[: k1-sample_num1], dtype=int)]) 92 | # create attention bias (positional encoding) 93 | attn_bias = torch.cat([torch.tensor(i[node_feature_id, :][:, node_feature_id].toarray(), dtype=torch.float32).unsqueeze(0) for i in power_adj_list]) 94 | attn_bias = attn_bias.permute(1, 2, 0) 95 | 96 | label = data.y[node_feature_id].long() 97 | feature_id = node_feature_id 98 | assert len(feature_id) == k1+1 99 | sub_data_list.append([attn_bias, feature_id, label]) 100 | data_list.append(sub_data_list) 101 | 102 | torch.save(data_list, './dataset/'+name+'/data.pt') 103 | torch.save(feature, './dataset/'+name+'/feature.pt') 104 | 105 | 106 | def node_sampling(p=None): 107 | print('Sampling Nodes!') 108 | name = 'pubmed' 109 | edge_index = torch.load('./dataset/'+name+'/edge_index.pt') 110 | data_x = torch.load('./dataset/'+name+'/x.pt') 111 | data_y = torch.load('./dataset/'+name+'/y.pt') 112 | adj = sp.coo_matrix((np.ones(edge_index.shape[1]), (edge_index[0], edge_index[1])), 113 | shape=(data_y.shape[0], data_y.shape[0]), dtype=np.float32) 114 | normalized_adj = sp.load_npz('./dataset/'+name+'/normalized_adj.npz') 115 | column_normalized_adj = sp.load_npz('./dataset/' + name + '/column_normalized_adj.npz') 116 | c = 0.15 117 | k1 = 14 118 | Samples = 8 # sampled subgraphs for each node 119 | power_adj_list = [normalized_adj] 120 | for m in range(5): 121 | power_adj_list.append(power_adj_list[0]*power_adj_list[m]) 122 | 123 | sampling_matrix = c * inv((sp.eye(adj.shape[0]) - (1 - c) * column_normalized_adj).toarray()) # power_adj_list[0].toarray() 124 | feature = data_x 125 | 126 | #create subgraph samples 127 | data_list = [] 128 | for id in range(data_y.shape[0]): 129 | s = sampling_matrix[id] 130 | s[id] = -1000.0 131 | top_neighbor_index = s.argsort()[-k1:] 132 | 133 | s = sampling_matrix[id] 134 | s[id] = 0 135 | s = np.maximum(s, 0) 136 | sample_num1 = np.minimum(k1, (s > 0).sum()) 137 | sub_data_list = [] 138 | for _ in range(Samples): 139 | if sample_num1 > 0: 140 | sample_index1 = np.random.choice(a=np.arange(data_y.shape[0]), size=sample_num1, replace=False, p=s/s.sum()) 141 | else: 142 | sample_index1 = np.array([], dtype=int) 143 | 144 | node_feature_id = torch.cat([torch.tensor([id, ]), torch.tensor(sample_index1, dtype=int), torch.tensor(top_neighbor_index[: k1-sample_num1], dtype=int)]) 145 | 146 | attn_bias = torch.cat([torch.tensor(i[node_feature_id, :][:, node_feature_id].toarray(), dtype=torch.float32).unsqueeze(0) for i in power_adj_list]) 147 | attn_bias = attn_bias.permute(1, 2, 0) 148 | 149 | sub_data_list.append([attn_bias, node_feature_id, data_y[node_feature_id].long()]) 150 | data_list.append(sub_data_list) 151 | 152 | return data_list, feature 153 | 154 | 155 | if __name__ == '__main__': 156 | # preprocess data 157 | process_data() 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /start.sh: -------------------------------------------------------------------------------- 1 | [ -z "${exp_name}" ] && exp_name="cora" 2 | [ -z "${epoch}" ] && epoch="1000" 3 | [ -z "${seed}" ] && seed="2022" 4 | [ -z "${arch}" ] && arch="--ffn_dim 128 --hidden_dim 128 --dropout_rate 0.1 --n_layers 4 --peak_lr 2e-4" 5 | [ -z "${batch_size}" ] && batch_size="32" 6 | [ -z "${data_augment}" ] && data_augment="8" 7 | 8 | max_epochs=$((epoch+1)) 9 | echo "=====================================ARGS======================================" 10 | echo "max_epochs: ${max_epochs}" 11 | echo "===============================================================================" 12 | 13 | 14 | echo -e "\n\n" 15 | echo "=====================================ARGS======================================" 16 | echo "arg0: $0" 17 | echo "exp_name: ${exp_name}" 18 | echo "arch: ${arch}" 19 | echo "seed: ${seed}" 20 | echo "batch_size: ${batch_size}" 21 | echo "===============================================================================" 22 | 23 | default_root_dir="./exps/$exp_name/$seed" 24 | mkdir -p $default_root_dir 25 | 26 | python main.py --seed $seed --batch_size $batch_size \ 27 | --dataset_name $exp_name --epochs $epoch\ 28 | $arch \ 29 | --checkpoint_path $default_root_dir\ 30 | --num_data_augment $data_augment 31 | 32 | 33 | --------------------------------------------------------------------------------