├── .gitignore ├── figure.png ├── README.md ├── layers.py ├── networks.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .vscode/ 3 | __pycache__ 4 | 5 | -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/inyeoplee77/SAGPool/HEAD/figure.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Pytorch implementation of Self-Attention Graph Pooling 2 | ==== 3 | 4 | PyTorch implementation of [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) 5 | 6 | ![SAGPool](figure.png) 7 | 8 | 9 | ## Requirements 10 | 11 | * torch_geometric 12 | * torch 13 | 14 | ## Usage 15 | 16 | ```python main.py``` 17 | 18 | 19 | ## Cite 20 | ``` 21 | @InProceedings{pmlr-v97-lee19c, 22 | title = {Self-Attention Graph Pooling}, 23 | author = {Lee, Junhyun and Lee, Inyeop and Kang, Jaewoo}, 24 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 25 | year = {2019}, 26 | month = {09--15 Jun} 27 | } 28 | ``` -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.nn import GCNConv 2 | from torch_geometric.nn.pool.topk_pool import topk,filter_adj 3 | from torch.nn import Parameter 4 | import torch 5 | 6 | 7 | class SAGPool(torch.nn.Module): 8 | def __init__(self,in_channels,ratio=0.8,Conv=GCNConv,non_linearity=torch.tanh): 9 | super(SAGPool,self).__init__() 10 | self.in_channels = in_channels 11 | self.ratio = ratio 12 | self.score_layer = Conv(in_channels,1) 13 | self.non_linearity = non_linearity 14 | def forward(self, x, edge_index, edge_attr=None, batch=None): 15 | if batch is None: 16 | batch = edge_index.new_zeros(x.size(0)) 17 | #x = x.unsqueeze(-1) if x.dim() == 1 else x 18 | score = self.score_layer(x,edge_index).squeeze() 19 | 20 | perm = topk(score, self.ratio, batch) 21 | x = x[perm] * self.non_linearity(score[perm]).view(-1, 1) 22 | batch = batch[perm] 23 | edge_index, edge_attr = filter_adj( 24 | edge_index, edge_attr, perm, num_nodes=score.size(0)) 25 | 26 | return x, edge_index, edge_attr, batch, perm -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import GCNConv 3 | from torch_geometric.nn import GraphConv, TopKPooling 4 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 5 | import torch.nn.functional as F 6 | from layers import SAGPool 7 | 8 | 9 | 10 | 11 | 12 | class Net(torch.nn.Module): 13 | def __init__(self,args): 14 | super(Net, self).__init__() 15 | self.args = args 16 | self.num_features = args.num_features 17 | self.nhid = args.nhid 18 | self.num_classes = args.num_classes 19 | self.pooling_ratio = args.pooling_ratio 20 | self.dropout_ratio = args.dropout_ratio 21 | 22 | self.conv1 = GCNConv(self.num_features, self.nhid) 23 | self.pool1 = SAGPool(self.nhid, ratio=self.pooling_ratio) 24 | self.conv2 = GCNConv(self.nhid, self.nhid) 25 | self.pool2 = SAGPool(self.nhid, ratio=self.pooling_ratio) 26 | self.conv3 = GCNConv(self.nhid, self.nhid) 27 | self.pool3 = SAGPool(self.nhid, ratio=self.pooling_ratio) 28 | 29 | self.lin1 = torch.nn.Linear(self.nhid*2, self.nhid) 30 | self.lin2 = torch.nn.Linear(self.nhid, self.nhid//2) 31 | self.lin3 = torch.nn.Linear(self.nhid//2, self. num_classes) 32 | 33 | def forward(self, data): 34 | x, edge_index, batch = data.x, data.edge_index, data.batch 35 | 36 | x = F.relu(self.conv1(x, edge_index)) 37 | x, edge_index, _, batch, _ = self.pool1(x, edge_index, None, batch) 38 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 39 | 40 | x = F.relu(self.conv2(x, edge_index)) 41 | x, edge_index, _, batch, _ = self.pool2(x, edge_index, None, batch) 42 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 43 | 44 | x = F.relu(self.conv3(x, edge_index)) 45 | x, edge_index, _, batch, _ = self.pool3(x, edge_index, None, batch) 46 | x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 47 | 48 | x = x1 + x2 + x3 49 | 50 | x = F.relu(self.lin1(x)) 51 | x = F.dropout(x, p=self.dropout_ratio, training=self.training) 52 | x = F.relu(self.lin2(x)) 53 | x = F.log_softmax(self.lin3(x), dim=-1) 54 | 55 | return x 56 | 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.datasets import TUDataset 3 | from torch_geometric.data import DataLoader 4 | from torch_geometric import utils 5 | from networks import Net 6 | import torch.nn.functional as F 7 | import argparse 8 | import os 9 | from torch.utils.data import random_split 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--seed', type=int, default=777, 13 | help='seed') 14 | parser.add_argument('--batch_size', type=int, default=128, 15 | help='batch size') 16 | parser.add_argument('--lr', type=float, default=0.0005, 17 | help='learning rate') 18 | parser.add_argument('--weight_decay', type=float, default=0.0001, 19 | help='weight decay') 20 | parser.add_argument('--nhid', type=int, default=128, 21 | help='hidden size') 22 | parser.add_argument('--pooling_ratio', type=float, default=0.5, 23 | help='pooling ratio') 24 | parser.add_argument('--dropout_ratio', type=float, default=0.5, 25 | help='dropout ratio') 26 | parser.add_argument('--dataset', type=str, default='DD', 27 | help='DD/PROTEINS/NCI1/NCI109/Mutagenicity') 28 | parser.add_argument('--epochs', type=int, default=100000, 29 | help='maximum number of epochs') 30 | parser.add_argument('--patience', type=int, default=50, 31 | help='patience for earlystopping') 32 | parser.add_argument('--pooling_layer_type', type=str, default='GCNConv', 33 | help='DD/PROTEINS/NCI1/NCI109/Mutagenicity') 34 | 35 | args = parser.parse_args() 36 | args.device = 'cpu' 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | torch.cuda.manual_seed(args.seed) 40 | args.device = 'cuda:0' 41 | dataset = TUDataset(os.path.join('data',args.dataset),name=args.dataset) 42 | args.num_classes = dataset.num_classes 43 | args.num_features = dataset.num_features 44 | 45 | num_training = int(len(dataset)*0.8) 46 | num_val = int(len(dataset)*0.1) 47 | num_test = len(dataset) - (num_training+num_val) 48 | training_set,validation_set,test_set = random_split(dataset,[num_training,num_val,num_test]) 49 | 50 | 51 | 52 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=True) 53 | val_loader = DataLoader(validation_set,batch_size=args.batch_size,shuffle=False) 54 | test_loader = DataLoader(test_set,batch_size=1,shuffle=False) 55 | model = Net(args).to(args.device) 56 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 57 | 58 | 59 | def test(model,loader): 60 | model.eval() 61 | correct = 0. 62 | loss = 0. 63 | for data in loader: 64 | data = data.to(args.device) 65 | out = model(data) 66 | pred = out.max(dim=1)[1] 67 | correct += pred.eq(data.y).sum().item() 68 | loss += F.nll_loss(out,data.y,reduction='sum').item() 69 | return correct / len(loader.dataset),loss / len(loader.dataset) 70 | 71 | 72 | min_loss = 1e10 73 | patience = 0 74 | 75 | for epoch in range(args.epochs): 76 | model.train() 77 | for i, data in enumerate(train_loader): 78 | data = data.to(args.device) 79 | out = model(data) 80 | loss = F.nll_loss(out, data.y) 81 | print("Training loss:{}".format(loss.item())) 82 | loss.backward() 83 | optimizer.step() 84 | optimizer.zero_grad() 85 | val_acc,val_loss = test(model,val_loader) 86 | print("Validation loss:{}\taccuracy:{}".format(val_loss,val_acc)) 87 | if val_loss < min_loss: 88 | torch.save(model.state_dict(),'latest.pth') 89 | print("Model saved at epoch{}".format(epoch)) 90 | min_loss = val_loss 91 | patience = 0 92 | else: 93 | patience += 1 94 | if patience > args.patience: 95 | break 96 | 97 | model = Net(args).to(args.device) 98 | model.load_state_dict(torch.load('latest.pth')) 99 | test_acc,test_loss = test(model,test_loader) 100 | print("Test accuarcy:{}".format(test_acc)) 101 | --------------------------------------------------------------------------------