├── .gitignore ├── LICENSE ├── README.md ├── heterophilic_graphs ├── data_handling.py ├── models.py └── run_GNN.py ├── imgs └── gradient_gating_scheme2.png ├── node_regression ├── README.md ├── data_handling.py ├── models.py └── run_GNN.py └── synthetic_cora ├── README.md ├── data_handling.py ├── models.py └── train_GNN.py /.gitignore: -------------------------------------------------------------------------------- 1 | code 2 | GraphSAGE/ 3 | GCN/ 4 | GAT/ 5 | plain_GCN/ 6 | plain_GAT/ 7 | plain_GraphSAGE/ 8 | data/ 9 | ray_tune/ 10 | __pycache__/ 11 | src/checkpoint 12 | images/ 13 | ray_results/ 14 | models/ 15 | __pycache__/ 16 | .data/ 17 | .vector_cache/ 18 | .idea/ 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Konstantin Rusch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Gradient Gating for Deep Multi-Rate Learning on Graphs

2 | 3 | This repository contains the implementation to reproduce the numerical experiments of the **ICLR 2023** paper [Gradient Gating for Deep Multi-Rate Learning on Graphs](https://openreview.net/forum?id=JpRExTbl1-) 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/gradient-gating-for-deep-multi-rate-learning/node-classification-on-arxiv-year)](https://paperswithcode.com/sota/node-classification-on-arxiv-year?p=gradient-gating-for-deep-multi-rate-learning) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/gradient-gating-for-deep-multi-rate-learning/node-classification-on-genius)](https://paperswithcode.com/sota/node-classification-on-genius?p=gradient-gating-for-deep-multi-rate-learning) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/gradient-gating-for-deep-multi-rate-learning/node-classification-on-snap-patents)](https://paperswithcode.com/sota/node-classification-on-snap-patents?p=gradient-gating-for-deep-multi-rate-learning) 8 | 9 |

10 | 11 |

12 | 13 | ### Requirements 14 | Main dependencies (with python >= 3.7):
15 | torch==1.9.0
16 | torch-cluster==1.5.9
17 | torch-geometric==2.0.3
18 | torch-scatter==2.0.9
19 | torch-sparse==0.6.12
20 | torch-spline-conv==1.2.1
21 | 22 | Commands to install all the dependencies in a new conda environment
23 | *(python 3.7 and cuda 10.2 -- for other cuda versions change accordingly)* 24 | ``` 25 | conda create --name gradientgating python=3.7 26 | conda activate gradientgating 27 | 28 | pip install torch==1.9.0 29 | 30 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 31 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 32 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 33 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 34 | pip install torch-geometric 35 | pip install scipy 36 | pip install numpy 37 | ``` 38 | 39 | # Citation 40 | If you found our work useful in your research, please cite our paper at: 41 | ```bibtex 42 | @inproceedings{rusch2022gradient, 43 | title={Gradient Gating for Deep Multi-Rate Learning on Graphs}, 44 | author={Rusch, T Konstantin and Chamberlain, Benjamin P and Mahoney, Michael W and Bronstein, Michael M and Mishra, Siddhartha}, 45 | booktitle={International Conference on Learning Representations}, 46 | year={2023} 47 | } 48 | ``` 49 | (Also consider starring the project on GitHub.) 50 | -------------------------------------------------------------------------------- /heterophilic_graphs/data_handling.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import WebKB, WikipediaNetwork, Actor 2 | import torch 3 | import numpy as np 4 | 5 | def get_data(name, split=0): 6 | path = '../data/' +name 7 | if name in ['chameleon','squirrel']: 8 | dataset = WikipediaNetwork(root=path, name=name) 9 | if name in ['cornell', 'texas', 'wisconsin']: 10 | dataset = WebKB(path ,name=name) 11 | if name == 'film': 12 | dataset = Actor(root=path) 13 | 14 | data = dataset[0] 15 | if name in ['chameleon', 'squirrel']: 16 | splits_file = np.load(f'{path}/{name}/geom_gcn/raw/{name}_split_0.6_0.2_{split}.npz') 17 | if name in ['cornell', 'texas', 'wisconsin']: 18 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz') 19 | if name == 'film': 20 | splits_file = np.load(f'{path}/raw/{name}_split_0.6_0.2_{split}.npz') 21 | if name in ['Cora', 'Citeseer', 'Pubmed']: 22 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz') 23 | train_mask = splits_file['train_mask'] 24 | val_mask = splits_file['val_mask'] 25 | test_mask = splits_file['test_mask'] 26 | 27 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool) 28 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool) 29 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool) 30 | 31 | return data 32 | -------------------------------------------------------------------------------- /heterophilic_graphs/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv 6 | 7 | class G2(nn.Module): 8 | def __init__(self, conv, p=2., conv_type='GraphSAGE', activation=nn.ReLU()): 9 | super(G2, self).__init__() 10 | self.conv = conv 11 | self.p = p 12 | self.activation = activation 13 | self.conv_type = conv_type 14 | 15 | def forward(self, X, edge_index): 16 | n_nodes = X.size(0) 17 | if self.conv_type == 'GAT': 18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 19 | else: 20 | X = self.activation(self.conv(X, edge_index)) 21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1), 22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean')) 23 | 24 | return gg 25 | 26 | class G2_GNN(nn.Module): 27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GraphSAGE', p=2., drop_in=0, drop=0, use_gg_conv=True): 28 | super(G2_GNN, self).__init__() 29 | self.conv_type = conv_type 30 | self.enc = nn.Linear(nfeat, nhid) 31 | self.dec = nn.Linear(nhid, nclass) 32 | self.drop_in = drop_in 33 | self.drop = drop 34 | self.nlayers = nlayers 35 | if conv_type == 'GCN': 36 | self.conv = GCNConv(nhid, nhid) 37 | if use_gg_conv == True: 38 | self.conv_gg = GCNConv(nhid, nhid) 39 | elif conv_type == 'GraphSAGE': 40 | self.conv = SAGEConv(nhid, nhid) 41 | if use_gg_conv == True: 42 | self.conv_gg = SAGEConv(nhid, nhid) 43 | elif conv_type == 'GAT': 44 | self.conv = GATConv(nhid,nhid,heads=4,concat=True) 45 | if use_gg_conv == True: 46 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True) 47 | else: 48 | print('specified graph conv not implemented') 49 | 50 | if use_gg_conv == True: 51 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU()) 52 | else: 53 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU()) 54 | 55 | def forward(self, data): 56 | X = data.x 57 | n_nodes = X.size(0) 58 | edge_index = data.edge_index 59 | X = F.dropout(X, self.drop_in, training=self.training) 60 | X = torch.relu(self.enc(X)) 61 | 62 | for i in range(self.nlayers): 63 | if self.conv_type == 'GAT': 64 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 65 | else: 66 | X_ = torch.relu(self.conv(X, edge_index)) 67 | tau = self.G2(X, edge_index) 68 | X = (1 - tau) * X + tau * X_ 69 | X = F.dropout(X, self.drop, training=self.training) 70 | 71 | return self.dec(X) 72 | -------------------------------------------------------------------------------- /heterophilic_graphs/run_GNN.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | from data_handling import get_data 6 | import argparse 7 | 8 | 9 | def train(args, split): 10 | data = get_data(args.dataset, split) 11 | 12 | best_eval_acc = 0 13 | best_eval_loss = 1e5 14 | bad_counter = 0 15 | best_test_acc = 0 16 | 17 | nout = 5 18 | 19 | model = G2_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN, args.G2_exp, args.drop_in, args.drop, 20 | args.use_G2_conv).to(args.device) 21 | 22 | lf = torch.nn.CrossEntropyLoss() 23 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 24 | 25 | @torch.no_grad() 26 | def test(model, data): 27 | model.eval() 28 | logits, accs, losses = model(data), [], [] 29 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 30 | loss = lf(logits[mask], data.y.squeeze()[mask]) 31 | pred = logits[mask].max(1)[1] 32 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 33 | accs.append(acc) 34 | losses.append(loss.item()) 35 | return accs, losses 36 | 37 | for epoch in range(args.epochs): 38 | model.train() 39 | optimizer.zero_grad() 40 | out = model(data.to(args.device)) 41 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask]) 42 | loss.backward() 43 | optimizer.step() 44 | 45 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model, data) 46 | 47 | if args.use_val_acc == True: 48 | if (val_acc > best_eval_acc): 49 | best_eval_acc = val_acc 50 | best_test_acc = test_acc 51 | else: 52 | bad_counter += 1 53 | 54 | else: 55 | if (val_loss < best_eval_loss): 56 | best_eval_loss = val_loss 57 | best_test_acc = test_acc 58 | else: 59 | bad_counter += 1 60 | 61 | if ((epoch+1) == args.patience): 62 | break 63 | 64 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 65 | print(log.format(split, epoch, train_acc, val_acc, test_acc)) 66 | 67 | return best_test_acc 68 | 69 | if __name__ == '__main__': 70 | parser = argparse.ArgumentParser(description='training parameters') 71 | parser.add_argument('--dataset', type=str, default='wisconsin', 72 | help='dataset name: texas, wisconsin, film, squirrel, chameleon, cornell') 73 | parser.add_argument('--GNN', type=str, default='GraphSAGE', 74 | help='base GNN model used with G^2: GraphSAGE, GCN, GAT') 75 | parser.add_argument('--nhid', type=int, default=256, 76 | help='number of hidden node features') 77 | parser.add_argument('--nlayers', type=int, default=6, 78 | help='number of layers') 79 | parser.add_argument('--epochs', type=int, default=500, 80 | help='max epochs') 81 | parser.add_argument('--patience', type=int, default=200, 82 | help='patience for early stopping') 83 | parser.add_argument('--lr', type=float, default=0.003, 84 | help='learning rate') 85 | parser.add_argument('--drop_in', type=float, default=0.5, 86 | help='input dropout rate') 87 | parser.add_argument('--drop', type=float, default=0.1, 88 | help='dropout rate') 89 | parser.add_argument('--weight_decay', type=float, default=0.01, 90 | help='weight_decay') 91 | parser.add_argument('--G2_exp', type=float, default=2.5, 92 | help='exponent p in G^2') 93 | parser.add_argument('--use_val_acc', type=bool, default=True, 94 | help='use validation accuracy for early stoppping -- otherwise use validation loss') 95 | parser.add_argument('--use_G2_conv', type=bool, default=False, 96 | help='use a different GNN model for the gradient gating method') 97 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 98 | help='computing device') 99 | 100 | args = parser.parse_args() 101 | 102 | n_splits = 10 103 | best_results = [] 104 | for split in range(n_splits): 105 | best_results.append(train(args, split)) 106 | 107 | best_results = np.array(best_results) 108 | mean_acc = np.mean(best_results) 109 | std = np.std(best_results) 110 | 111 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}' 112 | print(log.format(mean_acc,std)) 113 | -------------------------------------------------------------------------------- /imgs/gradient_gating_scheme2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tk-rusch/gradientgating/334bd1536f6d55190a394cd74e9df6600389f153/imgs/gradient_gating_scheme2.png -------------------------------------------------------------------------------- /node_regression/README.md: -------------------------------------------------------------------------------- 1 | ### Data preparation 2 | 3 | The data can be downloaded from https://www.kaggle.com/datasets/andreagarritano/wikipedia-article-networks/download?datasetVersionNumber=1 4 | (or alternatively if you don't want to log-in to kaggle: https://github.com/benedekrozemberczki/MUSAE/tree/master/input) 5 | 6 | Simply unpack it and put the raw .json and .csv files (chameleon and squirrel) 7 | inside a data directory inside this project. 8 | 9 | In order to access the standard splits, run the heterophilic_graphs experiment for chameleon and squirrel. 10 | This will automatically download and store the required standard splits. 11 | -------------------------------------------------------------------------------- /node_regression/data_handling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from numpy import genfromtxt 4 | import networkx 5 | from torch_geometric import utils 6 | import json 7 | 8 | def get_data(name='chameleon',split=0): 9 | edges = genfromtxt('data/'+name+'_edges.csv', delimiter=',')[1:].astype(int) 10 | G = networkx.Graph() 11 | for edge in edges: 12 | G.add_edge(edge[0],edge[1]) 13 | 14 | data = utils.from_networkx(G) 15 | y = genfromtxt('data/'+name+'_target.csv', delimiter=',')[1:,-1].astype(int) 16 | y = y/np.max(y) 17 | data.y = torch.tensor(y).float() 18 | 19 | with open('data/'+name+'_features.json', 'r') as myfile: 20 | file=myfile.read() 21 | obj = json.loads(file) 22 | 23 | if name == 'chameleon': 24 | x = np.zeros((2277,3132)) 25 | for i in range(2277): 26 | feats = np.array(obj[str(i)]) 27 | x[i,feats] = 1 28 | 29 | elif name == 'squirrel': 30 | x = np.zeros((5201, 3148)) 31 | for i in range(5201): 32 | feats = np.array(obj[str(i)]) 33 | x[i, feats] = 1 34 | 35 | data.x = torch.tensor(x).float() 36 | 37 | path = '../data/' + name 38 | splits_file = np.load(f'{path}/{name}/geom_gcn/raw/{name}_split_0.6_0.2_{split}.npz') 39 | 40 | train_mask = splits_file['train_mask'] 41 | val_mask = splits_file['val_mask'] 42 | test_mask = splits_file['test_mask'] 43 | 44 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool) 45 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool) 46 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool) 47 | 48 | return data -------------------------------------------------------------------------------- /node_regression/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GCNConv, GATConv 6 | 7 | class G2(nn.Module): 8 | def __init__(self, conv, p=2., conv_type='GCN', activation=nn.ReLU()): 9 | super(G2, self).__init__() 10 | self.conv = conv 11 | self.p = p 12 | self.activation = activation 13 | self.conv_type = conv_type 14 | 15 | def forward(self, X, edge_index): 16 | n_nodes = X.size(0) 17 | if self.conv_type == 'GAT': 18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 19 | else: 20 | X = self.activation(self.conv(X, edge_index)) 21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1), 22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean')) 23 | 24 | return gg 25 | 26 | class G2_GNN(nn.Module): 27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', p=2., drop_in=0, drop=0, use_gg_conv=True): 28 | super(G2_GNN, self).__init__() 29 | self.conv_type = conv_type 30 | self.enc = nn.Linear(nfeat, nhid) 31 | self.dec = nn.Linear(nhid, nclass) 32 | self.drop_in = drop_in 33 | self.drop = drop 34 | self.nlayers = nlayers 35 | if conv_type == 'GCN': 36 | self.conv = GCNConv(nhid, nhid) 37 | if use_gg_conv == True: 38 | self.conv_gg = GCNConv(nhid, nhid) 39 | elif conv_type == 'GAT': 40 | self.conv = GATConv(nhid,nhid,heads=4,concat=True) 41 | if use_gg_conv == True: 42 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True) 43 | else: 44 | print('specified graph conv not implemented') 45 | 46 | if use_gg_conv == True: 47 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU()) 48 | else: 49 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU()) 50 | 51 | def forward(self, data): 52 | X = data.x 53 | n_nodes = X.size(0) 54 | edge_index = data.edge_index 55 | X = F.dropout(X, self.drop_in, training=self.training) 56 | X = torch.relu(self.enc(X)) 57 | 58 | for i in range(self.nlayers): 59 | if self.conv_type == 'GAT': 60 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 61 | else: 62 | X_ = torch.relu(self.conv(X, edge_index)) 63 | tau = self.G2(X, edge_index) 64 | X = (1 - tau) * X + tau * X_ 65 | X = F.dropout(X, self.drop, training=self.training) 66 | 67 | return torch.relu(self.dec(X)) 68 | 69 | 70 | class plain_GNN(nn.Module): 71 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', drop_in=0, drop=0): 72 | super(plain_GNN, self).__init__() 73 | self.conv_type = conv_type 74 | self.drop_in = drop_in 75 | self.drop = drop 76 | self.nlayers = nlayers 77 | if conv_type == 'plain_GCN': 78 | self.conv = GCNConv(nhid, nhid) 79 | self.enc = GCNConv(nfeat, nhid) 80 | self.dec = GCNConv(nhid, nclass) 81 | elif conv_type == 'plain_GAT': 82 | self.conv = GATConv(nhid,nhid,heads=4,concat=True) 83 | self.enc = GATConv(nfeat, nhid,heads=4,concat=True) 84 | self.dec = GATConv(nhid, nclass,heads=4,concat=True) 85 | else: 86 | print('specified graph conv not implemented') 87 | 88 | def forward(self, data): 89 | X = data.x 90 | n_nodes = X.size(0) 91 | edge_index = data.edge_index 92 | X = F.dropout(X, self.drop_in, training=self.training) 93 | 94 | if self.conv_type == 'plain_GAT': 95 | X = F.elu(self.enc(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 96 | else: 97 | X = torch.relu(self.enc(X,edge_index)) 98 | 99 | for i in range(self.nlayers): 100 | if self.conv_type == 'plain_GAT': 101 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 102 | else: 103 | X = torch.relu(self.conv(X, edge_index)) 104 | X = F.dropout(X, self.drop, training=self.training) 105 | 106 | if self.conv_type == 'plain_GAT': 107 | X = torch.relu(self.dec(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 108 | else: 109 | X = torch.relu(self.dec(X,edge_index)) 110 | 111 | return X -------------------------------------------------------------------------------- /node_regression/run_GNN.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | from data_handling import get_data 6 | import argparse 7 | 8 | def train(args, split): 9 | data = get_data(args.dataset,split) 10 | best_eval_loss = 1e5 11 | bad_counter = 0 12 | best_test_loss = 1e5 13 | patience = 200 14 | 15 | nout = 1 16 | 17 | if args.dataset == 'chameleon': 18 | ninp = 3132 19 | elif args.dataset == 'squirrel': 20 | ninp = 3148 21 | 22 | if 'plain' in args.GNN: 23 | model = plain_GNN(ninp, args.nhid, nout, args.nlayers, args.GNN, args.drop_in, args.drop,).to(args.device) 24 | else: 25 | model = G2_GNN(ninp, args.nhid, nout, args.nlayers, args.GNN, args.G2_exp, args.drop_in, args.drop, 26 | args.use_G2_conv).to(args.device) 27 | 28 | lf = torch.nn.MSELoss() 29 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 30 | 31 | @torch.no_grad() 32 | def test(model, data): 33 | model.eval() 34 | out, losses = model(data).squeeze(-1), [] 35 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 36 | loss = lf(out[mask], data.y.squeeze()[mask])/torch.mean(data.y) 37 | losses.append(loss.item()) 38 | return losses 39 | 40 | for epoch in range(args.epochs): 41 | model.train() 42 | optimizer.zero_grad() 43 | out = model(data.to(args.device)).squeeze(-1) 44 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask]) 45 | loss.backward() 46 | optimizer.step() 47 | 48 | [train_loss, val_loss, test_loss] = test(model, data) 49 | 50 | if (val_loss < best_eval_loss): 51 | best_eval_loss = val_loss 52 | best_test_loss = test_loss 53 | else: 54 | bad_counter += 1 55 | 56 | if ((epoch+1) == patience): 57 | break 58 | 59 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 60 | print(log.format(split, epoch, train_loss, val_loss, test_loss)) 61 | 62 | return best_test_loss 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser(description='training parameters') 66 | parser.add_argument('--dataset', type=str, default='squirrel', 67 | help='dataset name: squirrel, chameleon') 68 | parser.add_argument('--GNN', type=str, default='GCN', 69 | help='base GNN model used with G^2: GCN, GAT -- ' 70 | 'plain GNN versions: plain_GCN, plain_GAT') 71 | parser.add_argument('--nhid', type=int, default=64, 72 | help='number of hidden node features') 73 | parser.add_argument('--nlayers', type=int, default=3, 74 | help='number of layers') 75 | parser.add_argument('--epochs', type=int, default=500, 76 | help='max epochs') 77 | parser.add_argument('--patience', type=int, default=200, 78 | help='patience for early stopping') 79 | parser.add_argument('--lr', type=float, default=0.002, 80 | help='learning rate') 81 | parser.add_argument('--drop_in', type=float, default=0.2, 82 | help='input dropout rate') 83 | parser.add_argument('--drop', type=float, default=0.3, 84 | help='dropout rate') 85 | parser.add_argument('--weight_decay', type=float, default=0.0001, 86 | help='weight_decay') 87 | parser.add_argument('--G2_exp', type=float, default=5., 88 | help='exponent p in G^2') 89 | parser.add_argument('--use_G2_conv', type=bool, default=False, 90 | help='use a different GNN model for the gradient gating method') 91 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 92 | help='computing device') 93 | 94 | args = parser.parse_args() 95 | 96 | n_splits = 10 97 | best_results = [] 98 | for split in range(n_splits): 99 | best_results.append(train(args, split)) 100 | 101 | best_results = np.array(best_results) 102 | mean_acc = np.mean(best_results) 103 | std = np.std(best_results) 104 | 105 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}' 106 | print(log.format(mean_acc, std)) 107 | -------------------------------------------------------------------------------- /synthetic_cora/README.md: -------------------------------------------------------------------------------- 1 | ### Data preparation 2 | 3 | The data can be downloaded from https://drive.google.com/file/d/1TbC-10pF2WlfbYLmu_gPV8TcXEBYeyOl/view?usp=sharing 4 | 5 | Simply unpack it and create a new directory called `data` inside this project directory and put all .npz files in there. 6 | -------------------------------------------------------------------------------- /synthetic_cora/data_handling.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import Planetoid 2 | import torch 3 | import numpy as np 4 | from torch_geometric.utils import convert 5 | import scipy.sparse as sp 6 | from torch_geometric.data import Data 7 | 8 | 9 | def get_data(hom_level=0, graph=1): 10 | seed = 123456 11 | dataset = Planetoid('../data', 'Cora') 12 | 13 | if(hom_level==0): 14 | loader = np.load('data/h0.0' + str(round(hom_level, 3)) + '-r' + str(graph) + '.npz') 15 | else: 16 | loader = np.load('data/h0.'+str(round(hom_level,3))+'-r'+str(graph)+'.npz') 17 | 18 | adj = sp.csr_matrix((loader['adj_data'], loader['adj_indices'], 19 | loader['adj_indptr']), shape=loader['adj_shape']) 20 | 21 | features = sp.csr_matrix((loader['attr_data'], loader['attr_indices'], 22 | loader['attr_indptr']), shape=loader['attr_shape']) 23 | 24 | y = torch.tensor(loader.get('labels')).long() 25 | x = torch.tensor(sp.csr_matrix.todense(features)).float() 26 | edge_index, edge_features = convert.from_scipy_sparse_matrix(adj) 27 | 28 | data = Data( 29 | x=x, 30 | edge_index=edge_index, 31 | y=y, 32 | train_mask=torch.zeros(y.size()[0], dtype=torch.bool), 33 | test_mask=torch.zeros(y.size()[0], dtype=torch.bool), 34 | val_mask=torch.zeros(y.size()[0], dtype=torch.bool) 35 | ) 36 | dataset.data = data 37 | 38 | num_nodes = data.y.shape[0] 39 | 40 | rnd_state = np.random.RandomState(seed) 41 | 42 | def get_mask(idx): 43 | mask = torch.zeros(num_nodes, dtype=torch.bool) 44 | mask[idx] = 1 45 | return mask 46 | 47 | idx = rnd_state.choice(num_nodes, size=num_nodes, replace=False) 48 | idx_train = idx[:int(0.5*num_nodes)] 49 | idx_val = idx[int(0.5*num_nodes):int(0.75*num_nodes)] 50 | idx_test = idx[int(0.75*num_nodes):] 51 | 52 | dataset.data.train_mask = get_mask(idx_train) 53 | dataset.data.val_mask = get_mask(idx_val) 54 | dataset.data.test_mask = get_mask(idx_test) 55 | 56 | return dataset.data 57 | -------------------------------------------------------------------------------- /synthetic_cora/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GCNConv, GATConv 6 | 7 | class G2(nn.Module): 8 | def __init__(self, conv, p=2., conv_type='GCN', activation=nn.ReLU()): 9 | super(G2, self).__init__() 10 | self.conv = conv 11 | self.p = p 12 | self.activation = activation 13 | self.conv_type = conv_type 14 | 15 | def forward(self, X, edge_index): 16 | n_nodes = X.size(0) 17 | if self.conv_type == 'GAT': 18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 19 | else: 20 | X = self.activation(self.conv(X, edge_index)) 21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1), 22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean')) 23 | 24 | return gg 25 | 26 | class G2_GNN(nn.Module): 27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', p=2., drop_in=0, drop=0, use_gg_conv=True): 28 | super(G2_GNN, self).__init__() 29 | self.conv_type = conv_type 30 | self.enc = nn.Linear(nfeat, nhid) 31 | self.dec = nn.Linear(nhid, nclass) 32 | self.drop_in = drop_in 33 | self.drop = drop 34 | self.nlayers = nlayers 35 | if conv_type == 'GCN': 36 | self.conv = GCNConv(nhid, nhid) 37 | if use_gg_conv == True: 38 | self.conv_gg = GCNConv(nhid, nhid) 39 | elif conv_type == 'GAT': 40 | self.conv = GATConv(nhid,nhid,heads=4,concat=True) 41 | if use_gg_conv == True: 42 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True) 43 | else: 44 | print('specified graph conv not implemented') 45 | 46 | if use_gg_conv == True: 47 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU()) 48 | else: 49 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU()) 50 | 51 | def forward(self, data): 52 | X = data.x 53 | n_nodes = X.size(0) 54 | edge_index = data.edge_index 55 | X = F.dropout(X, self.drop_in, training=self.training) 56 | X = torch.relu(self.enc(X)) 57 | 58 | for i in range(self.nlayers): 59 | if self.conv_type == 'GAT': 60 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 61 | else: 62 | X_ = torch.relu(self.conv(X, edge_index)) 63 | tau = self.G2(X, edge_index) 64 | X = (1 - tau) * X + tau * X_ 65 | X = F.dropout(X, self.drop, training=self.training) 66 | 67 | return self.dec(X) 68 | 69 | 70 | class plain_GNN(nn.Module): 71 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', drop_in=0, drop=0): 72 | super(plain_GNN, self).__init__() 73 | self.conv_type = conv_type 74 | self.drop_in = drop_in 75 | self.drop = drop 76 | self.nlayers = nlayers 77 | if conv_type == 'plain_GCN': 78 | self.conv = GCNConv(nhid, nhid) 79 | self.enc = GCNConv(nfeat, nhid) 80 | self.dec = GCNConv(nhid, nclass) 81 | elif conv_type == 'plain_GAT': 82 | self.conv = GATConv(nhid,nhid,heads=4,concat=True) 83 | self.enc = GATConv(nfeat, nhid,heads=4,concat=True) 84 | self.dec = GATConv(nhid, nclass,heads=4,concat=True) 85 | else: 86 | print('specified graph conv not implemented') 87 | 88 | def forward(self, data): 89 | X = data.x 90 | n_nodes = X.size(0) 91 | edge_index = data.edge_index 92 | X = F.dropout(X, self.drop_in, training=self.training) 93 | 94 | if self.conv_type == 'plain_GAT': 95 | X = F.elu(self.enc(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 96 | else: 97 | X = torch.relu(self.enc(X,edge_index)) 98 | 99 | for i in range(self.nlayers): 100 | if self.conv_type == 'plain_GAT': 101 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1) 102 | else: 103 | X = torch.relu(self.conv(X, edge_index)) 104 | X = F.dropout(X, self.drop, training=self.training) 105 | 106 | if self.conv_type == 'plain_GAT': 107 | X = self.dec(X, edge_index).view(n_nodes, -1, 4).mean(dim=-1) 108 | else: 109 | X = self.dec(X,edge_index) 110 | 111 | return X -------------------------------------------------------------------------------- /synthetic_cora/train_GNN.py: -------------------------------------------------------------------------------- 1 | from models import * 2 | import torch 3 | import torch.optim as optim 4 | import numpy as np 5 | from data_handling import get_data 6 | import argparse 7 | 8 | 9 | def train(args, graph_id): 10 | data = get_data(args.hom_level, graph_id) 11 | 12 | best_eval_acc = 0 13 | best_eval_loss = 1e5 14 | bad_counter = 0 15 | best_test_acc = 0 16 | patience = 200 17 | 18 | nout = 7 19 | 20 | if 'plain' in args.GNN: 21 | model = plain_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN, 22 | args.drop_in, args.drop).to(args.device) 23 | else: 24 | model = G2_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN, 25 | args.G2_exp, args.drop_in, args.drop, args.use_G2_conv).to(args.device) 26 | 27 | lf = torch.nn.CrossEntropyLoss() 28 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay) 29 | 30 | @torch.no_grad() 31 | def test(model, data): 32 | model.eval() 33 | logits, accs, losses = model(data), [], [] 34 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 35 | loss = lf(out[mask], data.y.squeeze()[mask]) 36 | pred = logits[mask].max(1)[1] 37 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 38 | accs.append(acc) 39 | losses.append(loss.item()) 40 | return accs, losses 41 | 42 | for epoch in range(args.epochs): 43 | model.train() 44 | optimizer.zero_grad() 45 | out = model(data.to(args.device)) 46 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask]) 47 | loss.backward() 48 | optimizer.step() 49 | 50 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model, data) 51 | 52 | if args.use_val_acc == True: 53 | if (val_acc > best_eval_acc): 54 | best_eval_acc = val_acc 55 | best_test_acc = test_acc 56 | else: 57 | bad_counter += 1 58 | 59 | else: 60 | if (val_loss < best_eval_loss): 61 | best_eval_loss = val_loss 62 | best_test_acc = test_acc 63 | else: 64 | bad_counter += 1 65 | 66 | if ((epoch+1) == patience): 67 | break 68 | 69 | log = 'Graph: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 70 | print(log.format(graph_id, epoch, train_acc, val_acc, test_acc)) 71 | 72 | return best_test_acc 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser(description='training parameters') 76 | parser.add_argument('--hom_level', type=int, default=0, 77 | help='level of true label homophily in percent, i.e., 0, 10, 20,...,99,100') 78 | parser.add_argument('--GNN', type=str, default='GCN', 79 | help='base GNN model used with G^2: GCN, GAT -- ' 80 | 'plain GNN versions: plain_GCN, plain_GAT') 81 | parser.add_argument('--nhid', type=int, default=128, 82 | help='number of hidden node features') 83 | parser.add_argument('--nlayers', type=int, default=13, 84 | help='number of layers') 85 | parser.add_argument('--epochs', type=int, default=500, 86 | help='max epochs') 87 | parser.add_argument('--patience', type=int, default=200, 88 | help='patience for early stopping') 89 | parser.add_argument('--lr', type=float, default=0.008, 90 | help='learning rate') 91 | parser.add_argument('--drop_in', type=float, default=0.7, 92 | help='input dropout rate') 93 | parser.add_argument('--drop', type=float, default=0.2, 94 | help='dropout rate') 95 | parser.add_argument('--weight_decay', type=float, default=0.001, 96 | help='weight_decay') 97 | parser.add_argument('--G2_exp', type=float, default=2., 98 | help='exponent p in G^2') 99 | parser.add_argument('--use_val_acc', type=bool, default=True, 100 | help='use validation accuracy for early stoppping -- otherwise use validation loss') 101 | parser.add_argument('--use_G2_conv', type=bool, default=False, 102 | help='use a different GNN model for the gradient gating method') 103 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 104 | help='computing device') 105 | 106 | args = parser.parse_args() 107 | 108 | 109 | n_graphs = 3 110 | best_results = [] 111 | for graph_id in range(1,n_graphs+1): 112 | best_results.append(train(args, graph_id)) 113 | 114 | best_results = np.array(best_results) 115 | mean_acc = np.mean(best_results) 116 | std = np.std(best_results) 117 | 118 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}' 119 | print(log.format(mean_acc,std)) 120 | --------------------------------------------------------------------------------