├── .gitignore
├── LICENSE
├── README.md
├── heterophilic_graphs
├── data_handling.py
├── models.py
└── run_GNN.py
├── imgs
└── gradient_gating_scheme2.png
├── node_regression
├── README.md
├── data_handling.py
├── models.py
└── run_GNN.py
└── synthetic_cora
├── README.md
├── data_handling.py
├── models.py
└── train_GNN.py
/.gitignore:
--------------------------------------------------------------------------------
1 | code
2 | GraphSAGE/
3 | GCN/
4 | GAT/
5 | plain_GCN/
6 | plain_GAT/
7 | plain_GraphSAGE/
8 | data/
9 | ray_tune/
10 | __pycache__/
11 | src/checkpoint
12 | images/
13 | ray_results/
14 | models/
15 | __pycache__/
16 | .data/
17 | .vector_cache/
18 | .idea/
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Konstantin Rusch
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Gradient Gating for Deep Multi-Rate Learning on Graphs
2 |
3 | This repository contains the implementation to reproduce the numerical experiments of the **ICLR 2023** paper [Gradient Gating for Deep Multi-Rate Learning on Graphs](https://openreview.net/forum?id=JpRExTbl1-)
4 |
5 | [](https://paperswithcode.com/sota/node-classification-on-arxiv-year?p=gradient-gating-for-deep-multi-rate-learning)
6 | [](https://paperswithcode.com/sota/node-classification-on-genius?p=gradient-gating-for-deep-multi-rate-learning)
7 | [](https://paperswithcode.com/sota/node-classification-on-snap-patents?p=gradient-gating-for-deep-multi-rate-learning)
8 |
9 |
10 |
11 |
12 |
13 | ### Requirements
14 | Main dependencies (with python >= 3.7):
15 | torch==1.9.0
16 | torch-cluster==1.5.9
17 | torch-geometric==2.0.3
18 | torch-scatter==2.0.9
19 | torch-sparse==0.6.12
20 | torch-spline-conv==1.2.1
21 |
22 | Commands to install all the dependencies in a new conda environment
23 | *(python 3.7 and cuda 10.2 -- for other cuda versions change accordingly)*
24 | ```
25 | conda create --name gradientgating python=3.7
26 | conda activate gradientgating
27 |
28 | pip install torch==1.9.0
29 |
30 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
31 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
32 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
33 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
34 | pip install torch-geometric
35 | pip install scipy
36 | pip install numpy
37 | ```
38 |
39 | # Citation
40 | If you found our work useful in your research, please cite our paper at:
41 | ```bibtex
42 | @inproceedings{rusch2022gradient,
43 | title={Gradient Gating for Deep Multi-Rate Learning on Graphs},
44 | author={Rusch, T Konstantin and Chamberlain, Benjamin P and Mahoney, Michael W and Bronstein, Michael M and Mishra, Siddhartha},
45 | booktitle={International Conference on Learning Representations},
46 | year={2023}
47 | }
48 | ```
49 | (Also consider starring the project on GitHub.)
50 |
--------------------------------------------------------------------------------
/heterophilic_graphs/data_handling.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import WebKB, WikipediaNetwork, Actor
2 | import torch
3 | import numpy as np
4 |
5 | def get_data(name, split=0):
6 | path = '../data/' +name
7 | if name in ['chameleon','squirrel']:
8 | dataset = WikipediaNetwork(root=path, name=name)
9 | if name in ['cornell', 'texas', 'wisconsin']:
10 | dataset = WebKB(path ,name=name)
11 | if name == 'film':
12 | dataset = Actor(root=path)
13 |
14 | data = dataset[0]
15 | if name in ['chameleon', 'squirrel']:
16 | splits_file = np.load(f'{path}/{name}/geom_gcn/raw/{name}_split_0.6_0.2_{split}.npz')
17 | if name in ['cornell', 'texas', 'wisconsin']:
18 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz')
19 | if name == 'film':
20 | splits_file = np.load(f'{path}/raw/{name}_split_0.6_0.2_{split}.npz')
21 | if name in ['Cora', 'Citeseer', 'Pubmed']:
22 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz')
23 | train_mask = splits_file['train_mask']
24 | val_mask = splits_file['val_mask']
25 | test_mask = splits_file['test_mask']
26 |
27 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
28 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
29 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
30 |
31 | return data
32 |
--------------------------------------------------------------------------------
/heterophilic_graphs/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import GCNConv, SAGEConv, GATConv
6 |
7 | class G2(nn.Module):
8 | def __init__(self, conv, p=2., conv_type='GraphSAGE', activation=nn.ReLU()):
9 | super(G2, self).__init__()
10 | self.conv = conv
11 | self.p = p
12 | self.activation = activation
13 | self.conv_type = conv_type
14 |
15 | def forward(self, X, edge_index):
16 | n_nodes = X.size(0)
17 | if self.conv_type == 'GAT':
18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
19 | else:
20 | X = self.activation(self.conv(X, edge_index))
21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1),
22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean'))
23 |
24 | return gg
25 |
26 | class G2_GNN(nn.Module):
27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GraphSAGE', p=2., drop_in=0, drop=0, use_gg_conv=True):
28 | super(G2_GNN, self).__init__()
29 | self.conv_type = conv_type
30 | self.enc = nn.Linear(nfeat, nhid)
31 | self.dec = nn.Linear(nhid, nclass)
32 | self.drop_in = drop_in
33 | self.drop = drop
34 | self.nlayers = nlayers
35 | if conv_type == 'GCN':
36 | self.conv = GCNConv(nhid, nhid)
37 | if use_gg_conv == True:
38 | self.conv_gg = GCNConv(nhid, nhid)
39 | elif conv_type == 'GraphSAGE':
40 | self.conv = SAGEConv(nhid, nhid)
41 | if use_gg_conv == True:
42 | self.conv_gg = SAGEConv(nhid, nhid)
43 | elif conv_type == 'GAT':
44 | self.conv = GATConv(nhid,nhid,heads=4,concat=True)
45 | if use_gg_conv == True:
46 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True)
47 | else:
48 | print('specified graph conv not implemented')
49 |
50 | if use_gg_conv == True:
51 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU())
52 | else:
53 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU())
54 |
55 | def forward(self, data):
56 | X = data.x
57 | n_nodes = X.size(0)
58 | edge_index = data.edge_index
59 | X = F.dropout(X, self.drop_in, training=self.training)
60 | X = torch.relu(self.enc(X))
61 |
62 | for i in range(self.nlayers):
63 | if self.conv_type == 'GAT':
64 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
65 | else:
66 | X_ = torch.relu(self.conv(X, edge_index))
67 | tau = self.G2(X, edge_index)
68 | X = (1 - tau) * X + tau * X_
69 | X = F.dropout(X, self.drop, training=self.training)
70 |
71 | return self.dec(X)
72 |
--------------------------------------------------------------------------------
/heterophilic_graphs/run_GNN.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import torch
3 | import torch.optim as optim
4 | import numpy as np
5 | from data_handling import get_data
6 | import argparse
7 |
8 |
9 | def train(args, split):
10 | data = get_data(args.dataset, split)
11 |
12 | best_eval_acc = 0
13 | best_eval_loss = 1e5
14 | bad_counter = 0
15 | best_test_acc = 0
16 |
17 | nout = 5
18 |
19 | model = G2_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN, args.G2_exp, args.drop_in, args.drop,
20 | args.use_G2_conv).to(args.device)
21 |
22 | lf = torch.nn.CrossEntropyLoss()
23 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)
24 |
25 | @torch.no_grad()
26 | def test(model, data):
27 | model.eval()
28 | logits, accs, losses = model(data), [], []
29 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
30 | loss = lf(logits[mask], data.y.squeeze()[mask])
31 | pred = logits[mask].max(1)[1]
32 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
33 | accs.append(acc)
34 | losses.append(loss.item())
35 | return accs, losses
36 |
37 | for epoch in range(args.epochs):
38 | model.train()
39 | optimizer.zero_grad()
40 | out = model(data.to(args.device))
41 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask])
42 | loss.backward()
43 | optimizer.step()
44 |
45 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model, data)
46 |
47 | if args.use_val_acc == True:
48 | if (val_acc > best_eval_acc):
49 | best_eval_acc = val_acc
50 | best_test_acc = test_acc
51 | else:
52 | bad_counter += 1
53 |
54 | else:
55 | if (val_loss < best_eval_loss):
56 | best_eval_loss = val_loss
57 | best_test_acc = test_acc
58 | else:
59 | bad_counter += 1
60 |
61 | if ((epoch+1) == args.patience):
62 | break
63 |
64 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
65 | print(log.format(split, epoch, train_acc, val_acc, test_acc))
66 |
67 | return best_test_acc
68 |
69 | if __name__ == '__main__':
70 | parser = argparse.ArgumentParser(description='training parameters')
71 | parser.add_argument('--dataset', type=str, default='wisconsin',
72 | help='dataset name: texas, wisconsin, film, squirrel, chameleon, cornell')
73 | parser.add_argument('--GNN', type=str, default='GraphSAGE',
74 | help='base GNN model used with G^2: GraphSAGE, GCN, GAT')
75 | parser.add_argument('--nhid', type=int, default=256,
76 | help='number of hidden node features')
77 | parser.add_argument('--nlayers', type=int, default=6,
78 | help='number of layers')
79 | parser.add_argument('--epochs', type=int, default=500,
80 | help='max epochs')
81 | parser.add_argument('--patience', type=int, default=200,
82 | help='patience for early stopping')
83 | parser.add_argument('--lr', type=float, default=0.003,
84 | help='learning rate')
85 | parser.add_argument('--drop_in', type=float, default=0.5,
86 | help='input dropout rate')
87 | parser.add_argument('--drop', type=float, default=0.1,
88 | help='dropout rate')
89 | parser.add_argument('--weight_decay', type=float, default=0.01,
90 | help='weight_decay')
91 | parser.add_argument('--G2_exp', type=float, default=2.5,
92 | help='exponent p in G^2')
93 | parser.add_argument('--use_val_acc', type=bool, default=True,
94 | help='use validation accuracy for early stoppping -- otherwise use validation loss')
95 | parser.add_argument('--use_G2_conv', type=bool, default=False,
96 | help='use a different GNN model for the gradient gating method')
97 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
98 | help='computing device')
99 |
100 | args = parser.parse_args()
101 |
102 | n_splits = 10
103 | best_results = []
104 | for split in range(n_splits):
105 | best_results.append(train(args, split))
106 |
107 | best_results = np.array(best_results)
108 | mean_acc = np.mean(best_results)
109 | std = np.std(best_results)
110 |
111 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}'
112 | print(log.format(mean_acc,std))
113 |
--------------------------------------------------------------------------------
/imgs/gradient_gating_scheme2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tk-rusch/gradientgating/334bd1536f6d55190a394cd74e9df6600389f153/imgs/gradient_gating_scheme2.png
--------------------------------------------------------------------------------
/node_regression/README.md:
--------------------------------------------------------------------------------
1 | ### Data preparation
2 |
3 | The data can be downloaded from https://www.kaggle.com/datasets/andreagarritano/wikipedia-article-networks/download?datasetVersionNumber=1
4 | (or alternatively if you don't want to log-in to kaggle: https://github.com/benedekrozemberczki/MUSAE/tree/master/input)
5 |
6 | Simply unpack it and put the raw .json and .csv files (chameleon and squirrel)
7 | inside a data directory inside this project.
8 |
9 | In order to access the standard splits, run the heterophilic_graphs experiment for chameleon and squirrel.
10 | This will automatically download and store the required standard splits.
11 |
--------------------------------------------------------------------------------
/node_regression/data_handling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from numpy import genfromtxt
4 | import networkx
5 | from torch_geometric import utils
6 | import json
7 |
8 | def get_data(name='chameleon',split=0):
9 | edges = genfromtxt('data/'+name+'_edges.csv', delimiter=',')[1:].astype(int)
10 | G = networkx.Graph()
11 | for edge in edges:
12 | G.add_edge(edge[0],edge[1])
13 |
14 | data = utils.from_networkx(G)
15 | y = genfromtxt('data/'+name+'_target.csv', delimiter=',')[1:,-1].astype(int)
16 | y = y/np.max(y)
17 | data.y = torch.tensor(y).float()
18 |
19 | with open('data/'+name+'_features.json', 'r') as myfile:
20 | file=myfile.read()
21 | obj = json.loads(file)
22 |
23 | if name == 'chameleon':
24 | x = np.zeros((2277,3132))
25 | for i in range(2277):
26 | feats = np.array(obj[str(i)])
27 | x[i,feats] = 1
28 |
29 | elif name == 'squirrel':
30 | x = np.zeros((5201, 3148))
31 | for i in range(5201):
32 | feats = np.array(obj[str(i)])
33 | x[i, feats] = 1
34 |
35 | data.x = torch.tensor(x).float()
36 |
37 | path = '../data/' + name
38 | splits_file = np.load(f'{path}/{name}/geom_gcn/raw/{name}_split_0.6_0.2_{split}.npz')
39 |
40 | train_mask = splits_file['train_mask']
41 | val_mask = splits_file['val_mask']
42 | test_mask = splits_file['test_mask']
43 |
44 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
45 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
46 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
47 |
48 | return data
--------------------------------------------------------------------------------
/node_regression/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import GCNConv, GATConv
6 |
7 | class G2(nn.Module):
8 | def __init__(self, conv, p=2., conv_type='GCN', activation=nn.ReLU()):
9 | super(G2, self).__init__()
10 | self.conv = conv
11 | self.p = p
12 | self.activation = activation
13 | self.conv_type = conv_type
14 |
15 | def forward(self, X, edge_index):
16 | n_nodes = X.size(0)
17 | if self.conv_type == 'GAT':
18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
19 | else:
20 | X = self.activation(self.conv(X, edge_index))
21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1),
22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean'))
23 |
24 | return gg
25 |
26 | class G2_GNN(nn.Module):
27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', p=2., drop_in=0, drop=0, use_gg_conv=True):
28 | super(G2_GNN, self).__init__()
29 | self.conv_type = conv_type
30 | self.enc = nn.Linear(nfeat, nhid)
31 | self.dec = nn.Linear(nhid, nclass)
32 | self.drop_in = drop_in
33 | self.drop = drop
34 | self.nlayers = nlayers
35 | if conv_type == 'GCN':
36 | self.conv = GCNConv(nhid, nhid)
37 | if use_gg_conv == True:
38 | self.conv_gg = GCNConv(nhid, nhid)
39 | elif conv_type == 'GAT':
40 | self.conv = GATConv(nhid,nhid,heads=4,concat=True)
41 | if use_gg_conv == True:
42 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True)
43 | else:
44 | print('specified graph conv not implemented')
45 |
46 | if use_gg_conv == True:
47 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU())
48 | else:
49 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU())
50 |
51 | def forward(self, data):
52 | X = data.x
53 | n_nodes = X.size(0)
54 | edge_index = data.edge_index
55 | X = F.dropout(X, self.drop_in, training=self.training)
56 | X = torch.relu(self.enc(X))
57 |
58 | for i in range(self.nlayers):
59 | if self.conv_type == 'GAT':
60 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
61 | else:
62 | X_ = torch.relu(self.conv(X, edge_index))
63 | tau = self.G2(X, edge_index)
64 | X = (1 - tau) * X + tau * X_
65 | X = F.dropout(X, self.drop, training=self.training)
66 |
67 | return torch.relu(self.dec(X))
68 |
69 |
70 | class plain_GNN(nn.Module):
71 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', drop_in=0, drop=0):
72 | super(plain_GNN, self).__init__()
73 | self.conv_type = conv_type
74 | self.drop_in = drop_in
75 | self.drop = drop
76 | self.nlayers = nlayers
77 | if conv_type == 'plain_GCN':
78 | self.conv = GCNConv(nhid, nhid)
79 | self.enc = GCNConv(nfeat, nhid)
80 | self.dec = GCNConv(nhid, nclass)
81 | elif conv_type == 'plain_GAT':
82 | self.conv = GATConv(nhid,nhid,heads=4,concat=True)
83 | self.enc = GATConv(nfeat, nhid,heads=4,concat=True)
84 | self.dec = GATConv(nhid, nclass,heads=4,concat=True)
85 | else:
86 | print('specified graph conv not implemented')
87 |
88 | def forward(self, data):
89 | X = data.x
90 | n_nodes = X.size(0)
91 | edge_index = data.edge_index
92 | X = F.dropout(X, self.drop_in, training=self.training)
93 |
94 | if self.conv_type == 'plain_GAT':
95 | X = F.elu(self.enc(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
96 | else:
97 | X = torch.relu(self.enc(X,edge_index))
98 |
99 | for i in range(self.nlayers):
100 | if self.conv_type == 'plain_GAT':
101 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
102 | else:
103 | X = torch.relu(self.conv(X, edge_index))
104 | X = F.dropout(X, self.drop, training=self.training)
105 |
106 | if self.conv_type == 'plain_GAT':
107 | X = torch.relu(self.dec(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
108 | else:
109 | X = torch.relu(self.dec(X,edge_index))
110 |
111 | return X
--------------------------------------------------------------------------------
/node_regression/run_GNN.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import torch
3 | import torch.optim as optim
4 | import numpy as np
5 | from data_handling import get_data
6 | import argparse
7 |
8 | def train(args, split):
9 | data = get_data(args.dataset,split)
10 | best_eval_loss = 1e5
11 | bad_counter = 0
12 | best_test_loss = 1e5
13 | patience = 200
14 |
15 | nout = 1
16 |
17 | if args.dataset == 'chameleon':
18 | ninp = 3132
19 | elif args.dataset == 'squirrel':
20 | ninp = 3148
21 |
22 | if 'plain' in args.GNN:
23 | model = plain_GNN(ninp, args.nhid, nout, args.nlayers, args.GNN, args.drop_in, args.drop,).to(args.device)
24 | else:
25 | model = G2_GNN(ninp, args.nhid, nout, args.nlayers, args.GNN, args.G2_exp, args.drop_in, args.drop,
26 | args.use_G2_conv).to(args.device)
27 |
28 | lf = torch.nn.MSELoss()
29 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)
30 |
31 | @torch.no_grad()
32 | def test(model, data):
33 | model.eval()
34 | out, losses = model(data).squeeze(-1), []
35 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
36 | loss = lf(out[mask], data.y.squeeze()[mask])/torch.mean(data.y)
37 | losses.append(loss.item())
38 | return losses
39 |
40 | for epoch in range(args.epochs):
41 | model.train()
42 | optimizer.zero_grad()
43 | out = model(data.to(args.device)).squeeze(-1)
44 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask])
45 | loss.backward()
46 | optimizer.step()
47 |
48 | [train_loss, val_loss, test_loss] = test(model, data)
49 |
50 | if (val_loss < best_eval_loss):
51 | best_eval_loss = val_loss
52 | best_test_loss = test_loss
53 | else:
54 | bad_counter += 1
55 |
56 | if ((epoch+1) == patience):
57 | break
58 |
59 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
60 | print(log.format(split, epoch, train_loss, val_loss, test_loss))
61 |
62 | return best_test_loss
63 |
64 | if __name__ == '__main__':
65 | parser = argparse.ArgumentParser(description='training parameters')
66 | parser.add_argument('--dataset', type=str, default='squirrel',
67 | help='dataset name: squirrel, chameleon')
68 | parser.add_argument('--GNN', type=str, default='GCN',
69 | help='base GNN model used with G^2: GCN, GAT -- '
70 | 'plain GNN versions: plain_GCN, plain_GAT')
71 | parser.add_argument('--nhid', type=int, default=64,
72 | help='number of hidden node features')
73 | parser.add_argument('--nlayers', type=int, default=3,
74 | help='number of layers')
75 | parser.add_argument('--epochs', type=int, default=500,
76 | help='max epochs')
77 | parser.add_argument('--patience', type=int, default=200,
78 | help='patience for early stopping')
79 | parser.add_argument('--lr', type=float, default=0.002,
80 | help='learning rate')
81 | parser.add_argument('--drop_in', type=float, default=0.2,
82 | help='input dropout rate')
83 | parser.add_argument('--drop', type=float, default=0.3,
84 | help='dropout rate')
85 | parser.add_argument('--weight_decay', type=float, default=0.0001,
86 | help='weight_decay')
87 | parser.add_argument('--G2_exp', type=float, default=5.,
88 | help='exponent p in G^2')
89 | parser.add_argument('--use_G2_conv', type=bool, default=False,
90 | help='use a different GNN model for the gradient gating method')
91 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
92 | help='computing device')
93 |
94 | args = parser.parse_args()
95 |
96 | n_splits = 10
97 | best_results = []
98 | for split in range(n_splits):
99 | best_results.append(train(args, split))
100 |
101 | best_results = np.array(best_results)
102 | mean_acc = np.mean(best_results)
103 | std = np.std(best_results)
104 |
105 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}'
106 | print(log.format(mean_acc, std))
107 |
--------------------------------------------------------------------------------
/synthetic_cora/README.md:
--------------------------------------------------------------------------------
1 | ### Data preparation
2 |
3 | The data can be downloaded from https://drive.google.com/file/d/1TbC-10pF2WlfbYLmu_gPV8TcXEBYeyOl/view?usp=sharing
4 |
5 | Simply unpack it and create a new directory called `data` inside this project directory and put all .npz files in there.
6 |
--------------------------------------------------------------------------------
/synthetic_cora/data_handling.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import Planetoid
2 | import torch
3 | import numpy as np
4 | from torch_geometric.utils import convert
5 | import scipy.sparse as sp
6 | from torch_geometric.data import Data
7 |
8 |
9 | def get_data(hom_level=0, graph=1):
10 | seed = 123456
11 | dataset = Planetoid('../data', 'Cora')
12 |
13 | if(hom_level==0):
14 | loader = np.load('data/h0.0' + str(round(hom_level, 3)) + '-r' + str(graph) + '.npz')
15 | else:
16 | loader = np.load('data/h0.'+str(round(hom_level,3))+'-r'+str(graph)+'.npz')
17 |
18 | adj = sp.csr_matrix((loader['adj_data'], loader['adj_indices'],
19 | loader['adj_indptr']), shape=loader['adj_shape'])
20 |
21 | features = sp.csr_matrix((loader['attr_data'], loader['attr_indices'],
22 | loader['attr_indptr']), shape=loader['attr_shape'])
23 |
24 | y = torch.tensor(loader.get('labels')).long()
25 | x = torch.tensor(sp.csr_matrix.todense(features)).float()
26 | edge_index, edge_features = convert.from_scipy_sparse_matrix(adj)
27 |
28 | data = Data(
29 | x=x,
30 | edge_index=edge_index,
31 | y=y,
32 | train_mask=torch.zeros(y.size()[0], dtype=torch.bool),
33 | test_mask=torch.zeros(y.size()[0], dtype=torch.bool),
34 | val_mask=torch.zeros(y.size()[0], dtype=torch.bool)
35 | )
36 | dataset.data = data
37 |
38 | num_nodes = data.y.shape[0]
39 |
40 | rnd_state = np.random.RandomState(seed)
41 |
42 | def get_mask(idx):
43 | mask = torch.zeros(num_nodes, dtype=torch.bool)
44 | mask[idx] = 1
45 | return mask
46 |
47 | idx = rnd_state.choice(num_nodes, size=num_nodes, replace=False)
48 | idx_train = idx[:int(0.5*num_nodes)]
49 | idx_val = idx[int(0.5*num_nodes):int(0.75*num_nodes)]
50 | idx_test = idx[int(0.75*num_nodes):]
51 |
52 | dataset.data.train_mask = get_mask(idx_train)
53 | dataset.data.val_mask = get_mask(idx_val)
54 | dataset.data.test_mask = get_mask(idx_test)
55 |
56 | return dataset.data
57 |
--------------------------------------------------------------------------------
/synthetic_cora/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch_scatter import scatter
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch_geometric.nn import GCNConv, GATConv
6 |
7 | class G2(nn.Module):
8 | def __init__(self, conv, p=2., conv_type='GCN', activation=nn.ReLU()):
9 | super(G2, self).__init__()
10 | self.conv = conv
11 | self.p = p
12 | self.activation = activation
13 | self.conv_type = conv_type
14 |
15 | def forward(self, X, edge_index):
16 | n_nodes = X.size(0)
17 | if self.conv_type == 'GAT':
18 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
19 | else:
20 | X = self.activation(self.conv(X, edge_index))
21 | gg = torch.tanh(scatter((torch.abs(X[edge_index[0]] - X[edge_index[1]]) ** self.p).squeeze(-1),
22 | edge_index[0], 0,dim_size=X.size(0), reduce='mean'))
23 |
24 | return gg
25 |
26 | class G2_GNN(nn.Module):
27 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', p=2., drop_in=0, drop=0, use_gg_conv=True):
28 | super(G2_GNN, self).__init__()
29 | self.conv_type = conv_type
30 | self.enc = nn.Linear(nfeat, nhid)
31 | self.dec = nn.Linear(nhid, nclass)
32 | self.drop_in = drop_in
33 | self.drop = drop
34 | self.nlayers = nlayers
35 | if conv_type == 'GCN':
36 | self.conv = GCNConv(nhid, nhid)
37 | if use_gg_conv == True:
38 | self.conv_gg = GCNConv(nhid, nhid)
39 | elif conv_type == 'GAT':
40 | self.conv = GATConv(nhid,nhid,heads=4,concat=True)
41 | if use_gg_conv == True:
42 | self.conv_gg = GATConv(nhid,nhid,heads=4,concat=True)
43 | else:
44 | print('specified graph conv not implemented')
45 |
46 | if use_gg_conv == True:
47 | self.G2 = G2(self.conv_gg,p,conv_type,activation=nn.ReLU())
48 | else:
49 | self.G2 = G2(self.conv,p,conv_type,activation=nn.ReLU())
50 |
51 | def forward(self, data):
52 | X = data.x
53 | n_nodes = X.size(0)
54 | edge_index = data.edge_index
55 | X = F.dropout(X, self.drop_in, training=self.training)
56 | X = torch.relu(self.enc(X))
57 |
58 | for i in range(self.nlayers):
59 | if self.conv_type == 'GAT':
60 | X_ = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
61 | else:
62 | X_ = torch.relu(self.conv(X, edge_index))
63 | tau = self.G2(X, edge_index)
64 | X = (1 - tau) * X + tau * X_
65 | X = F.dropout(X, self.drop, training=self.training)
66 |
67 | return self.dec(X)
68 |
69 |
70 | class plain_GNN(nn.Module):
71 | def __init__(self, nfeat, nhid, nclass, nlayers, conv_type='GCN', drop_in=0, drop=0):
72 | super(plain_GNN, self).__init__()
73 | self.conv_type = conv_type
74 | self.drop_in = drop_in
75 | self.drop = drop
76 | self.nlayers = nlayers
77 | if conv_type == 'plain_GCN':
78 | self.conv = GCNConv(nhid, nhid)
79 | self.enc = GCNConv(nfeat, nhid)
80 | self.dec = GCNConv(nhid, nclass)
81 | elif conv_type == 'plain_GAT':
82 | self.conv = GATConv(nhid,nhid,heads=4,concat=True)
83 | self.enc = GATConv(nfeat, nhid,heads=4,concat=True)
84 | self.dec = GATConv(nhid, nclass,heads=4,concat=True)
85 | else:
86 | print('specified graph conv not implemented')
87 |
88 | def forward(self, data):
89 | X = data.x
90 | n_nodes = X.size(0)
91 | edge_index = data.edge_index
92 | X = F.dropout(X, self.drop_in, training=self.training)
93 |
94 | if self.conv_type == 'plain_GAT':
95 | X = F.elu(self.enc(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
96 | else:
97 | X = torch.relu(self.enc(X,edge_index))
98 |
99 | for i in range(self.nlayers):
100 | if self.conv_type == 'plain_GAT':
101 | X = F.elu(self.conv(X, edge_index)).view(n_nodes, -1, 4).mean(dim=-1)
102 | else:
103 | X = torch.relu(self.conv(X, edge_index))
104 | X = F.dropout(X, self.drop, training=self.training)
105 |
106 | if self.conv_type == 'plain_GAT':
107 | X = self.dec(X, edge_index).view(n_nodes, -1, 4).mean(dim=-1)
108 | else:
109 | X = self.dec(X,edge_index)
110 |
111 | return X
--------------------------------------------------------------------------------
/synthetic_cora/train_GNN.py:
--------------------------------------------------------------------------------
1 | from models import *
2 | import torch
3 | import torch.optim as optim
4 | import numpy as np
5 | from data_handling import get_data
6 | import argparse
7 |
8 |
9 | def train(args, graph_id):
10 | data = get_data(args.hom_level, graph_id)
11 |
12 | best_eval_acc = 0
13 | best_eval_loss = 1e5
14 | bad_counter = 0
15 | best_test_acc = 0
16 | patience = 200
17 |
18 | nout = 7
19 |
20 | if 'plain' in args.GNN:
21 | model = plain_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN,
22 | args.drop_in, args.drop).to(args.device)
23 | else:
24 | model = G2_GNN(data.num_node_features, args.nhid, nout, args.nlayers, args.GNN,
25 | args.G2_exp, args.drop_in, args.drop, args.use_G2_conv).to(args.device)
26 |
27 | lf = torch.nn.CrossEntropyLoss()
28 | optimizer = optim.Adam(model.parameters(),lr=args.lr,weight_decay=args.weight_decay)
29 |
30 | @torch.no_grad()
31 | def test(model, data):
32 | model.eval()
33 | logits, accs, losses = model(data), [], []
34 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
35 | loss = lf(out[mask], data.y.squeeze()[mask])
36 | pred = logits[mask].max(1)[1]
37 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
38 | accs.append(acc)
39 | losses.append(loss.item())
40 | return accs, losses
41 |
42 | for epoch in range(args.epochs):
43 | model.train()
44 | optimizer.zero_grad()
45 | out = model(data.to(args.device))
46 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask])
47 | loss.backward()
48 | optimizer.step()
49 |
50 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model, data)
51 |
52 | if args.use_val_acc == True:
53 | if (val_acc > best_eval_acc):
54 | best_eval_acc = val_acc
55 | best_test_acc = test_acc
56 | else:
57 | bad_counter += 1
58 |
59 | else:
60 | if (val_loss < best_eval_loss):
61 | best_eval_loss = val_loss
62 | best_test_acc = test_acc
63 | else:
64 | bad_counter += 1
65 |
66 | if ((epoch+1) == patience):
67 | break
68 |
69 | log = 'Graph: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
70 | print(log.format(graph_id, epoch, train_acc, val_acc, test_acc))
71 |
72 | return best_test_acc
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser(description='training parameters')
76 | parser.add_argument('--hom_level', type=int, default=0,
77 | help='level of true label homophily in percent, i.e., 0, 10, 20,...,99,100')
78 | parser.add_argument('--GNN', type=str, default='GCN',
79 | help='base GNN model used with G^2: GCN, GAT -- '
80 | 'plain GNN versions: plain_GCN, plain_GAT')
81 | parser.add_argument('--nhid', type=int, default=128,
82 | help='number of hidden node features')
83 | parser.add_argument('--nlayers', type=int, default=13,
84 | help='number of layers')
85 | parser.add_argument('--epochs', type=int, default=500,
86 | help='max epochs')
87 | parser.add_argument('--patience', type=int, default=200,
88 | help='patience for early stopping')
89 | parser.add_argument('--lr', type=float, default=0.008,
90 | help='learning rate')
91 | parser.add_argument('--drop_in', type=float, default=0.7,
92 | help='input dropout rate')
93 | parser.add_argument('--drop', type=float, default=0.2,
94 | help='dropout rate')
95 | parser.add_argument('--weight_decay', type=float, default=0.001,
96 | help='weight_decay')
97 | parser.add_argument('--G2_exp', type=float, default=2.,
98 | help='exponent p in G^2')
99 | parser.add_argument('--use_val_acc', type=bool, default=True,
100 | help='use validation accuracy for early stoppping -- otherwise use validation loss')
101 | parser.add_argument('--use_G2_conv', type=bool, default=False,
102 | help='use a different GNN model for the gradient gating method')
103 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
104 | help='computing device')
105 |
106 | args = parser.parse_args()
107 |
108 |
109 | n_graphs = 3
110 | best_results = []
111 | for graph_id in range(1,n_graphs+1):
112 | best_results.append(train(args, graph_id))
113 |
114 | best_results = np.array(best_results)
115 | mean_acc = np.mean(best_results)
116 | std = np.std(best_results)
117 |
118 | log = 'Final test results -- mean: {:.4f}, std: {:.4f}'
119 | print(log.format(mean_acc,std))
120 |
--------------------------------------------------------------------------------