├── __pycache__ └── util.cpython-37.pyc ├── for_gin ├── __pycache__ │ └── util.cpython-37.pyc ├── models │ ├── __pycache__ │ │ ├── mlp.cpython-36.pyc │ │ ├── mlp.cpython-37.pyc │ │ ├── graphcnn.cpython-36.pyc │ │ ├── graphcnn.cpython-37.pyc │ │ ├── graphcnn_pooled.cpython-37.pyc │ │ └── graphcnn_pooled_multilayer.cpython-37.pyc │ ├── mlp.py │ └── graphcnn_pooled_multilayer.py ├── train_powerfulgnn_oneenc.py ├── util.py └── train_powerfulgnn_twoenc.py ├── for_hgp-sl ├── __pycache__ │ ├── util.cpython-37.pyc │ ├── layers.cpython-37.pyc │ ├── models.cpython-37.pyc │ └── sparse_softmax.cpython-37.pyc ├── models.py ├── sparse_softmax.py ├── train_hgp-sl_twoenc.py ├── train_hgp-sl_oneenc.py ├── util.py └── layers.py └── README.md /__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /for_gin/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /for_hgp-sl/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_hgp-sl/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /for_hgp-sl/__pycache__/layers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_hgp-sl/__pycache__/layers.cpython-37.pyc -------------------------------------------------------------------------------- /for_hgp-sl/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_hgp-sl/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/mlp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/mlp.cpython-37.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/graphcnn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/graphcnn.cpython-36.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/graphcnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/graphcnn.cpython-37.pyc -------------------------------------------------------------------------------- /for_hgp-sl/__pycache__/sparse_softmax.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_hgp-sl/__pycache__/sparse_softmax.cpython-37.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/graphcnn_pooled.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/graphcnn_pooled.cpython-37.pyc -------------------------------------------------------------------------------- /for_gin/models/__pycache__/graphcnn_pooled_multilayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuxiangRen/Label-Contrastive-Coding-based-Graph-Neural-Network-for-Graph-Classification-/HEAD/for_gin/models/__pycache__/graphcnn_pooled_multilayer.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This is the implementation of paper: 3 | 4 | > Label Contrastive Coding based Graph Neural Network for Graph Classification 5 | 6 | ### Requirements 7 | The code is implemented in Python 3.7. Package used for development are just below. 8 | ``` 9 | networkx 10 | numpy 11 | scipy 12 | torch == 1.4.0 13 | torch_geometric == 1.6.0 14 | ``` 15 | 16 | 17 | ###Instructions for running the code 18 | 19 | For LCGNN with different encoders, the training scripts are in separate files (e.g., ./for_gin). 20 | 21 | 22 | 1, Enter the for_gin file 23 | ``` 24 | cd ./for_gin 25 | ``` 26 | 27 | 2, Run the code 28 | ``` 29 | python3 train_powerfulgnn_oneenc.py 30 | ``` 31 | for the momentum weight $\alpha = 0$ condition; or run the code 32 | 33 | ``` 34 | python3 train_powerfulgnn_twoenc.py 35 | ``` 36 | for other conditions. 37 | 38 | 39 | 40 | ###Note: 41 | 42 | 1, The default setting includes using the GPU. 43 | 2, To change model configurations, (e.g., set the epoch numbers of training as NNUMBER), add config `--epochs NUMBER`. 44 | 3, For the size limitation of Github, you can get the dataset from https://www.dropbox.com/sh/kc7xf42kz4lqx9a/AAC9wKim768TBNocN1JNPudFa?dl=0 45 | -------------------------------------------------------------------------------- /for_gin/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ###MLP with lienar output 6 | class MLP(nn.Module): 7 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 8 | ''' 9 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. 10 | input_dim: dimensionality of input features 11 | hidden_dim: dimensionality of hidden units at ALL layers 12 | output_dim: number of classes for prediction 13 | device: which device to use 14 | ''' 15 | 16 | super(MLP, self).__init__() 17 | 18 | self.linear_or_not = True #default is linear model 19 | self.num_layers = num_layers 20 | 21 | if num_layers < 1: 22 | raise ValueError("number of layers should be positive!") 23 | elif num_layers == 1: 24 | #Linear model 25 | self.linear = nn.Linear(input_dim, output_dim) 26 | else: 27 | #Multi-layer model 28 | self.linear_or_not = False 29 | self.linears = torch.nn.ModuleList() 30 | self.batch_norms = torch.nn.ModuleList() 31 | 32 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 33 | for layer in range(num_layers - 2): 34 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 35 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 36 | 37 | for layer in range(num_layers - 1): 38 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 39 | 40 | def forward(self, x): 41 | if self.linear_or_not: 42 | #If linear model 43 | return self.linear(x) 44 | else: 45 | #If MLP 46 | h = x 47 | for layer in range(self.num_layers - 1): 48 | h = F.relu(self.batch_norms[layer](self.linears[layer](h))) 49 | return self.linears[self.num_layers - 1](h) -------------------------------------------------------------------------------- /for_hgp-sl/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp 4 | from torch_geometric.nn import GCNConv 5 | 6 | from layers import GCN, HGPSLPool 7 | 8 | 9 | class Model(torch.nn.Module): 10 | def __init__(self, args): 11 | super(Model, self).__init__() 12 | self.args = args 13 | self.num_features = args.num_features 14 | self.nhid = args.nhid 15 | self.num_classes = args.num_classes 16 | self.pooling_ratio = args.pooling_ratio 17 | self.dropout_ratio = args.dropout_ratio 18 | self.sample = args.sample_neighbor 19 | self.sparse = args.sparse_attention 20 | self.sl = args.structure_learning 21 | self.lamb = args.lamb 22 | 23 | self.conv1 = GCNConv(self.num_features, self.nhid) 24 | self.conv2 = GCN(self.nhid, self.nhid) 25 | self.conv3 = GCN(self.nhid, self.nhid) 26 | 27 | self.pool1 = HGPSLPool(self.nhid, self.pooling_ratio, self.sample, self.sparse, self.sl, self.lamb) 28 | self.pool2 = HGPSLPool(self.nhid, self.pooling_ratio, self.sample, self.sparse, self.sl, self.lamb) 29 | 30 | self.lin1 = torch.nn.Linear(self.nhid * 2, self.nhid) 31 | self.lin2 = torch.nn.Linear(self.nhid, self.nhid // 2) 32 | self.lin3 = torch.nn.Linear(self.nhid // 2, self.num_classes) 33 | 34 | def forward(self, data, data_x=None): 35 | if(data_x == None): 36 | x, edge_index, batch = data.x, data.edge_index, data.batch# x get 1006*37 37 | else: 38 | x, edge_index, batch = data_x, data.edge_index, data.batch 39 | edge_attr = None 40 | 41 | hidden_feats_dict = {} 42 | 43 | x = F.relu(self.conv1(x, edge_index, edge_attr))# x get 1006*128 44 | x, edge_index, edge_attr, batch = self.pool1(x, edge_index, edge_attr, batch)# x get 510, 128 45 | x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 46 | x = F.relu(self.conv2(x, edge_index, edge_attr)) 47 | x, edge_index, edge_attr, batch = self.pool2(x, edge_index, edge_attr, batch) 48 | x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 49 | x = F.relu(self.conv3(x, edge_index, edge_attr)) 50 | x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1) 51 | 52 | x = F.relu(x1) + F.relu(x2) + F.relu(x3) 53 | # hidden_feats_dict[0] = x 54 | x = F.relu(self.lin1(x)) 55 | x = F.dropout(x, p=self.dropout_ratio, training=self.training) 56 | hidden_feats_dict[1] = x 57 | x = F.relu(self.lin2(x)) 58 | x = F.dropout(x, p=self.dropout_ratio, training=self.training) 59 | hidden_feats_dict[2] = x 60 | x = F.log_softmax(self.lin3(x), dim=-1) 61 | hidden_feats_dict[3] = x 62 | return x, hidden_feats_dict 63 | -------------------------------------------------------------------------------- /for_hgp-sl/sparse_softmax.py: -------------------------------------------------------------------------------- 1 | """ 2 | An original implementation of sparsemax (Martins & Astudillo, 2016) is available at 3 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py. 4 | See `From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification, ICML 2016` 5 | for detailed description. 6 | 7 | We make some modifications to make it work at scatter operation scenarios, e.g., calculate softmax according to batch 8 | indicators. 9 | 10 | Usage: 11 | >> x = torch.tensor([ 1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 12 | 0.1831, -0.0061]) 13 | >> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) 14 | >> sparse_attention = Sparsemax() 15 | >> res = sparse_attention(x, batch) 16 | >> print(res) 17 | tensor([0.5343, 0.0000, 0.0000, 0.4657, 0.0612, 0.0000, 0.0000, 0.0000, 0.5640, 18 | 0.3748]) 19 | 20 | """ 21 | import torch 22 | import torch.nn as nn 23 | from torch.autograd import Function 24 | from torch_scatter import scatter_add, scatter_max 25 | 26 | 27 | def scatter_sort(x, batch, fill_value=-1e16): 28 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 29 | batch_size, max_num_nodes = num_nodes.size(0), num_nodes.max().item() 30 | 31 | cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 32 | 33 | index = torch.arange(batch.size(0), dtype=torch.long, device=x.device) 34 | index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes) 35 | 36 | dense_x = x.new_full((batch_size * max_num_nodes,), fill_value) 37 | dense_x[index] = x 38 | dense_x = dense_x.view(batch_size, max_num_nodes) 39 | 40 | sorted_x, _ = dense_x.sort(dim=-1, descending=True) 41 | cumsum_sorted_x = sorted_x.cumsum(dim=-1) 42 | cumsum_sorted_x = cumsum_sorted_x.view(-1) 43 | 44 | sorted_x = sorted_x.view(-1) 45 | filled_index = sorted_x != fill_value 46 | 47 | sorted_x = sorted_x[filled_index] 48 | cumsum_sorted_x = cumsum_sorted_x[filled_index] 49 | 50 | return sorted_x, cumsum_sorted_x 51 | 52 | 53 | def _make_ix_like(batch): 54 | num_nodes = scatter_add(batch.new_ones(batch.size(0)), batch, dim=0) 55 | idx = [torch.arange(1, i + 1, dtype=torch.long, device=batch.device) for i in num_nodes] 56 | idx = torch.cat(idx, dim=0) 57 | 58 | return idx 59 | 60 | 61 | def _threshold_and_support(x, batch): 62 | """Sparsemax building block: compute the threshold 63 | Args: 64 | x: input tensor to apply the sparsemax 65 | batch: group indicators 66 | Returns: 67 | the threshold value 68 | """ 69 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 70 | cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 71 | 72 | sorted_input, input_cumsum = scatter_sort(x, batch) 73 | input_cumsum = input_cumsum - 1.0 74 | rhos = _make_ix_like(batch).to(x.dtype) 75 | support = rhos * sorted_input > input_cumsum 76 | 77 | support_size = scatter_add(support.to(batch.dtype), batch) 78 | # mask invalid index, for example, if batch is not start from 0 or not continuous, it may result in negative index 79 | idx = support_size + cum_num_nodes - 1 80 | mask = idx < 0 81 | idx[mask] = 0 82 | tau = input_cumsum.gather(0, idx) 83 | tau /= support_size.to(x.dtype) 84 | 85 | return tau, support_size 86 | 87 | 88 | class SparsemaxFunction(Function): 89 | 90 | @staticmethod 91 | def forward(ctx, x, batch): 92 | """sparsemax: normalizing sparse transform 93 | Parameters: 94 | ctx: context object 95 | x (Tensor): shape (N, ) 96 | batch: group indicator 97 | Returns: 98 | output (Tensor): same shape as input 99 | """ 100 | max_val, _ = scatter_max(x, batch) 101 | x -= max_val[batch] 102 | tau, supp_size = _threshold_and_support(x, batch) 103 | output = torch.clamp(x - tau[batch], min=0) 104 | ctx.save_for_backward(supp_size, output, batch) 105 | 106 | return output 107 | 108 | @staticmethod 109 | def backward(ctx, grad_output): 110 | supp_size, output, batch = ctx.saved_tensors 111 | grad_input = grad_output.clone() 112 | grad_input[output == 0] = 0 113 | 114 | v_hat = scatter_add(grad_input, batch) / supp_size.to(output.dtype) 115 | grad_input = torch.where(output != 0, grad_input - v_hat[batch], grad_input) 116 | 117 | return grad_input, None 118 | 119 | 120 | sparsemax = SparsemaxFunction.apply 121 | 122 | 123 | class Sparsemax(nn.Module): 124 | 125 | def __init__(self): 126 | super(Sparsemax, self).__init__() 127 | 128 | def forward(self, x, batch): 129 | return sparsemax(x, batch) 130 | 131 | 132 | if __name__ == '__main__': 133 | sparse_attention = Sparsemax() 134 | input_x = torch.tensor([1.7301, 0.6792, -1.0565, 1.6614, -0.3196, -0.7790, -0.3877, -0.4943, 0.1831, -0.0061]) 135 | input_batch = torch.cat([torch.zeros(4, dtype=torch.long), torch.ones(6, dtype=torch.long)], dim=0) 136 | res = sparse_attention(input_x, input_batch) 137 | print(res) 138 | -------------------------------------------------------------------------------- /for_hgp-sl/train_hgp-sl_twoenc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import time 5 | import pickle 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import copy 10 | from models import Model 11 | from torch.utils.data import random_split 12 | from torch_geometric.data import DataLoader 13 | from torch_geometric.datasets import TUDataset 14 | from util import contrastive_loss_labelwise_winslide, dequeue_and_enqueue_HGPSL, momentum_update 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument('--seed', type=int, default=777, help='random seed') 19 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 20 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 21 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') 22 | parser.add_argument('--nhid', type=int, default=128, help='hidden size') 23 | parser.add_argument('--sample_neighbor', type=bool, default=True, help='whether sample neighbors') 24 | parser.add_argument('--sparse_attention', type=bool, default=True, help='whether use sparse attention') 25 | parser.add_argument('--structure_learning', type=bool, default=True, help='whether perform structure learning') 26 | parser.add_argument('--pooling_ratio', type=float, default=0.5, help='pooling ratio') 27 | parser.add_argument('--dropout_ratio', type=float, default=0.0, help='dropout ratio') 28 | parser.add_argument('--lamb', type=float, default=1.0, help='trade-off parameter') 29 | parser.add_argument('--dataset', type=str, default='PROTEINS', help='DD/PROTEINS/NCI1') 30 | parser.add_argument('--device', type=str, default='cuda:0', help='specify cuda devices') 31 | parser.add_argument('--epochs', type=int, default=350, help='maximum number of epochs') 32 | parser.add_argument('--patience', type=int, default=100, help='patience for early stopping') 33 | parser.add_argument('--contraloss_weight', type=float, default=0.5, help='The weight of contrastive loss term.') 34 | parser.add_argument('--temperature', type=float, default=0.07, help='The temperature for contrastive loss.') 35 | 36 | args = parser.parse_args() 37 | torch.manual_seed(args.seed) 38 | if torch.cuda.is_available(): 39 | torch.cuda.manual_seed(args.seed) 40 | 41 | dataset = TUDataset(os.path.join('data', args.dataset), name=args.dataset, use_node_attr=True) 42 | 43 | args.num_classes = dataset.num_classes 44 | args.num_features = dataset.num_features 45 | 46 | num_training = int(len(dataset) * 0.8) 47 | num_val = int(len(dataset) * 0.1) 48 | num_test = len(dataset) - (num_training + num_val) 49 | 50 | 51 | def train(model, model_k, optimizer, queue, train_loader, val_loader, train_idx_by_label): 52 | min_loss = 1e10 53 | patience_cnt = 0 54 | val_loss_values = [] 55 | best_epoch = 0 56 | 57 | t = time.time() 58 | model.train() 59 | model_k.train() 60 | 61 | val_acc, train_acc, train_celoss, train_contraloss= [],[],[],[] 62 | for epoch in range(args.epochs): 63 | celoss_train = 0.0 64 | contraloss_train = 0.0 65 | correct = 0 66 | start_ptr = 0 67 | for i, data in enumerate(train_loader): 68 | optimizer.zero_grad() 69 | data = data.to(args.device) 70 | out, _ = model(data) 71 | celoss = F.nll_loss(out, data.y) 72 | _, hidden_feats_dict = model_k(data) 73 | # update memory bank 74 | queue = dequeue_and_enqueue_HGPSL(hidden_feats_dict, start_ptr, start_ptr+len(data.y), queue) 75 | start_ptr += len(data.y) 76 | # compute label-wise contrastive loss 77 | batch_idx_by_label = {} 78 | for i in range(args.num_classes): 79 | batch_idx_by_label[i] = [idx for idx in range(len(data.y)) if data.y[idx] == i] 80 | 81 | contraloss = 0.0 82 | for layer in hidden_feats_dict: 83 | contraloss += contrastive_loss_labelwise_winslide(args, batch_idx_by_label, train_idx_by_label, 84 | hidden_feats_dict[layer], queue[layer].detach().clone()) 85 | 86 | loss = celoss + args.contraloss_weight*contraloss 87 | loss.backward() 88 | optimizer.step() 89 | # update model_k by momentum 90 | model_k = momentum_update(model, model_k, 0.999) 91 | 92 | celoss_train += celoss.item() 93 | contraloss_train += contraloss 94 | pred = out.max(dim=1)[1] 95 | correct += pred.eq(data.y).sum().item() 96 | acc_train = correct / len(train_loader.dataset) 97 | acc_val, loss_val = compute_test(test_loader) 98 | print('Epoch: {:04d}'.format(epoch + 1), 'celoss_train: {:.6f}'.format(celoss_train), 'contraloss_train: {:.6f}'.format(contraloss_train), 99 | 'acc_train: {:.6f}'.format(acc_train), 'loss_val: {:.6f}'.format(loss_val), 100 | 'acc_val: {:.6f}'.format(acc_val), 'time: {:.6f}s'.format(time.time() - t)) 101 | 102 | train_celoss.append(celoss_train) 103 | train_contraloss.append(contraloss_train) 104 | train_acc.append(acc_train) 105 | val_acc.append(acc_val) 106 | val_loss_values.append(loss_val) 107 | 108 | print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t)) 109 | 110 | return best_epoch, train_celoss, train_contraloss, train_acc, val_acc, val_loss_values 111 | 112 | 113 | def compute_test(loader): 114 | model.eval() 115 | correct = 0.0 116 | loss_test = 0.0 117 | for data in loader: 118 | data = data.to(args.device) 119 | out, _ = model(data) 120 | pred = out.max(dim=1)[1] 121 | correct += pred.eq(data.y).sum().item() 122 | loss_test += F.nll_loss(out, data.y).item() 123 | return correct / len(loader.dataset), loss_test 124 | 125 | 126 | if __name__ == '__main__': 127 | Train_celoss, Train_contraloss, Train_acc, Val_loss, Val_acc = {},{},{},{},{} 128 | for i in range(10): 129 | # prepare data 130 | training_set, validation_set, test_set = random_split(dataset, [num_training, num_val, num_test]) 131 | #training_set, test_set = random_split(dataset, [num_training, num_test]) 132 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=False) 133 | val_loader = DataLoader(validation_set, batch_size=args.batch_size, shuffle=False) 134 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) 135 | 136 | train_idx_by_label = {} 137 | for i in range(args.num_classes): 138 | train_idx_by_label[i] = [idx for idx in range(num_training) if training_set[idx].y == i] 139 | 140 | model = Model(args).to(args.device) 141 | model_k = copy.deepcopy(Model(args)).to(args.device) 142 | for param in model_k.parameters(): 143 | param.requires_grad = False 144 | 145 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 146 | # Model training 147 | queue = {1:F.normalize(torch.randn(num_training, args.nhid), dim=1).to(args.device), 148 | 2:F.normalize(torch.randn(num_training, args.nhid//2), dim=1).to(args.device), 149 | 3:F.normalize(torch.randn(num_training, args.num_classes), dim=1).to(args.device)} 150 | 151 | best_model, train_celoss, train_contraloss, train_acc, val_acc, val_loss = train(model, model_k, optimizer, queue, train_loader, val_loader, train_idx_by_label) 152 | Train_celoss[i], Train_contraloss[i], Train_acc[i], Val_loss[i], Val_acc[i] = train_celoss, train_contraloss, train_acc, val_loss, val_acc 153 | 154 | Best_acc.append(max(val_acc)) 155 | show_results += np.array(val_acc) 156 | 157 | -------------------------------------------------------------------------------- /for_hgp-sl/train_hgp-sl_oneenc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import time 5 | import pickle 6 | import torch 7 | import torch.nn.functional as F 8 | import numpy as np 9 | from models import Model 10 | from torch.utils.data import random_split 11 | from torch_geometric.data import DataLoader 12 | from torch_geometric.datasets import TUDataset 13 | from util import contrastive_loss_labelwise_winslide, dequeue_and_enqueue_HGPSL 14 | 15 | parser = argparse.ArgumentParser() 16 | 17 | parser.add_argument('--seed', type=int, default=777, help='random seed') 18 | parser.add_argument('--batch_size', type=int, default=512, help='batch size') 19 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 20 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight decay') 21 | parser.add_argument('--nhid', type=int, default=128, help='hidden size') 22 | parser.add_argument('--sample_neighbor', type=bool, default=True, help='whether sample neighbors') 23 | parser.add_argument('--sparse_attention', type=bool, default=True, help='whether use sparse attention') 24 | parser.add_argument('--structure_learning', type=bool, default=True, help='whether perform structure learning') 25 | parser.add_argument('--pooling_ratio', type=float, default=0.5, help='pooling ratio') 26 | parser.add_argument('--dropout_ratio', type=float, default=0.0, help='dropout ratio') 27 | parser.add_argument('--lamb', type=float, default=1.0, help='trade-off parameter') 28 | parser.add_argument('--dataset', type=str, default='PROTEINS', help='DD/PROTEINS/NCI1/') 29 | parser.add_argument('--device', type=str, default='cuda:0', help='specify cuda devices') 30 | parser.add_argument('--epochs', type=int, default=1000, help='maximum number of epochs') 31 | parser.add_argument('--patience', type=int, default=100, help='patience for early stopping') 32 | parser.add_argument('--contraloss_weight', type=float, default=0.5, help='The weight of contrastive loss term.') 33 | parser.add_argument('--temperature', type=float, default=0.07, help='The temperature for contrastive loss.') 34 | 35 | args = parser.parse_args() 36 | torch.manual_seed(args.seed) 37 | if torch.cuda.is_available(): 38 | torch.cuda.manual_seed(args.seed) 39 | 40 | dataset = TUDataset(os.path.join('data', args.dataset), name=args.dataset, use_node_attr=True) 41 | 42 | args.num_classes = dataset.num_classes 43 | args.num_features = dataset.num_features 44 | 45 | num_training = int(len(dataset) * 0.8) 46 | num_val = int(len(dataset) * 0.1) 47 | num_test = len(dataset) - (num_training + num_val) 48 | 49 | 50 | def train(model, optimizer, queue, train_loader, val_loader, train_idx_by_label): 51 | min_loss = 1e10 52 | patience_cnt = 0 53 | val_loss_values = [] 54 | best_epoch = 0 55 | 56 | t = time.time() 57 | model.train() 58 | 59 | val_acc, train_acc, train_celoss, train_contraloss= [],[],[],[] 60 | for epoch in range(args.epochs): 61 | celoss_train = 0.0 62 | contraloss_train = 0.0 63 | correct = 0 64 | start_ptr = 0 65 | for i, data in enumerate(train_loader): 66 | optimizer.zero_grad() 67 | data = data.to(args.device) 68 | out, hidden_feats_dict = model(data) 69 | celoss = F.nll_loss(out, data.y) 70 | 71 | # update memory bank 72 | queue = dequeue_and_enqueue_HGPSL(hidden_feats_dict, start_ptr, start_ptr+len(data.y), queue) 73 | start_ptr += len(data.y) 74 | # compute label-wise contrastive loss 75 | batch_idx_by_label = {} 76 | for i in range(args.num_classes): 77 | batch_idx_by_label[i] = [idx for idx in range(len(data.y)) if data.y[idx] == i] 78 | 79 | contraloss = 0.0 80 | for layer in hidden_feats_dict: 81 | contraloss += contrastive_loss_labelwise_winslide(args, batch_idx_by_label, train_idx_by_label, 82 | hidden_feats_dict[layer], queue[layer].detach().clone()) 83 | 84 | loss = celoss + args.contraloss_weight*contraloss 85 | loss.backward() 86 | optimizer.step() 87 | celoss_train += celoss.item() 88 | contraloss_train += contraloss 89 | pred = out.max(dim=1)[1] 90 | correct += pred.eq(data.y).sum().item() 91 | acc_train = correct / len(train_loader.dataset) 92 | acc_val, loss_val = compute_test(val_loader) 93 | print('Epoch: {:04d}'.format(epoch + 1), 'celoss_train: {:.6f}'.format(celoss_train), 'contraloss_train: {:.6f}'.format(contraloss_train), 94 | 'acc_train: {:.6f}'.format(acc_train), 'loss_val: {:.6f}'.format(loss_val), 95 | 'acc_val: {:.6f}'.format(acc_val), 'time: {:.6f}s'.format(time.time() - t)) 96 | 97 | train_celoss.append(celoss_train) 98 | train_contraloss.append(contraloss_train) 99 | train_acc.append(acc_train) 100 | val_acc.append(acc_val) 101 | val_loss_values.append(loss_val) 102 | torch.save(model.state_dict(), '{}.winpth'.format(epoch)) 103 | if val_loss_values[-1] < min_loss: 104 | min_loss = val_loss_values[-1] 105 | best_epoch = epoch 106 | patience_cnt = 0 107 | else: 108 | patience_cnt += 1 109 | 110 | if patience_cnt == args.patience: 111 | break 112 | 113 | files = glob.glob('*.winpth') 114 | for f in files: 115 | epoch_nb = int(f.split('.')[0]) 116 | if epoch_nb < best_epoch: 117 | os.remove(f) 118 | 119 | files = glob.glob('*.winpth') 120 | for f in files: 121 | epoch_nb = int(f.split('.')[0]) 122 | if epoch_nb > best_epoch: 123 | os.remove(f) 124 | print('Optimization Finished! Total time elapsed: {:.6f}'.format(time.time() - t)) 125 | 126 | return best_epoch, train_celoss, train_contraloss, train_acc, val_acc, val_loss_values 127 | 128 | 129 | def compute_test(loader): 130 | model.eval() 131 | correct = 0.0 132 | loss_test = 0.0 133 | for data in loader: 134 | data = data.to(args.device) 135 | out, _ = model(data) 136 | pred = out.max(dim=1)[1] 137 | correct += pred.eq(data.y).sum().item() 138 | loss_test += F.nll_loss(out, data.y).item() 139 | return correct / len(loader.dataset), loss_test 140 | 141 | 142 | if __name__ == '__main__': 143 | Train_celoss, Train_contraloss, Train_acc, Val_loss, Val_acc = {},{},{},{},{} 144 | for i in range(10): 145 | 146 | # prepare data 147 | training_set, validation_set, test_set = random_split(dataset, [num_training, num_val, num_test]) 148 | train_loader = DataLoader(training_set, batch_size=args.batch_size, shuffle=False) 149 | val_loader = DataLoader(validation_set, batch_size=args.batch_size, shuffle=False) 150 | test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False) 151 | 152 | train_idx_by_label = {} 153 | for i in range(args.num_classes): 154 | train_idx_by_label[i] = [idx for idx in range(num_training) if training_set[idx].y == i] 155 | 156 | model = Model(args).to(args.device) 157 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 158 | # Model training 159 | queue = {1:F.normalize(torch.randn(num_training, args.nhid), dim=1).to(args.device), 160 | 2:F.normalize(torch.randn(num_training, args.nhid//2), dim=1).to(args.device), 161 | 3:F.normalize(torch.randn(num_training, args.num_classes), dim=1).to(args.device)} 162 | 163 | best_model, train_celoss, train_contraloss, train_acc, val_acc, val_loss = train(model, optimizer, queue, train_loader, val_loader, train_idx_by_label) 164 | Train_celoss[i], Train_contraloss[i], Train_acc[i], Val_loss[i], Val_acc[i] = train_celoss, train_contraloss, train_acc, val_loss, val_acc 165 | # Restore best model for test set 166 | model.load_state_dict(torch.load('{}.winpth'.format(best_model))) 167 | test_acc, test_loss = compute_test(test_loader) 168 | Test_loss.append(test_loss) 169 | Test_acc.append(test_acc) 170 | 171 | show_results += np.array(val_acc) 172 | files = glob.glob('*.winpth') 173 | for f in files: 174 | epoch_nb = int(f.split('.')[0]) 175 | if epoch_nb == best_model: 176 | os.remove(f) 177 | 178 | 179 | 180 | -------------------------------------------------------------------------------- /for_gin/train_powerfulgnn_oneenc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import numpy as np 7 | import pickle 8 | from tqdm import tqdm 9 | 10 | from util import load_data, separate_data, to_cuda, contrastive_loss_labelwise_winslide, dequeue_and_enqueue_multiLayer 11 | from models.graphcnn_pooled_multilayer import GraphCNN 12 | 13 | criterion = nn.CrossEntropyLoss() 14 | 15 | def train(args, model, device, train_graphs, queue, num_classes, optimizer, epoch): 16 | train_idx_by_label = {} 17 | for i in range(num_classes): 18 | train_idx_by_label[i] = [idx for idx in range(len(train_graphs)) if train_graphs[idx].label == i] 19 | 20 | model.train() 21 | 22 | pbar = tqdm(range(args.iters_per_epoch), unit='batch') 23 | 24 | celoss_accum, contraloss_accum = 0, 0 25 | for pos in pbar: 26 | selected_batch_idx = np.random.permutation(len(train_graphs))[:args.batch_size] 27 | batch_graph = [train_graphs[idx] for idx in selected_batch_idx] 28 | 29 | output, hidden_feats_dict = model(batch_graph) 30 | 31 | labels_batch = torch.LongTensor([graph.label for graph in batch_graph]).to(device) 32 | #compute cross-entropy loss 33 | celoss = criterion(output, labels_batch) 34 | 35 | # update memory bank 36 | queue = dequeue_and_enqueue_multiLayer(hidden_feats_dict, selected_batch_idx, queue) 37 | # compute label-wise contrastive loss 38 | batch_i_by_label = {} 39 | for i in range(num_classes): 40 | batch_i_by_label[i] = [idx for idx in range(len(batch_graph)) if batch_graph[idx].label == i] 41 | 42 | contraloss = 0 43 | for layer in hidden_feats_dict: 44 | contraloss += contrastive_loss_labelwise_winslide(args, batch_i_by_label, train_idx_by_label, 45 | hidden_feats_dict[layer], queue[layer].detach().clone()) 46 | #backprop 47 | loss = celoss + args.contraloss_weight * contraloss 48 | optimizer.zero_grad() 49 | loss.backward() 50 | optimizer.step() 51 | 52 | celoss_accum += celoss.detach().cpu().numpy() 53 | contraloss_accum += contraloss.detach().cpu().numpy() 54 | 55 | 56 | #report 57 | pbar.set_description('epoch: %d' % (epoch)) 58 | 59 | avg_celoss = celoss_accum/args.iters_per_epoch 60 | avg_contraloss = contraloss_accum/args.iters_per_epoch 61 | print("training celoss: %4f, contraloss: %4f" % (avg_celoss, avg_contraloss)) 62 | 63 | return avg_celoss, avg_contraloss, queue 64 | 65 | ###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation) 66 | def pass_data_iteratively(model, graphs, minibatch_size = 64): 67 | model.eval() 68 | output = [] 69 | idx = np.arange(len(graphs)) 70 | for i in range(0, len(graphs), minibatch_size): 71 | sampled_idx = idx[i:i+minibatch_size] 72 | if len(sampled_idx) == 0: 73 | continue 74 | output.append(model([graphs[j] for j in sampled_idx])[0].detach()) 75 | return torch.cat(output, 0) 76 | 77 | def test(args, model, device, train_graphs, test_graphs, epoch): 78 | model.eval() 79 | 80 | output = pass_data_iteratively(model, train_graphs) 81 | pred = output.max(1, keepdim=True)[1] 82 | labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device) 83 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 84 | acc_train = correct / float(len(train_graphs)) 85 | 86 | output = pass_data_iteratively(model, test_graphs) 87 | pred = output.max(1, keepdim=True)[1] 88 | labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device) 89 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 90 | acc_test = correct / float(len(test_graphs)) 91 | 92 | print("accuracy train: %f test: %f" % (acc_train, acc_test)) 93 | 94 | return acc_train, acc_test 95 | 96 | def main(): 97 | # Training settings 98 | # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper. 99 | parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') 100 | parser.add_argument('--dataset', type=str, default="MUTAG", help='name of dataset') 101 | parser.add_argument('--device', type=int, default=2, help='which gpu to use if any') 102 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') 103 | parser.add_argument('--iters_per_epoch', type=int, default=50, help='number of iterations per each epoch') 104 | parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train') 105 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 106 | parser.add_argument('--seed', type=int, default=0, help='random seed for splitting the dataset into 10') 107 | parser.add_argument('--fold_idx', type=int, default=0, help='the index of fold in 10-fold validation. Should be less then 10.') 108 | parser.add_argument('--num_layers', type=int, default=5, help='number of layers INCLUDING the input one') 109 | parser.add_argument('--num_mlp_layers', type=int, default=2, help='number of layers for MLP EXCLUDING the input one. 1 means linear model.') 110 | parser.add_argument('--hidden_dim', type=int, default=64, help='number of hidden units') 111 | parser.add_argument('--final_dropout', type=float, default=0.5, help='final layer dropout') 112 | parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"], help='Pooling for over nodes in a graph: sum or average') 113 | parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"], help='Pooling for over neighboring nodes: sum, average or max') 114 | parser.add_argument('--learn_eps', action="store_true", help='Whether to learn epsilon weighting forcenter nodes. Do not affect training accuracy.') 115 | parser.add_argument('--degree_as_tag', action="store_true", help='let the input node features be the degree of nodes (heuristics for unlabeled graph)') 116 | parser.add_argument('--filename', type = str, default = "", help='output file') 117 | parser.add_argument('--contraloss_weight', type=float, default=0.5, help='The weight of contrastive loss term.') 118 | parser.add_argument('--temperature', type=float, default=0.07, help='The temperature for contrastive loss.') 119 | args = parser.parse_args() 120 | 121 | #set up seeds and gpu device 122 | torch.manual_seed(0) 123 | np.random.seed(0) 124 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 125 | if torch.cuda.is_available(): 126 | torch.cuda.manual_seed_all(0) 127 | 128 | graphs, num_classes = load_data(args.dataset, args.degree_as_tag) 129 | 130 | Val_results, Train_results, Train_celoss, Train_contraloss = {},{},{},{} 131 | ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx. 132 | for fold_idx in range(10): 133 | train_graphs, test_graphs = separate_data(graphs, args.seed, fold_idx) 134 | 135 | model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device) 136 | 137 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 138 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5) 139 | 140 | # initialize the memory bank 141 | _, queue = model(train_graphs) 142 | 143 | for i in queue: 144 | queue[i] = nn.functional.normalize(queue[i], dim=1) 145 | 146 | val_acc, train_acc, train_celoss, train_contraloss = [],[],[],[] 147 | for epoch in range(1, args.epochs + 1): 148 | scheduler.step() 149 | 150 | avg_celoss, avg_contraloss, queue = train(args, model, device, train_graphs, queue, num_classes, optimizer, epoch) 151 | acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch) 152 | val_acc.append(acc_test) 153 | train_acc.append(acc_train) 154 | train_celoss.append(avg_celoss) 155 | train_contraloss.append(avg_contraloss) 156 | Val_results[fold_idx], Train_results[fold_idx], Train_celoss[fold_idx] = val_acc, train_acc, train_celoss 157 | Train_contraloss[fold_idx] = train_contraloss 158 | 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | -------------------------------------------------------------------------------- /for_hgp-sl/util.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import pickle 7 | from sklearn.model_selection import StratifiedKFold 8 | 9 | class S2VGraph(object): 10 | def __init__(self, g, label, node_tags=None, node_features=None): 11 | ''' 12 | g: a networkx graph 13 | label: an integer graph label 14 | node_tags: a list of integer node tags 15 | node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets 16 | edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor 17 | neighbors: list of neighbors (without self-loop) 18 | ''' 19 | self.label = label 20 | self.g = g 21 | self.node_tags = node_tags 22 | self.neighbors = [] 23 | self.node_features = 0 24 | self.edge_mat = 0 25 | 26 | self.max_neighbor = 0 27 | 28 | def load_data(dataset, degree_as_tag=False): 29 | ''' 30 | dataset: name of dataset 31 | test_proportion: ratio of test train split 32 | seed: random seed for random splitting of dataset 33 | ''' 34 | print('loading data') 35 | g_list = [] 36 | label_dict = {} 37 | feat_dict = {} 38 | 39 | with open('../dataset/%s/%s.txt' % (dataset, dataset), 'r') as f: 40 | n_g = int(f.readline().strip()) 41 | for i in range(n_g): 42 | row = f.readline().strip().split() 43 | n, l = [int(w) for w in row] 44 | if not l in label_dict: 45 | mapped = len(label_dict) 46 | label_dict[l] = mapped 47 | g = nx.Graph() 48 | node_tags, node_features = [], [] 49 | n_edges = 0 50 | for j in range(n): 51 | g.add_node(j) 52 | row = f.readline().strip().split() 53 | tmp = int(row[1]) + 2 54 | if tmp == len(row): 55 | # no node attributes 56 | row = [int(w) for w in row] 57 | attr = None 58 | else: 59 | row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]]) 60 | if not row[0] in feat_dict: 61 | mapped = len(feat_dict) 62 | feat_dict[row[0]] = mapped 63 | node_tags.append(feat_dict[row[0]]) 64 | 65 | if tmp > len(row): 66 | node_features.append(attr) 67 | 68 | n_edges += row[1] 69 | for k in range(2, len(row)): 70 | g.add_edge(j, row[k]) 71 | 72 | if node_features != []: 73 | node_features = np.stack(node_features) 74 | node_feature_flag = True 75 | else: 76 | node_features, node_feature_flag = None, False 77 | 78 | assert len(g) == n 79 | 80 | g_list.append(S2VGraph(g, l, node_tags)) 81 | 82 | #add labels and edge_mat 83 | for g in g_list: 84 | g.neighbors = [[] for i in range(len(g.g))] 85 | for i, j in g.g.edges(): 86 | g.neighbors[i].append(j) 87 | g.neighbors[j].append(i) 88 | degree_list = [] 89 | for i in range(len(g.g)): 90 | g.neighbors[i] = g.neighbors[i] 91 | degree_list.append(len(g.neighbors[i])) 92 | g.max_neighbor = max(degree_list) 93 | 94 | g.label = label_dict[g.label] 95 | 96 | edges = [list(pair) for pair in g.g.edges()] 97 | edges.extend([[i, j] for j, i in edges]) 98 | 99 | deg_list = list(dict(g.g.degree(range(len(g.g)))).values()) 100 | g.edge_mat = torch.LongTensor(edges).transpose(0,1) 101 | 102 | if degree_as_tag: 103 | for g in g_list: 104 | g.node_tags = list(dict(g.g.degree).values()) 105 | 106 | #Extracting unique tag labels 107 | tagset = set([]) 108 | for g in g_list: 109 | tagset = tagset.union(set(g.node_tags)) 110 | 111 | tagset = list(tagset) 112 | tag2index = {tagset[i]:i for i in range(len(tagset))} 113 | 114 | for g in g_list: 115 | g.node_features = torch.zeros(len(g.node_tags), len(tagset)) 116 | g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1 117 | 118 | 119 | print('# classes: %d' % len(label_dict)) 120 | print('# maximum node tag: %d' % len(tagset)) 121 | print("# data: %d" % len(g_list)) 122 | 123 | return g_list, len(label_dict) 124 | 125 | def separate_data(graph_list, seed, fold_idx): 126 | assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9." 127 | skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = seed) 128 | 129 | labels = [graph.y for graph in graph_list] 130 | idx_list = [] 131 | for idx in skf.split(np.zeros(len(labels)), labels): 132 | idx_list.append(idx) 133 | train_idx, test_idx = idx_list[fold_idx] 134 | 135 | train_graph_list = [graph_list[int(i)] for i in train_idx] 136 | test_graph_list = [graph_list[int(i)] for i in test_idx] 137 | 138 | return train_graph_list, test_graph_list 139 | 140 | 141 | # def contrastive_loss_labelwise(args, batch_idx_by_label, train_idx_by_label, hidden_feats): 142 | # ''' 143 | # hidden feats must be normalized. 144 | # ''' 145 | # assert(len(batch_idx_by_label) == len(train_idx_by_label)) 146 | # hidden_feats = nn.functional.normalize(hidden_feats, dim=1) 147 | 148 | # loss = 0 149 | # for i in batch_idx_by_label: 150 | # if(len(batch_idx_by_label[i]) == 0): 151 | # continue 152 | # q, k = hidden_feats[batch_idx_by_label[i]], hidden_feats[train_idx_by_label[i]] 153 | # # k_neg = hidden_feats[list(set([i for i in range(len_train_graphs)]) - set(train_idx_by_label[i]))] 154 | # l_pos = torch.sum(torch.exp(torch.mm(q, k.transpose(0,1))/args.temperature), dim=1) 155 | # l_neg = torch.sum(torch.exp(torch.mm(q, hidden_feats.transpose(0,1))/args.temperature), dim=1) 156 | # # print('part loss', l_pos/l_neg) 157 | # loss += torch.sum(-1.0*torch.log(l_pos/l_neg)) 158 | # return loss/args.batch_size 159 | 160 | 161 | def contrastive_loss_labelwise_winslide(args, batch_idx_by_label, train_idx_by_label, hidden_feats, queue): 162 | ''' 163 | hidden feats must be normalized. 164 | ''' 165 | assert(len(batch_idx_by_label) == len(train_idx_by_label)) 166 | hidden_feats = nn.functional.normalize(hidden_feats, dim=1) 167 | 168 | loss = 0 169 | for i in batch_idx_by_label: 170 | if(len(batch_idx_by_label) == 0): 171 | continue 172 | q, k = hidden_feats[batch_idx_by_label[i]], queue[train_idx_by_label[i]] 173 | # print('max value', torch.max(torch.mm(q, k.transpose(0,1)))) 174 | l_pos = torch.sum(torch.exp(torch.mm(q, k.transpose(0,1))/args.temperature), dim=1) 175 | l_neg = torch.sum(torch.exp(torch.mm(q, queue.transpose(0,1))/args.temperature), dim=1) 176 | # print('two part size',l_pos.size(), l_neg.size()) 177 | loss += torch.sum(-1.0*torch.log(l_pos/l_neg)) 178 | return loss/args.batch_size 179 | 180 | 181 | @torch.no_grad() 182 | def momentum_update(encoder_q, encoder_k, m=0.999): 183 | """ 184 | encoder_k = m * encoder_k + (1 - m) encoder_q 185 | """ 186 | for param_q, param_k in zip(encoder_q.parameters(), encoder_k.parameters()): 187 | param_k.data = param_k.data * m + param_q.data * (1. - m) 188 | 189 | return encoder_k 190 | 191 | 192 | # def dequeue_and_enqueue(hidden_batch_feats, selected_batch_idx, queue): 193 | # ''' 194 | # update memory bank by batch window slide; hidden_batch_feats must be normalized 195 | # ''' 196 | # assert(hidden_batch_feats.size()[1] == queue.size()[1]) 197 | 198 | 199 | # queue[selected_batch_idx] = nn.functional.normalize(hidden_batch_feats,dim=1) 200 | # return queue 201 | 202 | 203 | # def dequeue_and_enqueue_multiLayer(hidden_feats_dict, selected_batch_idx, queue): 204 | # ''' 205 | # update memory bank by batch window slide; hidden_batch_feats must be normalized 206 | # ''' 207 | # assert(len(hidden_feats_dict) == len(queue)) 208 | 209 | # for i in range(len(queue)): 210 | # queue[i][selected_batch_idx] = nn.functional.normalize(hidden_feats_dict[i],dim=1) 211 | # return queue 212 | 213 | def dequeue_and_enqueue_HGPSL(hidden_feats_dict, start, end, queue): 214 | ''' 215 | update memory bank by batch window slide; hidden_batch_feats must be normalized 216 | ''' 217 | assert(hidden_feats_dict.keys() == queue.keys()) 218 | 219 | for i in queue: 220 | queue[i][start:end] = nn.functional.normalize(hidden_feats_dict[i],dim=1) 221 | return queue 222 | 223 | 224 | def to_cuda(x): 225 | if torch.cuda.is_available(): 226 | x = x.cuda() 227 | return x 228 | 229 | 230 | -------------------------------------------------------------------------------- /for_gin/util.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.nn as nn 6 | import pickle 7 | from sklearn.model_selection import StratifiedKFold 8 | 9 | class S2VGraph(object): 10 | def __init__(self, g, label, node_tags=None, node_features=None): 11 | ''' 12 | g: a networkx graph 13 | label: an integer graph label 14 | node_tags: a list of integer node tags 15 | node_features: a torch float tensor, one-hot representation of the tag that is used as input to neural nets 16 | edge_mat: a torch long tensor, contain edge list, will be used to create torch sparse tensor 17 | neighbors: list of neighbors (without self-loop) 18 | ''' 19 | self.label = label 20 | self.g = g 21 | self.node_tags = node_tags 22 | self.neighbors = [] 23 | self.node_features = 0 24 | self.edge_mat = 0 25 | 26 | self.max_neighbor = 0 27 | 28 | def load_data(dataset, degree_as_tag): 29 | ''' 30 | dataset: name of dataset 31 | test_proportion: ratio of test train split 32 | seed: random seed for random splitting of dataset 33 | ''' 34 | print('loading data') 35 | g_list = [] 36 | label_dict = {} 37 | feat_dict = {} 38 | 39 | with open('dataset/%s/%s.txt' % (dataset, dataset), 'r') as f: 40 | n_g = int(f.readline().strip()) 41 | for i in range(n_g): 42 | row = f.readline().strip().split() 43 | n, l = [int(w) for w in row] 44 | if not l in label_dict: 45 | mapped = len(label_dict) 46 | label_dict[l] = mapped 47 | g = nx.Graph() 48 | node_tags, node_features = [], [] 49 | n_edges = 0 50 | for j in range(n): 51 | g.add_node(j) 52 | row = f.readline().strip().split() 53 | tmp = int(row[1]) + 2 54 | if tmp == len(row): 55 | # no node attributes 56 | row = [int(w) for w in row] 57 | attr = None 58 | else: 59 | row, attr = [int(w) for w in row[:tmp]], np.array([float(w) for w in row[tmp:]]) 60 | if not row[0] in feat_dict: 61 | mapped = len(feat_dict) 62 | feat_dict[row[0]] = mapped 63 | node_tags.append(feat_dict[row[0]]) 64 | 65 | if tmp > len(row): 66 | node_features.append(attr) 67 | 68 | n_edges += row[1] 69 | for k in range(2, len(row)): 70 | g.add_edge(j, row[k]) 71 | 72 | if node_features != []: 73 | node_features = np.stack(node_features) 74 | node_feature_flag = True 75 | else: 76 | node_features, node_feature_flag = None, False 77 | 78 | assert len(g) == n 79 | 80 | g_list.append(S2VGraph(g, l, node_tags)) 81 | 82 | #add labels and edge_mat 83 | for g in g_list: 84 | g.neighbors = [[] for i in range(len(g.g))] 85 | for i, j in g.g.edges(): 86 | g.neighbors[i].append(j) 87 | g.neighbors[j].append(i) 88 | degree_list = [] 89 | for i in range(len(g.g)): 90 | g.neighbors[i] = g.neighbors[i] 91 | degree_list.append(len(g.neighbors[i])) 92 | g.max_neighbor = max(degree_list) 93 | 94 | g.label = label_dict[g.label] 95 | 96 | edges = [list(pair) for pair in g.g.edges()] 97 | edges.extend([[i, j] for j, i in edges]) 98 | 99 | deg_list = list(dict(g.g.degree(range(len(g.g)))).values()) 100 | g.edge_mat = torch.LongTensor(edges).transpose(0,1) 101 | 102 | if degree_as_tag: 103 | for g in g_list: 104 | g.node_tags = list(dict(g.g.degree).values()) 105 | 106 | #Extracting unique tag labels 107 | tagset = set([]) 108 | for g in g_list: 109 | tagset = tagset.union(set(g.node_tags)) 110 | 111 | tagset = list(tagset) 112 | tag2index = {tagset[i]:i for i in range(len(tagset))} 113 | 114 | for g in g_list: 115 | g.node_features = torch.zeros(len(g.node_tags), len(tagset)) 116 | g.node_features[range(len(g.node_tags)), [tag2index[tag] for tag in g.node_tags]] = 1 117 | 118 | 119 | print('# classes: %d' % len(label_dict)) 120 | print('# maximum node tag: %d' % len(tagset)) 121 | print("# data: %d" % len(g_list)) 122 | 123 | return g_list, len(label_dict) 124 | 125 | def separate_data(graph_list, seed, fold_idx): 126 | assert 0 <= fold_idx and fold_idx < 10, "fold_idx must be from 0 to 9." 127 | skf = StratifiedKFold(n_splits=10, shuffle = True, random_state = seed) 128 | 129 | labels = [graph.label for graph in graph_list] 130 | idx_list = [] 131 | for idx in skf.split(np.zeros(len(labels)), labels): 132 | idx_list.append(idx) 133 | train_idx, test_idx = idx_list[fold_idx] 134 | 135 | train_graph_list = [graph_list[i] for i in train_idx] 136 | test_graph_list = [graph_list[i] for i in test_idx] 137 | 138 | return train_graph_list, test_graph_list 139 | 140 | 141 | #def contrastive_loss_labelwise(args, batch_idx_by_label, train_idx_by_label, hidden_feats): 142 | # ''' 143 | # hidden feats must be normalized. 144 | # ''' 145 | # assert(len(batch_idx_by_label) == len(train_idx_by_label)) 146 | # hidden_feats = nn.functional.normalize(hidden_feats, dim=1) 147 | # 148 | # loss = 0 149 | # for i in batch_idx_by_label: 150 | # if(len(batch_idx_by_label[i]) == 0): 151 | # continue 152 | # q, k = hidden_feats[batch_idx_by_label[i]], hidden_feats[train_idx_by_label[i]] 153 | # # k_neg = hidden_feats[list(set([i for i in range(len_train_graphs)]) - set(train_idx_by_label[i]))] 154 | # l_pos = torch.sum(torch.exp(torch.mm(q, k.transpose(0,1))/args.temperature), dim=1) 155 | # l_neg = torch.sum(torch.exp(torch.mm(q, hidden_feats.transpose(0,1))/args.temperature), dim=1) 156 | # # print('part loss', l_pos/l_neg) 157 | # loss += torch.sum(-1.0*torch.log(l_pos/l_neg)) 158 | # return loss/args.batch_size 159 | 160 | 161 | def contrastive_loss_labelwise_winslide(args, batch_idx_by_label, train_idx_by_label, hidden_feats, queue): 162 | ''' 163 | hidden feats must be normalized. 164 | ''' 165 | assert(len(batch_idx_by_label) == len(train_idx_by_label)) 166 | hidden_feats = nn.functional.normalize(hidden_feats, dim=1) 167 | 168 | loss = 0 169 | for i in batch_idx_by_label: 170 | if(len(batch_idx_by_label) == 0): 171 | continue 172 | q, k = hidden_feats[batch_idx_by_label[i]], queue[train_idx_by_label[i]] 173 | l_pos = torch.sum(torch.exp(torch.mm(q, k.transpose(0,1))/args.temperature), dim=1) 174 | l_neg = torch.sum(torch.exp(torch.mm(q, queue.transpose(0,1))/args.temperature), dim=1) 175 | loss += torch.sum(-1.0*torch.log(l_pos/l_neg)) 176 | return loss/args.batch_size 177 | 178 | 179 | def contrastive_loss_samplewise_winslide(args, hidden_feats, queue): 180 | ''' 181 | hidden feats must be normalized. 182 | ''' 183 | hidden_feats = nn.functional.normalize(hidden_feats, dim=1) 184 | 185 | loss = 0 186 | # for i in batch_idx_by_label: 187 | 188 | q = hidden_feats 189 | #print('q size', q.size()) 190 | l_pos = torch.sum(torch.exp(q*q/args.temperature), dim=1) 191 | l_neg = torch.sum(torch.exp(torch.mm(q, queue.transpose(0,1))/args.temperature), dim=1) 192 | # print('two part size',l_pos.size(), l_neg.size()) 193 | loss = torch.sum(-1.0*torch.log(l_pos/l_neg)) 194 | return loss/args.batch_size 195 | 196 | 197 | @torch.no_grad() 198 | def momentum_update(encoder_q, encoder_k, m=0.999): 199 | """ 200 | encoder_k = m * encoder_k + (1 - m) encoder_q 201 | """ 202 | for param_q, param_k in zip(encoder_q.parameters(), encoder_k.parameters()): 203 | param_k.data = param_k.data * m + param_q.data * (1. - m) 204 | 205 | return encoder_k 206 | 207 | 208 | def dequeue_and_enqueue(hidden_batch_feats, selected_batch_idx, queue): 209 | ''' 210 | update memory bank by batch window slide; hidden_batch_feats must be normalized 211 | ''' 212 | assert(hidden_batch_feats.size()[1] == queue.size()[1]) 213 | 214 | 215 | queue[selected_batch_idx] = nn.functional.normalize(hidden_batch_feats,dim=1) 216 | return queue 217 | 218 | 219 | def dequeue_and_enqueue_multiLayer(hidden_feats_dict, selected_batch_idx, queue): 220 | ''' 221 | update memory bank by batch window slide; hidden_batch_feats must be normalized 222 | ''' 223 | assert(len(hidden_feats_dict) == len(queue)) 224 | 225 | for i in range(len(queue)): 226 | queue[i][selected_batch_idx] = nn.functional.normalize(hidden_feats_dict[i],dim=1) 227 | return queue 228 | 229 | 230 | def to_cuda(x): 231 | if torch.cuda.is_available(): 232 | x = x.cuda() 233 | return x 234 | 235 | 236 | -------------------------------------------------------------------------------- /for_gin/train_powerfulgnn_twoenc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import numpy as np 7 | import pickle 8 | import copy 9 | from tqdm import tqdm 10 | 11 | from util import load_data, separate_data, to_cuda, contrastive_loss_labelwise_winslide, dequeue_and_enqueue_multiLayer, momentum_update 12 | from models.graphcnn_pooled_multilayer import GraphCNN 13 | 14 | criterion = nn.CrossEntropyLoss() 15 | 16 | def train(args, model_q, model_k, device, train_graphs, num_classes, optimizer, queue, epoch): 17 | train_idx_by_label = {} 18 | for i in range(num_classes): 19 | train_idx_by_label[i] = [idx for idx in range(len(train_graphs)) if train_graphs[idx].label == i] 20 | 21 | model_q.train() 22 | model_k.train() 23 | 24 | pbar = tqdm(range(args.iters_per_epoch), unit='batch') 25 | 26 | celoss_accum, contraloss_accum = 0,0 27 | for pos in pbar: 28 | selected_batch_idx = np.random.permutation(len(train_graphs))[:args.batch_size] 29 | batch_graph = [train_graphs[idx] for idx in selected_batch_idx] 30 | 31 | output, _ = model_q(batch_graph) 32 | 33 | batch_labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device) 34 | #compute cross-entropy loss 35 | celoss = criterion(output, batch_labels) 36 | # update memory bank 37 | _, hidden_batch_feats = model_k(batch_graph) 38 | queue = dequeue_and_enqueue_multiLayer(hidden_batch_feats, selected_batch_idx, queue) 39 | # compute label-wise contrastive loss 40 | batch_idx_by_label = {} 41 | for i in range(num_classes): 42 | batch_idx_by_label[i] = [idx for idx in range(len(batch_graph)) if batch_graph[idx].label == i] 43 | 44 | contraloss = 0.0 45 | for layer in hidden_batch_feats: 46 | contraloss += contrastive_loss_labelwise_winslide(args, batch_idx_by_label, train_idx_by_label, 47 | hidden_batch_feats[layer], queue[layer].detach().clone()) 48 | # momentum update model k 49 | model_k = momentum_update(model_q, model_k, m=0.999) 50 | # backprop 51 | loss = celoss + args.contraloss_weight * contraloss 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | celoss_accum += celoss.detach().cpu().numpy() 57 | contraloss_accum += contraloss.detach().cpu().numpy() 58 | 59 | 60 | #report 61 | pbar.set_description('epoch: %d' % (epoch)) 62 | 63 | avg_celoss = celoss_accum/args.iters_per_epoch 64 | avg_contraloss = contraloss_accum/args.iters_per_epoch 65 | print("training celoss: %4f, contraloss: %4f" % (avg_celoss, avg_contraloss)) 66 | 67 | return avg_celoss, avg_contraloss, queue 68 | 69 | ###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation) 70 | def pass_data_iteratively(model, graphs, minibatch_size = 64): 71 | model.eval() 72 | output = [] 73 | idx = np.arange(len(graphs)) 74 | for i in range(0, len(graphs), minibatch_size): 75 | sampled_idx = idx[i:i+minibatch_size] 76 | if len(sampled_idx) == 0: 77 | continue 78 | output.append(model([graphs[j] for j in sampled_idx])[0].detach()) 79 | return torch.cat(output, 0) 80 | 81 | def test(args, model, device, train_graphs, test_graphs, epoch): 82 | model.eval() 83 | 84 | output = pass_data_iteratively(model, train_graphs) 85 | pred = output.max(1, keepdim=True)[1] 86 | labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device) 87 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 88 | acc_train = correct / float(len(train_graphs)) 89 | 90 | output = pass_data_iteratively(model, test_graphs) 91 | pred = output.max(1, keepdim=True)[1] 92 | labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device) 93 | correct = pred.eq(labels.view_as(pred)).sum().cpu().item() 94 | acc_test = correct / float(len(test_graphs)) 95 | 96 | print("accuracy train: %f test: %f" % (acc_train, acc_test)) 97 | 98 | return acc_train, acc_test 99 | 100 | def main(): 101 | # Training settings 102 | # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper. 103 | parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification') 104 | parser.add_argument('--dataset', type=str, default="MUTAG", help='name of dataset') 105 | parser.add_argument('--device', type=int, default=0, help='which gpu to use if any') 106 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') 107 | parser.add_argument('--iters_per_epoch', type=int, default=50, help='number of iterations per each epoch') 108 | parser.add_argument('--epochs', type=int, default=1000, help='number of epochs to train') 109 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 110 | parser.add_argument('--seed', type=int, default=0, help='random seed for splitting the dataset into 10') 111 | parser.add_argument('--fold_idx', type=int, default=0, help='the index of fold in 10-fold validation. Should be less then 10.') 112 | parser.add_argument('--num_layers', type=int, default=5, help='number of layers INCLUDING the input one') 113 | parser.add_argument('--num_mlp_layers', type=int, default=2, help='number of layers for MLP EXCLUDING the input one. 1 means linear model.') 114 | parser.add_argument('--hidden_dim', type=int, default=64, help='number of hidden units') 115 | parser.add_argument('--final_dropout', type=float, default=0.5, help='final layer dropout') 116 | parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"], help='Pooling for over nodes in a graph: sum or average') 117 | parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"], help='Pooling for over neighboring nodes: sum, average or max') 118 | parser.add_argument('--learn_eps', action="store_true", help='Whether to learn epsilon weighting forcenter nodes. Do not affect training accuracy.') 119 | parser.add_argument('--degree_as_tag', action="store_true", help='let the input node features be the degree of nodes (heuristics for unlabeled graph)') 120 | parser.add_argument('--filename', type = str, default = "", help='output file') 121 | parser.add_argument('--contraloss_weight', type=float, default=0.5, help='The weight of contrastive loss term.') 122 | parser.add_argument('--temperature', type=float, default=0.07, help='The temperature for contrastive loss.') 123 | args = parser.parse_args() 124 | 125 | #set up seeds and gpu device 126 | torch.manual_seed(0) 127 | np.random.seed(0) 128 | device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu") 129 | if torch.cuda.is_available(): 130 | torch.cuda.manual_seed_all(0) 131 | 132 | graphs, num_classes = load_data(args.dataset, args.degree_as_tag) 133 | 134 | Val_results, Train_results, Train_celoss, Train_contraloss = {},{},{},{} 135 | ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx. 136 | for fold_idx in range(10): 137 | train_graphs, test_graphs = separate_data(graphs, args.seed, fold_idx) 138 | 139 | model_q = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device) 140 | model_k = copy.deepcopy(GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device)).to(device) 141 | for param_q, param_k in zip(model_q.parameters(), model_k.parameters()): 142 | param_k.data.copy_(param_q.data) # initialize 143 | param_k.requires_grad = False # not update by gradient 144 | 145 | optimizer = optim.Adam(model_q.parameters(), lr=args.lr) 146 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5) 147 | 148 | # create queue 149 | model_q.train() 150 | queue = model_q(train_graphs)[1] 151 | for i in queue: 152 | queue[i] = F.normalize(queue[i], dim=1) 153 | 154 | val_acc, train_acc, train_celoss, train_contraloss = [],[],[],[] 155 | for epoch in range(1, args.epochs + 1): 156 | scheduler.step() 157 | 158 | avg_celoss, avg_contraloss, queue = train(args, model_q, model_k, device, train_graphs, num_classes, optimizer, queue, epoch) 159 | acc_train, acc_test = test(args, model_q, device, train_graphs, test_graphs, epoch) 160 | val_acc.append(acc_test) 161 | train_acc.append(acc_train) 162 | train_celoss.append(avg_celoss) 163 | train_contraloss.append(avg_contraloss) 164 | 165 | Val_results[fold_idx], Train_results[fold_idx], Train_celoss[fold_idx] = val_acc, train_acc, train_celoss 166 | Train_contraloss[fold_idx] = train_contraloss 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /for_gin/models/graphcnn_pooled_multilayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append("models/") 7 | from mlp import MLP 8 | 9 | class GraphCNN(nn.Module): 10 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device): 11 | ''' 12 | num_layers: number of layers in the neural networks (INCLUDING the input layer) 13 | num_mlp_layers: number of layers in mlps (EXCLUDING the input layer) 14 | input_dim: dimensionality of input features 15 | hidden_dim: dimensionality of hidden units at ALL layers 16 | output_dim: number of classes for prediction 17 | final_dropout: dropout ratio on the final linear layer 18 | learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. 19 | neighbor_pooling_type: how to aggregate neighbors (mean, average, or max) 20 | graph_pooling_type: how to aggregate entire nodes in a graph (mean, average) 21 | device: which device to use 22 | ''' 23 | super(GraphCNN, self).__init__() 24 | 25 | self.final_dropout = final_dropout 26 | self.device = device 27 | self.num_layers = num_layers 28 | self.graph_pooling_type = graph_pooling_type 29 | self.neighbor_pooling_type = neighbor_pooling_type 30 | self.learn_eps = learn_eps 31 | self.eps = nn.Parameter(torch.zeros(self.num_layers-1)) 32 | 33 | ###List of MLPs 34 | self.mlps = torch.nn.ModuleList() 35 | 36 | ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer) 37 | self.batch_norms = torch.nn.ModuleList() 38 | 39 | for layer in range(self.num_layers-1): 40 | if layer == 0: 41 | self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) 42 | else: 43 | self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) 44 | 45 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 46 | 47 | #Linear function that maps the hidden representation at dofferemt layers into a prediction score 48 | self.linears_prediction = torch.nn.ModuleList() 49 | for layer in range(num_layers): 50 | if layer == 0: 51 | self.linears_prediction.append(nn.Linear(input_dim, output_dim)) 52 | else: 53 | self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) 54 | 55 | 56 | def __preprocess_neighbors_maxpool(self, batch_graph): 57 | ###create padded_neighbor_list in concatenated graph 58 | 59 | #compute the maximum number of neighbors within the graphs in the current minibatch 60 | max_deg = max([graph.max_neighbor for graph in batch_graph]) 61 | 62 | padded_neighbor_list = [] 63 | start_idx = [0] 64 | 65 | 66 | for i, graph in enumerate(batch_graph): 67 | start_idx.append(start_idx[i] + len(graph.g)) 68 | padded_neighbors = [] 69 | for j in range(len(graph.neighbors)): 70 | #add off-set values to the neighbor indices 71 | pad = [n + start_idx[i] for n in graph.neighbors[j]] 72 | #padding, dummy data is assumed to be stored in -1 73 | pad.extend([-1]*(max_deg - len(pad))) 74 | 75 | #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 76 | if not self.learn_eps: 77 | pad.append(j + start_idx[i]) 78 | 79 | padded_neighbors.append(pad) 80 | padded_neighbor_list.extend(padded_neighbors) 81 | 82 | return torch.LongTensor(padded_neighbor_list) 83 | 84 | 85 | def __preprocess_neighbors_sumavepool(self, batch_graph): 86 | ###create block diagonal sparse matrix 87 | 88 | edge_mat_list = [] 89 | start_idx = [0] 90 | for i, graph in enumerate(batch_graph): 91 | start_idx.append(start_idx[i] + len(graph.g)) 92 | edge_mat_list.append(graph.edge_mat + start_idx[i]) 93 | Adj_block_idx = torch.cat(edge_mat_list, 1) 94 | Adj_block_elem = torch.ones(Adj_block_idx.shape[1]) 95 | 96 | #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 97 | 98 | if not self.learn_eps: 99 | num_node = start_idx[-1] 100 | self_loop_edge = torch.LongTensor([range(num_node), range(num_node)]) 101 | elem = torch.ones(num_node) 102 | Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1) 103 | Adj_block_elem = torch.cat([Adj_block_elem, elem], 0) 104 | 105 | Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]])) 106 | 107 | return Adj_block.to(self.device) 108 | 109 | 110 | def __preprocess_graphpool(self, batch_graph): 111 | ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes) 112 | 113 | start_idx = [0] 114 | 115 | #compute the padded neighbor list 116 | for i, graph in enumerate(batch_graph): 117 | start_idx.append(start_idx[i] + len(graph.g)) 118 | 119 | idx = [] 120 | elem = [] 121 | for i, graph in enumerate(batch_graph): 122 | ###average pooling 123 | if self.graph_pooling_type == "average": 124 | elem.extend([1./len(graph.g)]*len(graph.g)) 125 | 126 | else: 127 | ###sum pooling 128 | elem.extend([1]*len(graph.g)) 129 | 130 | idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)]) 131 | elem = torch.FloatTensor(elem) 132 | idx = torch.LongTensor(idx).transpose(0,1) 133 | graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]])) 134 | 135 | return graph_pool.to(self.device) 136 | 137 | def maxpool(self, h, padded_neighbor_list): 138 | ###Element-wise minimum will never affect max-pooling 139 | 140 | dummy = torch.min(h, dim = 0)[0] 141 | h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)]) 142 | pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0] 143 | return pooled_rep 144 | 145 | 146 | def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None): 147 | ###pooling neighboring nodes and center nodes separately by epsilon reweighting. 148 | 149 | if self.neighbor_pooling_type == "max": 150 | ##If max pooling 151 | pooled = self.maxpool(h, padded_neighbor_list) 152 | else: 153 | #If sum or average pooling 154 | pooled = torch.spmm(Adj_block, h) 155 | if self.neighbor_pooling_type == "average": 156 | #If average pooling 157 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 158 | pooled = pooled/degree 159 | 160 | #Reweights the center node representation when aggregating it with its neighbors 161 | pooled = pooled + (1 + self.eps[layer])*h 162 | pooled_rep = self.mlps[layer](pooled) 163 | h = self.batch_norms[layer](pooled_rep) 164 | 165 | #non-linearity 166 | h = F.relu(h) 167 | return h 168 | 169 | 170 | def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None): 171 | ###pooling neighboring nodes and center nodes altogether 172 | 173 | if self.neighbor_pooling_type == "max": 174 | ##If max pooling 175 | pooled = self.maxpool(h, padded_neighbor_list) 176 | else: 177 | #If sum or average pooling 178 | pooled = torch.spmm(Adj_block, h) 179 | if self.neighbor_pooling_type == "average": 180 | #If average pooling 181 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 182 | pooled = pooled/degree 183 | 184 | #representation of neighboring and center nodes 185 | pooled_rep = self.mlps[layer](pooled) 186 | 187 | h = self.batch_norms[layer](pooled_rep) 188 | 189 | #non-linearity 190 | h = F.relu(h) 191 | return h 192 | 193 | 194 | def forward(self, batch_graph): 195 | X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device) 196 | graph_pool = self.__preprocess_graphpool(batch_graph) 197 | 198 | if self.neighbor_pooling_type == "max": 199 | padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph) 200 | else: 201 | Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph) 202 | 203 | #list of hidden representation at each layer (including input) 204 | hidden_rep = [X_concat] 205 | h = X_concat 206 | 207 | for layer in range(self.num_layers-1): 208 | if self.neighbor_pooling_type == "max" and self.learn_eps: 209 | h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list) 210 | elif not self.neighbor_pooling_type == "max" and self.learn_eps: 211 | h = self.next_layer_eps(h, layer, Adj_block = Adj_block) 212 | elif self.neighbor_pooling_type == "max" and not self.learn_eps: 213 | h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list) 214 | elif not self.neighbor_pooling_type == "max" and not self.learn_eps: 215 | h = self.next_layer(h, layer, Adj_block = Adj_block) 216 | 217 | hidden_rep.append(h) 218 | 219 | score_over_layer = 0 220 | 221 | #perform pooling over all nodes in each graph in every layer 222 | hidden_feats_dict = {} 223 | for layer, h in enumerate(hidden_rep): 224 | pooled_h = torch.spmm(graph_pool, h) 225 | # self.linears_prediction is for the final classification 226 | score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training) 227 | hidden_feats_dict[layer] = pooled_h 228 | return score_over_layer, hidden_feats_dict 229 | 230 | 231 | 232 | -------------------------------------------------------------------------------- /for_hgp-sl/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sparse_softmax import Sparsemax 5 | from torch.nn import Parameter 6 | from torch_geometric.data import Data 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.pool.topk_pool import topk, filter_adj 9 | from torch_geometric.utils import softmax, dense_to_sparse, add_remaining_self_loops 10 | from torch_scatter import scatter_add 11 | from torch_sparse import spspmm, coalesce 12 | 13 | 14 | class TwoHopNeighborhood(object): 15 | def __call__(self, data): 16 | edge_index, edge_attr = data.edge_index, data.edge_attr 17 | n = data.num_nodes 18 | 19 | value = edge_index.new_ones((edge_index.size(1),), dtype=torch.float) 20 | 21 | index, value = spspmm(edge_index, value, edge_index, value, n, n, n) 22 | value.fill_(0) 23 | 24 | edge_index = torch.cat([edge_index, index], dim=1) 25 | if edge_attr is None: 26 | data.edge_index, _ = coalesce(edge_index, None, n, n) 27 | else: 28 | value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)]) 29 | value = value.expand(-1, *list(edge_attr.size())[1:]) 30 | edge_attr = torch.cat([edge_attr, value], dim=0) 31 | data.edge_index, edge_attr = coalesce(edge_index, edge_attr, n, n) 32 | data.edge_attr = edge_attr 33 | 34 | return data 35 | 36 | def __repr__(self): 37 | return '{}()'.format(self.__class__.__name__) 38 | 39 | 40 | class GCN(MessagePassing): 41 | def __init__(self, in_channels, out_channels, cached=False, bias=True, **kwargs): 42 | super(GCN, self).__init__(aggr='add', **kwargs) 43 | 44 | self.in_channels = in_channels 45 | self.out_channels = out_channels 46 | self.cached = cached 47 | self.cached_result = None 48 | self.cached_num_edges = None 49 | 50 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 51 | nn.init.xavier_uniform_(self.weight.data) 52 | 53 | if bias: 54 | self.bias = Parameter(torch.Tensor(out_channels)) 55 | nn.init.zeros_(self.bias.data) 56 | else: 57 | self.register_parameter('bias', None) 58 | 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | self.cached_result = None 63 | self.cached_num_edges = None 64 | 65 | @staticmethod 66 | def norm(edge_index, num_nodes, edge_weight, dtype=None): 67 | if edge_weight is None: 68 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device) 69 | 70 | row, col = edge_index 71 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 72 | deg_inv_sqrt = deg.pow(-0.5) 73 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 74 | 75 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 76 | 77 | def forward(self, x, edge_index, edge_weight=None): 78 | x = torch.matmul(x, self.weight) 79 | 80 | if self.cached and self.cached_result is not None: 81 | if edge_index.size(1) != self.cached_num_edges: 82 | raise RuntimeError( 83 | 'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1))) 84 | 85 | if not self.cached or self.cached_result is None: 86 | self.cached_num_edges = edge_index.size(1) 87 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype) 88 | self.cached_result = edge_index, norm 89 | 90 | edge_index, norm = self.cached_result 91 | 92 | return self.propagate(edge_index, x=x, norm=norm) 93 | 94 | def message(self, x_j, norm): 95 | return norm.view(-1, 1) * x_j 96 | 97 | def update(self, aggr_out): 98 | if self.bias is not None: 99 | aggr_out = aggr_out + self.bias 100 | return aggr_out 101 | 102 | def __repr__(self): 103 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, self.out_channels) 104 | 105 | 106 | class NodeInformationScore(MessagePassing): 107 | def __init__(self, improved=False, cached=False, **kwargs): 108 | super(NodeInformationScore, self).__init__(aggr='add', **kwargs) 109 | 110 | self.improved = improved 111 | self.cached = cached 112 | self.cached_result = None 113 | self.cached_num_edges = None 114 | 115 | @staticmethod 116 | def norm(edge_index, num_nodes, edge_weight, dtype=None): 117 | if edge_weight is None: 118 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, device=edge_index.device) 119 | 120 | row, col = edge_index 121 | deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) 122 | deg_inv_sqrt = deg.pow(-0.5) 123 | deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 124 | 125 | edge_index, edge_weight = add_remaining_self_loops(edge_index, edge_weight, 0, num_nodes) 126 | 127 | row, col = edge_index 128 | expand_deg = torch.zeros((edge_weight.size(0),), dtype=dtype, device=edge_index.device) 129 | expand_deg[-num_nodes:] = torch.ones((num_nodes,), dtype=dtype, device=edge_index.device) 130 | 131 | return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 132 | 133 | def forward(self, x, edge_index, edge_weight): 134 | if self.cached and self.cached_result is not None: 135 | if edge_index.size(1) != self.cached_num_edges: 136 | raise RuntimeError( 137 | 'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1))) 138 | 139 | if not self.cached or self.cached_result is None: 140 | self.cached_num_edges = edge_index.size(1) 141 | edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype) 142 | self.cached_result = edge_index, norm 143 | 144 | edge_index, norm = self.cached_result 145 | 146 | return self.propagate(edge_index, x=x, norm=norm) 147 | 148 | def message(self, x_j, norm): 149 | return norm.view(-1, 1) * x_j 150 | 151 | def update(self, aggr_out): 152 | return aggr_out 153 | 154 | 155 | class HGPSLPool(torch.nn.Module): 156 | def __init__(self, in_channels, ratio=0.8, sample=False, sparse=False, sl=True, lamb=1.0, negative_slop=0.2): 157 | super(HGPSLPool, self).__init__() 158 | self.in_channels = in_channels 159 | self.ratio = ratio 160 | self.sample = sample 161 | self.sparse = sparse 162 | self.sl = sl 163 | self.negative_slop = negative_slop 164 | self.lamb = lamb 165 | 166 | self.att = Parameter(torch.Tensor(1, self.in_channels * 2)) 167 | nn.init.xavier_uniform_(self.att.data) 168 | self.sparse_attention = Sparsemax() 169 | self.neighbor_augment = TwoHopNeighborhood() 170 | self.calc_information_score = NodeInformationScore() 171 | 172 | def forward(self, x, edge_index, edge_attr, batch=None): 173 | if batch is None: 174 | batch = edge_index.new_zeros(x.size(0)) 175 | 176 | x_information_score = self.calc_information_score(x, edge_index, edge_attr) 177 | score = torch.sum(torch.abs(x_information_score), dim=1) 178 | 179 | # Graph Pooling 180 | original_x = x 181 | perm = topk(score, self.ratio, batch) 182 | x = x[perm] 183 | batch = batch[perm] 184 | induced_edge_index, induced_edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) 185 | 186 | # Discard structure learning layer, directly return 187 | if self.sl is False: 188 | return x, induced_edge_index, induced_edge_attr, batch 189 | 190 | # Structure Learning 191 | if self.sample: # run as default 192 | # A fast mode for large graphs. 193 | # In large graphs, learning the possible edge weights between each pair of nodes is time consuming. 194 | # To accelerate this process, we sample it's K-Hop neighbors for each node and then learn the 195 | # edge weights between them. 196 | k_hop = 3 197 | if edge_attr is None: 198 | edge_attr = torch.ones((edge_index.size(1),), dtype=torch.float, device=edge_index.device) 199 | 200 | hop_data = Data(x=original_x, edge_index=edge_index, edge_attr=edge_attr) 201 | for _ in range(k_hop - 1): 202 | hop_data = self.neighbor_augment(hop_data) 203 | hop_edge_index = hop_data.edge_index 204 | hop_edge_attr = hop_data.edge_attr 205 | new_edge_index, new_edge_attr = filter_adj(hop_edge_index, hop_edge_attr, perm, num_nodes=score.size(0)) 206 | 207 | new_edge_index, new_edge_attr = add_remaining_self_loops(new_edge_index, new_edge_attr, 0, x.size(0)) 208 | row, col = new_edge_index 209 | weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) 210 | weights = F.leaky_relu(weights, self.negative_slop) + new_edge_attr * self.lamb 211 | adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) 212 | adj[row, col] = weights 213 | new_edge_index, weights = dense_to_sparse(adj) 214 | row, col = new_edge_index 215 | if self.sparse: 216 | new_edge_attr = self.sparse_attention(weights, row) 217 | else: 218 | new_edge_attr = softmax(weights, row, x.size(0)) 219 | # filter out zero weight edges 220 | adj[row, col] = new_edge_attr 221 | new_edge_index, new_edge_attr = dense_to_sparse(adj) 222 | # release gpu memory 223 | del adj 224 | torch.cuda.empty_cache() 225 | else: 226 | # Learning the possible edge weights between each pair of nodes in the pooled subgraph, relative slower. 227 | if edge_attr is None: 228 | induced_edge_attr = torch.ones((induced_edge_index.size(1),), dtype=x.dtype, 229 | device=induced_edge_index.device) 230 | num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0) 231 | shift_cum_num_nodes = torch.cat([num_nodes.new_zeros(1), num_nodes.cumsum(dim=0)[:-1]], dim=0) 232 | cum_num_nodes = num_nodes.cumsum(dim=0) 233 | adj = torch.zeros((x.size(0), x.size(0)), dtype=torch.float, device=x.device) 234 | # Construct batch fully connected graph in block diagonal matirx format 235 | for idx_i, idx_j in zip(shift_cum_num_nodes, cum_num_nodes): 236 | adj[idx_i:idx_j, idx_i:idx_j] = 1.0 237 | new_edge_index, _ = dense_to_sparse(adj) 238 | row, col = new_edge_index 239 | 240 | weights = (torch.cat([x[row], x[col]], dim=1) * self.att).sum(dim=-1) 241 | weights = F.leaky_relu(weights, self.negative_slop) 242 | adj[row, col] = weights 243 | induced_row, induced_col = induced_edge_index 244 | 245 | adj[induced_row, induced_col] += induced_edge_attr * self.lamb 246 | weights = adj[row, col] 247 | if self.sparse: 248 | new_edge_attr = self.sparse_attention(weights, row) 249 | else: 250 | new_edge_attr = softmax(weights, row, x.size(0)) 251 | # filter out zero weight edges 252 | adj[row, col] = new_edge_attr 253 | new_edge_index, new_edge_attr = dense_to_sparse(adj) 254 | # release gpu memory 255 | del adj 256 | torch.cuda.empty_cache() 257 | 258 | return x, new_edge_index, new_edge_attr, batch 259 | --------------------------------------------------------------------------------