├── .gitignore
├── README.md
├── imgs
├── graphCON_figure.pdf
└── graphCON_figure.png
└── src
├── Superpixels
├── best_params.py
├── data_handling.py
├── models.py
└── run_GNN.py
├── ZINC
├── data_handling.py
├── models.py
└── run_GNN.py
├── definitions.py
├── heterophilic_graphs
├── best_params.py
├── data_handling.py
├── models.py
└── run_GNN.py
└── homophilic_graphs
├── GNN.py
├── GNN_early.py
├── README.md
├── base_classes.py
├── block_constant.py
├── block_transformer_attention.py
├── data.py
├── early_stop_solver.py
├── experiment_configs
├── gat_pubmed.yaml
├── gcn_cora.yaml
├── gcn_depth_random.yaml
├── gcn_planetoid.yaml
└── run_sweeps.sh
├── function_GAT_attention.py
├── function_gcn.py
├── function_laplacian_diffusion.py
├── function_transformer_attention.py
├── geometric_integrators.py
├── geometric_solvers.py
├── good_params_graphCON.py
├── model_configurations.py
├── odeint_geometric.py
├── regularized_ODE_function.py
├── run_GNN.py
├── run_best_sweeps.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | code
2 | data/
3 | ray_tune/
4 | __pycache__/
5 | src/checkpoint
6 | images/
7 | ray_results/
8 | models/
9 | __pycache__/
10 | .data/
11 | .vector_cache/
12 | .idea/
13 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Graph-Coupled Oscillator Networks
2 |
3 | This repository contains the implementation to reproduce the numerical experiments
4 | of the **ICML 2022** paper [Graph-Coupled Oscillator Networks](https://arxiv.org/abs/2202.02296)
5 |
6 |
7 |
8 |
9 |
10 | ### Requirements
11 | Main dependencies (with python >= 3.7):
12 | torch==1.9.0
13 | torch-cluster==1.5.9
14 | torch-geometric==2.0.3
15 | torch-scatter==2.0.9
16 | torch-sparse==0.6.12
17 | torch-spline-conv==1.2.1
18 | torchdiffeq==0.2.2
19 |
20 | Commands to install all the dependencies in a new conda environment
21 | *(python 3.7 and cuda 10.2 -- for other cuda versions change accordingly)*
22 | ```
23 | conda create --name graphCON python=3.7
24 | conda activate graphCON
25 |
26 | pip install ogb pykeops
27 | pip install torch==1.9.0
28 | pip install torchdiffeq -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
29 |
30 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
31 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
32 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
33 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
34 | pip install torch-geometric
35 | pip install wandb
36 | ```
37 |
38 | ### Run the experiments
39 | To run each experiment, navigate into `src/exp_dir`
40 | (change `exp_dir` to the name of the corresponding experiment directory).
41 | There, simply do
42 | ```
43 | python run_GNN.py --kwargs
44 | ```
45 | where kwargs are specified in each individual `run_GNN.py` file.
46 |
47 | ### Dataset and preprocessing
48 | All data gets downloaded and preprocessed automatically and stored in `./data` directory
49 | (which gets automatically created the first time one of the experiments is run).
50 |
51 | ### Usage
52 | GraphCON is a general framework for "stacking" many GNN layers (aka message passing mechanisms)
53 | in order to obtain a deep GNN which overcomes the oversmoothing problem.
54 |
55 | Given any standard GNN layer (such as *GCN* or *GAT*),
56 | GraphCON can be implemented using **PyTorch (geometric)** as simple as that:
57 | ```python
58 | from torch import nn
59 | import torch
60 | import torch.nn.functional as F
61 |
62 |
63 | class GraphCON(nn.Module):
64 | def __init__(self, GNNs, dt=1., alpha=1., gamma=1., dropout=None):
65 | super(GraphCON, self).__init__()
66 | self.dt = dt
67 | self.alpha = alpha
68 | self.gamma = gamma
69 | self.GNNs = GNNs # list of the individual GNN layers
70 | self.dropout = dropout
71 |
72 | def forward(self, X0, Y0, edge_index):
73 | # set initial values of ODEs
74 | X = X0
75 | Y = Y0
76 | # solve ODEs using simple IMEX scheme
77 | for gnn in self.GNNs:
78 | Y = Y + self.dt * (torch.relu(gnn(X, edge_index)) - self.alpha * Y - self.gamma * X)
79 | X = X + self.dt * Y
80 |
81 | if (self.dropout is not None):
82 | Y = F.dropout(Y, self.dropout, training=self.training)
83 | X = F.dropout(X, self.dropout, training=self.training)
84 |
85 | return X, Y
86 | ```
87 |
88 | A deep GraphCON model using for instance [Kipf & Welling's GCN](https://arxiv.org/abs/1609.02907)
89 | as the underlying message passing mechanism can then be written as
90 |
91 | ```python
92 | from torch_geometric.nn import GCNConv
93 |
94 |
95 | class deep_GNN(nn.Module):
96 | def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1., dropout=None):
97 | super(deep_GNN, self).__init__()
98 | self.enc = nn.Linear(nfeat, nhid)
99 | self.GNNs = nn.ModuleList()
100 | for _ in range(nlayers):
101 | self.GNNs.append(GCNConv(nhid, nhid))
102 | self.graphcon = GraphCON(self.GNNs, dt, alpha, gamma, dropout)
103 | self.dec = nn.Linear(nhid, nclass)
104 |
105 | def forward(self, x, edge_index):
106 | # compute initial values of ODEs (encode input)
107 | X0 = self.enc(x)
108 | Y0 = X0
109 | # stack GNNs using GraphCON
110 | X, Y = self.graphcon(X0, Y0, edge_index)
111 | # decode X state of GraphCON at final time for output nodes
112 | output = self.dec(X)
113 | return output
114 | ```
115 | This is just an easy example to demonstrate the **simple usage of GraphCON**.
116 | You can find the full GraphCON models we used in our experiments in the `src` directory.
117 |
118 | # Citation
119 | If you found our work useful in your research, please cite our paper at:
120 | ```bibtex
121 | @article{graphcon,
122 | title={Graph-Coupled Oscillator Networks},
123 | author={Rusch, T Konstantin and Chamberlain, Benjamin P and Rowbottom, James and Mishra, Siddhartha and Bronstein, Michael M},
124 | journal={arXiv preprint arXiv:2202.02296},
125 | year={2022}
126 | }
127 | ```
128 | (Also consider starring the project on GitHub.)
129 |
--------------------------------------------------------------------------------
/imgs/graphCON_figure.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tk-rusch/GraphCON/3326c48bd3631b9f25417455ded9a1a83ccd3e93/imgs/graphCON_figure.pdf
--------------------------------------------------------------------------------
/imgs/graphCON_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tk-rusch/GraphCON/3326c48bd3631b9f25417455ded9a1a83ccd3e93/imgs/graphCON_figure.png
--------------------------------------------------------------------------------
/src/Superpixels/best_params.py:
--------------------------------------------------------------------------------
1 | best_params_dict = {'GraphCON_GCN': {'lr': 0.00116, 'alpha': 1.396, 'gamma': 0.841, 'nlayers': 10, 'dropout': 0.014151415141514152},
2 | 'GraphCON_GAT': {'lr': 0.00028, 'alpha': 0.972, 'gamma': 0.348, 'nlayers': 14, 'dropout': 0.04}
3 | }
4 |
--------------------------------------------------------------------------------
/src/Superpixels/data_handling.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import MNISTSuperpixels
2 | DATA_PATH = '../../data'
3 |
4 | def get_data(train):
5 | path = '../../data/MNIST'
6 | dataset = MNISTSuperpixels(path,train=train)
7 | return dataset
--------------------------------------------------------------------------------
/src/Superpixels/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch_geometric.nn import GCNConv, GATConv
5 | from torch_geometric.nn import global_mean_pool, global_add_pool
6 |
7 | class GraphCON_GCN(nn.Module):
8 | def __init__(self, nfeat, nhid, nclass, dropout, nlayers, dt=1., alpha=1., gamma=1.):
9 | super(GraphCON_GCN, self).__init__()
10 | self.dropout = dropout
11 | self.nhid = nhid
12 | self.nlayers = nlayers
13 | self.enc = nn.Linear(nfeat,nhid)
14 | self.conv = GCNConv(nhid, nhid)
15 | self.dec = nn.Linear(nhid,nclass)
16 | self.res = nn.Linear(nhid,nhid)
17 | self.dt = dt
18 | self.act_fn = nn.ReLU()
19 | self.alpha = alpha
20 | self.gamma = gamma
21 |
22 | def res_connection(self, X):
23 | res = self.res(X)
24 | return res
25 |
26 | def forward(self, data):
27 | input = torch.cat([data.x,data.pos],dim=-1)
28 | edge_index = data.edge_index
29 | Y = self.enc(input)
30 | X = Y
31 | Y = F.dropout(Y, self.dropout, training=self.training)
32 | X = F.dropout(X, self.dropout, training=self.training)
33 |
34 | for i in range(self.nlayers):
35 | Y = Y + self.dt*(self.act_fn(self.conv(X,edge_index) + self.res_connection(X)) - self.alpha*Y - self.gamma*X)
36 | X = X + self.dt*Y
37 | Y = F.dropout(Y, self.dropout, training=self.training)
38 | X = F.dropout(X, self.dropout, training=self.training)
39 |
40 | X = self.dec(X)
41 | X = global_add_pool(X, data.batch)
42 |
43 | return X.squeeze(-1)
44 |
45 | class GraphCON_GAT(nn.Module):
46 | def __init__(self, nfeat, nhid, nclass, nlayers, dropout, dt=1., alpha=1., gamma=1., nheads=4):
47 | super(GraphCON_GAT, self).__init__()
48 | self.alpha = alpha
49 | self.gamma = gamma
50 | self.dropout = dropout
51 | self.nheads = nheads
52 | self.nhid = nhid
53 | self.nlayers = nlayers
54 | self.act_fn = nn.ReLU()
55 | self.res = nn.Linear(nhid, nheads * nhid)
56 | self.enc = nn.Linear(nfeat,nhid)
57 | self.conv = GATConv(nhid, nhid, heads=nheads)
58 | self.dec = nn.Linear(nhid,nclass)
59 | self.dt = dt
60 |
61 | def res_connection(self, X):
62 | res = self.res(X)
63 | return res
64 |
65 | def forward(self, data):
66 | input = torch.cat([data.x, data.pos], dim=-1)
67 | n_nodes = input.size(0)
68 | edge_index = data.edge_index
69 | Y = self.enc(input)
70 | X = Y
71 | Y = F.dropout(Y, self.dropout, training=self.training)
72 | X = F.dropout(X, self.dropout, training=self.training)
73 |
74 | for i in range(self.nlayers):
75 | Y = Y + self.dt*(F.elu(self.conv(X, edge_index) + self.res_connection(X)).view(n_nodes, -1, self.nheads).mean(dim=-1) - self.alpha*Y - self.gamma*X)
76 | X = X + self.dt*Y
77 | Y = F.dropout(Y, self.dropout, training=self.training)
78 | X = F.dropout(X, self.dropout, training=self.training)
79 |
80 | X = self.dec(X)
81 | X = global_add_pool(X, data.batch)
82 |
83 | return X.squeeze(-1)
--------------------------------------------------------------------------------
/src/Superpixels/run_GNN.py:
--------------------------------------------------------------------------------
1 | from data_handling import get_data
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 | import torch.optim as optim
4 | from models import *
5 | from torch_geometric.data import DataLoader
6 | from best_params import best_params_dict
7 | import argparse
8 |
9 | def train_GNN(opt):
10 | train_dataset = get_data(train=True)
11 | test_dataset = get_data(train=False)
12 |
13 | train_dataset = train_dataset.shuffle()
14 | val_dataset = train_dataset[:5000]
15 | train_dataset = train_dataset[5000:]
16 |
17 | train_loader = DataLoader(train_dataset, batch_size=opt['batch'], shuffle=True)
18 | val_loader = DataLoader(val_dataset, batch_size=opt['batch'], shuffle=False)
19 | test_loader = DataLoader(test_dataset, batch_size=opt['batch'], shuffle=False)
20 |
21 | epochs = opt['epochs']
22 |
23 | if opt['model'] == 'GraphCON_GCN':
24 | model = GraphCON_GCN(nfeat=train_dataset.data.num_features+2,nhid=opt['nhid'],nclass=10,
25 | dropout=opt['drop'],nlayers=opt['nlayers'],dt=1.,
26 | alpha=opt['alpha'],gamma=opt['gamma']).to(opt['device'])
27 | elif opt['model'] == 'GraphCON_GAT':
28 | model = GraphCON_GAT(nfeat=train_dataset.data.num_features+2, nhid=opt['nhid'], nclass=10,
29 | dropout=opt['drop'], nlayers=opt['nlayers'], dt=1.,
30 | alpha=opt['alpha'], gamma=opt['gamma'],nheads=opt['nheads']).to(opt['device'])
31 |
32 | optimizer = optim.Adam(model.parameters(), lr=opt['lr'])
33 | lf = torch.nn.CrossEntropyLoss()
34 | scheduler = ReduceLROnPlateau(optimizer, 'max', factor=opt['reduce_factor'])
35 |
36 | best_eval = 0
37 |
38 | def test(data_loader):
39 | model.eval()
40 | correct = 0
41 | with torch.no_grad():
42 | for i, data in enumerate(data_loader):
43 | data = data.to(opt['device'])
44 | output = model(data)
45 | pred = output.data.max(1, keepdim=True)[1]
46 | correct += pred.eq(data.y.data.view_as(pred)).sum()
47 |
48 | accuracy = 100. * correct / len(data_loader.dataset)
49 | return accuracy.item()
50 |
51 | for epoch in range(epochs):
52 | model.train()
53 | for i, data in enumerate(train_loader):
54 | data = data.to(opt['device'])
55 | optimizer.zero_grad()
56 | out = model(data)
57 | loss = lf(out, data.y)
58 | loss.backward()
59 | optimizer.step()
60 |
61 | val_acc = test(val_loader)
62 | test_acc = test(test_loader)
63 |
64 | if(val_acc > best_eval):
65 | best_eval = val_acc
66 | best_test_acc = test_acc
67 |
68 | log = 'Epoch: {:03d}, Val: {:.4f}, Test: {:.4f}'
69 | print(log.format(epoch, val_acc, test_acc))
70 |
71 | scheduler.step(val_acc)
72 | for param_group in optimizer.param_groups:
73 | curr_lr = param_group['lr']
74 | if(curr_lr<1e-5):
75 | break
76 |
77 | if(epoch > 25 and val_acc < 20.):
78 | break
79 |
80 | print('Final test accuracy: ', best_test_acc)
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser(description='training parameters')
84 | parser.add_argument('--model', type=str, default='GraphCON_GAT',
85 | help='GraphCON_GCN, GraphCON_GAT')
86 | parser.add_argument('--nhid', type=int, default=256,
87 | help='number of hidden node features')
88 | parser.add_argument('--nlayers', type=int, default=5,
89 | help='number of layers')
90 | parser.add_argument('--alpha', type=float, default=1.,
91 | help='alpha parameter of graphCON')
92 | parser.add_argument('--gamma', type=float, default=1.,
93 | help='gamma parameter of graphCON')
94 | parser.add_argument('--nheads', type=int, default=4,
95 | help='number of attention heads for GraphCON-GAT')
96 | parser.add_argument('--epochs', type=int, default=1000,
97 | help='max epochs')
98 | parser.add_argument('--batch', type=int, default=32,
99 | help='batch size')
100 | parser.add_argument('--reduce_factor', type=float, default=0.5,
101 | help='reduce factor')
102 | parser.add_argument('--lr', type=float, default=0.001,
103 | help='learning rate')
104 | parser.add_argument('--drop', type=float, default=0.3,
105 | help='dropout rate')
106 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
107 | help='computing device')
108 |
109 | args = parser.parse_args()
110 | cmd_opt = vars(args)
111 |
112 | best_opt = best_params_dict[cmd_opt['model']]
113 | opt = {**cmd_opt, **best_opt}
114 | print(opt)
115 |
116 | train_GNN(opt)
117 |
118 |
--------------------------------------------------------------------------------
/src/ZINC/data_handling.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch_geometric.datasets import ZINC
3 | DATA_PATH = '../../data'
4 |
5 | def get_zinc_data(split):
6 | path = '../../data/ZINC'
7 | dataset = ZINC(path,subset=True,split=split)
8 | return dataset
--------------------------------------------------------------------------------
/src/ZINC/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch_geometric.nn import GCNConv
4 | from torch_geometric.nn import global_mean_pool, global_add_pool
5 |
6 | class GraphCON_GCN(nn.Module):
7 | def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1.):
8 | super(GraphCON_GCN, self).__init__()
9 | self.nlayers = nlayers
10 | self.enc = nn.Linear(nfeat,nhid)
11 | self.res = nn.Linear(nhid,nhid)
12 | self.conv = GCNConv(nhid, nhid)
13 | self.dec = nn.Linear(nhid,nclass)
14 | self.dt = dt
15 | self.alpha = alpha
16 | self.gamma = gamma
17 |
18 | def res_connection(self, X):
19 | ## residual connection
20 | res = self.res(X) - self.conv.lin(X)
21 | return res
22 |
23 | def forward(self, data):
24 | ## the following encoder gives better results than using nn.embedding()
25 | input = data.x.float()
26 | edge_index = data.edge_index
27 | Y = self.enc(input)
28 | X = Y
29 |
30 | for i in range(self.nlayers):
31 | Y = Y + self.dt * (torch.relu(self.conv(X, edge_index) + self.res_connection(X)) - self.alpha * Y - self.gamma * X)
32 | X = X + self.dt * Y
33 |
34 | X = self.dec(X)
35 | X = global_add_pool(X, data.batch).squeeze(-1)
36 |
37 | return X
--------------------------------------------------------------------------------
/src/ZINC/run_GNN.py:
--------------------------------------------------------------------------------
1 | from data_handling import get_zinc_data
2 | import numpy as np
3 | import torch
4 | import torch.optim as optim
5 | from models import *
6 | from torch.optim.lr_scheduler import ReduceLROnPlateau
7 | from torch_geometric.data import DataLoader
8 |
9 | import argparse
10 |
11 | parser = argparse.ArgumentParser(description='training parameters')
12 |
13 | parser.add_argument('--nhid', type=int, default=220,
14 | help='number of hidden node features')
15 | parser.add_argument('--nlayers', type=int, default=22,
16 | help='number of layers')
17 | parser.add_argument('--alpha', type=float, default=0.215,
18 | help='alpha parameter of graphCON')
19 | parser.add_argument('--gamma', type=float, default=1.115,
20 | help='gamma parameter of graphCON')
21 | parser.add_argument('--epochs', type=int, default=3000,
22 | help='max epochs')
23 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
24 | help='computing device')
25 | parser.add_argument('--batch', type=int, default=128,
26 | help='batch size')
27 | parser.add_argument('--lr', type=float, default=0.00159,
28 | help='learning rate')
29 | parser.add_argument('--reduce_point', type=int, default=20,
30 | help='length of patience')
31 | parser.add_argument('--start_reduce', type=int, default=1000,
32 | help='epoch when to start reducing')
33 | parser.add_argument('--seed', type=int, default=1111,
34 | help='random seed')
35 |
36 | args = parser.parse_args()
37 | print(args)
38 |
39 |
40 | torch.backends.cudnn.deterministic = True
41 | torch.backends.cudnn.benchmark = False
42 | torch.manual_seed(args.seed)
43 | np.random.seed(args.seed)
44 |
45 | train_dataset = get_zinc_data('train')
46 | test_dataset = get_zinc_data('test')
47 | val_dataset = get_zinc_data('val')
48 |
49 | train_loader = DataLoader(train_dataset, batch_size=args.batch,shuffle=True)
50 | test_loader = DataLoader(test_dataset, batch_size=args.batch,shuffle=False)
51 | val_loader = DataLoader(val_dataset, batch_size=args.batch,shuffle=False)
52 |
53 | model = GraphCON_GCN(nfeat=train_dataset.data.num_features,nhid=args.nhid,nclass=1,
54 | nlayers=args.nlayers, dt=1., alpha=args.alpha,
55 | gamma=args.gamma).to(args.device)
56 |
57 | nparams = 0
58 | for p in model.parameters():
59 | nparams += p.numel()
60 | print('number of parameters: ',nparams)
61 |
62 | optimizer = optim.Adam(model.parameters(),lr=args.lr)
63 | lf = torch.nn.L1Loss()
64 |
65 | patience = 0
66 | best_eval = 1000000
67 |
68 | def test(loader):
69 | model.eval()
70 | error = 0
71 | with torch.no_grad():
72 | for data in loader:
73 | data = data.to(args.device)
74 | output = model(data)
75 | error += (output - data.y).abs().sum().item()
76 | return error / len(loader.dataset)
77 |
78 | for epoch in range(args.epochs):
79 | model.train()
80 | for i, data in enumerate(train_loader):
81 | data = data.to(args.device)
82 | optimizer.zero_grad()
83 | out = model(data)
84 | loss = lf(out,data.y)
85 | loss.backward()
86 | optimizer.step()
87 |
88 | val_loss = test(val_loader)
89 |
90 | f = open('zinc_graphcon_gcn_log.txt', 'a')
91 | f.write('validation loss: ' + str(val_loss) + '\n')
92 | f.close()
93 |
94 | print('epoch: ',epoch,'validation loss: ',val_loss)
95 |
96 | if (val_loss < best_eval):
97 | best_eval = val_loss
98 | best_test_loss = test(test_loader)
99 |
100 | elif (val_loss >= best_eval and (epoch + 1) >= args.start_reduce):
101 | patience += 1
102 |
103 | if (epoch + 1) >= args.start_reduce and patience == args.reduce_point:
104 | patience = 0
105 | args.lr /= 2.
106 | for param_group in optimizer.param_groups:
107 | param_group['lr'] = args.lr
108 |
109 | if (epoch > 25):
110 | if (val_loss > 20. or args.lr < 1e-5):
111 | break
112 |
113 | f = open('zinc_graphcon_gcn_log.txt', 'a')
114 | f.write('final test loss: ' + str(best_test_loss) + '\n')
115 | f.close()
116 | print('final test loss: ',best_test_loss)
--------------------------------------------------------------------------------
/src/definitions.py:
--------------------------------------------------------------------------------
1 | import os
2 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
--------------------------------------------------------------------------------
/src/heterophilic_graphs/best_params.py:
--------------------------------------------------------------------------------
1 | best_params_dict = {'cornell': {'model': 'GraphCON_GCN', 'lr': 0.00721, 'nhid': 256, 'alpha': 0, 'gamma': 0, 'nlayers': 1, 'dropout': 0.15, 'weight_decay': 0.0012708787092020595, 'res_version': 1},
2 | 'wisconsin': {'model': 'GraphCON_GCN', 'lr': 0.00356, 'nhid': 64, 'alpha': 0, 'gamma': 0, 'nlayers': 2, 'dropout': 0.23, 'weight_decay': 0.008126619200091946, 'res_version': 2},
3 | 'texas': {'model': 'GraphCON_GCN', 'lr': 0.00155, 'nhid': 256, 'alpha': 0, 'gamma': 0, 'nlayers': 2, 'dropout': 0.68, 'weight_decay': 0.0008549327066268375, 'res_version': 2}
4 | }
--------------------------------------------------------------------------------
/src/heterophilic_graphs/data_handling.py:
--------------------------------------------------------------------------------
1 | from torch_geometric.datasets import WebKB
2 | import torch
3 | import numpy as np
4 |
5 | DATA_PATH = '../../data'
6 |
7 | def get_data(name, split=0):
8 | path = '../../data/'+name
9 | dataset = WebKB(path,name=name)
10 |
11 | data = dataset[0]
12 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz')
13 | train_mask = splits_file['train_mask']
14 | val_mask = splits_file['val_mask']
15 | test_mask = splits_file['test_mask']
16 |
17 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool)
18 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool)
19 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool)
20 |
21 | return data
22 |
--------------------------------------------------------------------------------
/src/heterophilic_graphs/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from torch_geometric.nn import GCNConv, GATConv
6 |
7 | class GraphCON_GCN(nn.Module):
8 | def __init__(self, nfeat, nhid, nclass, dropout, nlayers, dt=1., alpha=1., gamma=1., res_version=1):
9 | super(GraphCON_GCN, self).__init__()
10 | self.dropout = dropout
11 | self.nhid = nhid
12 | self.nlayers = nlayers
13 | self.enc = nn.Linear(nfeat,nhid)
14 | self.conv = GCNConv(nhid, nhid)
15 | self.dec = nn.Linear(nhid,nclass)
16 | self.res = nn.Linear(nhid,nhid)
17 | if(res_version==1):
18 | self.residual = self.res_connection_v1
19 | else:
20 | self.residual = self.res_connection_v2
21 | self.dt = dt
22 | self.act_fn = nn.ReLU()
23 | self.alpha = alpha
24 | self.gamma = gamma
25 | self.reset_params()
26 |
27 | def reset_params(self):
28 | for name, param in self.named_parameters():
29 | if 'weight' in name and 'emb' not in name and 'out' not in name:
30 | stdv = 1. / math.sqrt(self.nhid)
31 | param.data.uniform_(-stdv, stdv)
32 |
33 | def res_connection_v1(self, X):
34 | res = - self.res(self.conv.lin(X))
35 | return res
36 |
37 | def res_connection_v2(self, X):
38 | res = - self.conv.lin(X) + self.res(X)
39 | return res
40 |
41 | def forward(self, data):
42 | input = data.x
43 | edge_index = data.edge_index
44 | input = F.dropout(input, self.dropout, training=self.training)
45 | Y = self.act_fn(self.enc(input))
46 | X = Y
47 | Y = F.dropout(Y, self.dropout, training=self.training)
48 | X = F.dropout(X, self.dropout, training=self.training)
49 |
50 | for i in range(self.nlayers):
51 | Y = Y + self.dt*(self.act_fn(self.conv(X,edge_index) + self.residual(X)) - self.alpha*Y - self.gamma*X)
52 | X = X + self.dt*Y
53 | Y = F.dropout(Y, self.dropout, training=self.training)
54 | X = F.dropout(X, self.dropout, training=self.training)
55 |
56 | X = self.dec(X)
57 |
58 | return X
59 |
60 | class GraphCON_GAT(nn.Module):
61 | def __init__(self, nfeat, nhid, nclass, nlayers, dropout, dt=1., alpha=1., gamma=1., nheads=4):
62 | super(GraphCON_GAT, self).__init__()
63 | self.alpha = alpha
64 | self.gamma = gamma
65 | self.dropout = dropout
66 | self.nheads = nheads
67 | self.nhid = nhid
68 | self.nlayers = nlayers
69 | self.act_fn = nn.ReLU()
70 | self.res = nn.Linear(nhid, nheads * nhid)
71 | self.enc = nn.Linear(nfeat,nhid)
72 | self.conv = GATConv(nhid, nhid, heads=nheads)
73 | self.dec = nn.Linear(nhid,nclass)
74 | self.dt = dt
75 |
76 | def res_connection(self, X):
77 | res = self.res(X)
78 | return res
79 |
80 | def forward(self, data):
81 | input = data.x
82 | n_nodes = input.size(0)
83 | edge_index = data.edge_index
84 | input = F.dropout(input, self.dropout, training=self.training)
85 | Y = self.act_fn(self.enc(input))
86 | X = Y
87 | Y = F.dropout(Y, self.dropout, training=self.training)
88 | X = F.dropout(X, self.dropout, training=self.training)
89 |
90 | for i in range(self.nlayers):
91 | Y = Y + self.dt*(F.elu(self.conv(X, edge_index) + self.res_connection(X)).view(n_nodes, -1, self.nheads).mean(dim=-1) - self.alpha*Y - self.gamma*X)
92 | X = X + self.dt*Y
93 | Y = F.dropout(Y, self.dropout, training=self.training)
94 | X = F.dropout(X, self.dropout, training=self.training)
95 |
96 | X = self.dec(X)
97 |
98 | return X
99 |
--------------------------------------------------------------------------------
/src/heterophilic_graphs/run_GNN.py:
--------------------------------------------------------------------------------
1 | from data_handling import get_data
2 | import numpy as np
3 | import torch.optim as optim
4 | from models import *
5 | from torch import nn
6 | from best_params import best_params_dict
7 |
8 | import argparse
9 |
10 | def train_GNN(opt,split):
11 | data = get_data(opt['dataset'],split)
12 |
13 | best_eval = 10000
14 | bad_counter = 0
15 | best_test_acc = 0
16 |
17 | if opt['model'] == 'GraphCON_GCN':
18 | model = GraphCON_GCN(nfeat=data.num_features,nhid=opt['nhid'],nclass=5,
19 | dropout=opt['drop'],nlayers=opt['nlayers'],dt=1.,
20 | alpha=opt['alpha'],gamma=opt['gamma'],res_version=opt['res_version']).to(opt['device'])
21 | elif opt['model'] == 'GraphCON_GAT':
22 | model = GraphCON_GAT(nfeat=data.num_features, nhid=opt['nhid'], nclass=5,
23 | dropout=opt['drop'], nlayers=opt['nlayers'], dt=1.,
24 | alpha=opt['alpha'], gamma=opt['gamma'],nheads=opt['nheads']).to(opt['device'])
25 |
26 | optimizer = optim.Adam(model.parameters(),lr=opt['lr'],weight_decay=opt['weight_decay'])
27 | lf = nn.CrossEntropyLoss()
28 |
29 | @torch.no_grad()
30 | def test(model, data):
31 | model.eval()
32 | logits, accs, losses = model(data), [], []
33 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
34 | loss = lf(out[mask], data.y.squeeze()[mask])
35 | pred = logits[mask].max(1)[1]
36 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
37 | accs.append(acc)
38 | losses.append(loss.item())
39 | return accs, losses
40 |
41 | for epoch in range(opt['epochs']):
42 | model.train()
43 | optimizer.zero_grad()
44 | out = model(data.to(opt['device']))
45 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask])
46 | loss.backward()
47 | optimizer.step()
48 |
49 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model,data)
50 |
51 | if (val_loss < best_eval):
52 | best_eval = val_loss
53 | best_test_acc = test_acc
54 | else:
55 | bad_counter += 1
56 |
57 | if(bad_counter==opt['patience']):
58 | break
59 |
60 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
61 | print(log.format(split, epoch, train_acc, val_acc, test_acc))
62 |
63 | return best_test_acc
64 |
65 | if __name__ == '__main__':
66 | parser = argparse.ArgumentParser(description='training parameters')
67 | parser.add_argument('--dataset', type=str, default='texas',
68 | help='cornell, wisconsin, texas')
69 | parser.add_argument('--model', type=str, default='GraphCON_GCN',
70 | help='GraphCON_GCN, GraphCON_GAT')
71 | parser.add_argument('--nhid', type=int, default=64,
72 | help='number of hidden node features')
73 | parser.add_argument('--nlayers', type=int, default=5,
74 | help='number of layers')
75 | parser.add_argument('--alpha', type=float, default=1.,
76 | help='alpha parameter of graphCON')
77 | parser.add_argument('--gamma', type=float, default=1.,
78 | help='gamma parameter of graphCON')
79 | parser.add_argument('--nheads', type=int, default=4,
80 | help='number of attention heads for GraphCON-GAT')
81 | parser.add_argument('--epochs', type=int, default=1500,
82 | help='max epochs')
83 | parser.add_argument('--patience', type=int, default=100,
84 | help='patience')
85 | parser.add_argument('--lr', type=float, default=0.001,
86 | help='learning rate')
87 | parser.add_argument('--drop', type=float, default=0.3,
88 | help='dropout rate')
89 | parser.add_argument('--res_version', type=int, default=1,
90 | help='version of residual connection')
91 | parser.add_argument('--weight_decay', type=float, default=1e-5,
92 | help='weight_decay')
93 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
94 | help='computing device')
95 |
96 | args = parser.parse_args()
97 | cmd_opt = vars(args)
98 |
99 | best_opt = best_params_dict[cmd_opt['dataset']]
100 | opt = {**cmd_opt, **best_opt}
101 | print(opt)
102 |
103 | n_splits = 10
104 |
105 | best = []
106 | for split in range(n_splits):
107 | best.append(train_GNN(opt,split))
108 | print('Mean test accuracy: ', np.mean(np.array(best)*100),'std: ', np.std(np.array(best)*100))
109 |
110 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/GNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from base_classes import BaseGNN
5 | from model_configurations import set_block, set_function
6 |
7 |
8 | # Define the GNN model.
9 | class GNN(BaseGNN):
10 | def __init__(self, opt, dataset, device=torch.device('cpu')):
11 | super(GNN, self).__init__(opt, dataset, device)
12 | self.f = set_function(opt)
13 | block = set_block(opt)
14 | time_tensor = torch.tensor([0, self.T]).to(device)
15 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device)
16 |
17 | def forward(self, x):
18 | # Encode each node based on its feature.
19 | if self.opt['use_labels']:
20 | y = x[:, self.num_features:]
21 | x = x[:, :self.num_features]
22 | x = F.dropout(x, self.opt['input_dropout'], training=self.training)
23 | x = self.m1(x)
24 | if self.opt['use_mlp']:
25 | x = F.dropout(x, self.opt['dropout'], training=self.training)
26 | x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training)
27 | x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training)
28 | # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper
29 |
30 | if self.opt['use_labels']:
31 | x = torch.cat([x, y], dim=-1)
32 |
33 | if self.opt['batch_norm']:
34 | x = self.bn_in(x)
35 |
36 | # Solve the initial value problem of the ODE.
37 | if self.opt['augment']:
38 | c_aux = torch.zeros(x.shape).to(self.device)
39 | x = torch.cat([x, c_aux], dim=1)
40 |
41 | self.odeblock.set_x0(x)
42 |
43 | if self.training and self.odeblock.nreg > 0:
44 | z, self.reg_states = self.odeblock(x)
45 | else:
46 | z = self.odeblock(x)
47 |
48 | if self.opt['augment']:
49 | z = torch.split(z, x.shape[1] // 2, dim=1)[0]
50 |
51 | # Activation.
52 | z = F.relu(z)
53 |
54 | if self.opt['fc_out']:
55 | z = self.fc(z)
56 | z = F.relu(z)
57 |
58 | # Dropout.
59 | z = F.dropout(z, self.opt['dropout'], training=self.training)
60 |
61 | # Decode each node embedding to get node label.
62 | z = self.m2(z)
63 | return z
64 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/GNN_early.py:
--------------------------------------------------------------------------------
1 | """
2 | A GNN used at test time that supports early stopping during the integrator
3 | """
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | import argparse
8 | from torch_geometric.nn import GCNConv, ChebConv # noqa
9 | import time
10 | from data import get_dataset
11 | # from run_GNN import get_optimizer, train, test
12 | from early_stop_solver import EarlyStopInt
13 | from base_classes import BaseGNN
14 | from model_configurations import set_block, set_function
15 |
16 |
17 | class GNNEarly(BaseGNN):
18 | def __init__(self, opt, dataset, device=torch.device('cpu')):
19 | super(GNNEarly, self).__init__(opt, dataset, device)
20 | self.f = set_function(opt)
21 | block = set_block(opt)
22 | time_tensor = torch.tensor([0, self.T]).to(device)
23 | # self.regularization_fns = ()
24 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device)
25 | # overwrite the test integrator with this custom one
26 | self.odeblock.test_integrator = EarlyStopInt(self.T, opt, device)
27 | # if opt['adjoint']:
28 | # from torchdiffeq import odeint_adjoint as odeint
29 | # else:
30 | # from torchdiffeq import odeint
31 | # self.odeblock.train_integrator = odeint
32 |
33 | self.set_solver_data(dataset.data)
34 |
35 | def set_solver_m2(self):
36 | if self.odeblock.test_integrator.m2 is None:
37 | self.odeblock.test_integrator.m2 = self.m2
38 | else:
39 | self.odeblock.test_integrator.m2.weight.data = self.m2.weight.data
40 | self.odeblock.test_integrator.m2.bias.data = self.m2.bias.data
41 |
42 | def set_solver_data(self, data):
43 | self.odeblock.test_integrator.data = data
44 |
45 | def forward(self, x):
46 | # Encode each node based on its feature.
47 | if self.opt['use_labels']:
48 | y = x[:, self.num_features:]
49 | x = x[:, :self.num_features]
50 | x = F.dropout(x, self.opt['input_dropout'], training=self.training)
51 | x = self.m1(x)
52 |
53 | if self.opt['use_labels']:
54 | x = torch.cat([x, y], dim=-1)
55 |
56 | if self.opt['batch_norm']:
57 | x = self.bn_in(x)
58 |
59 | # Solve the initial value problem of the ODE.
60 | if self.opt['augment']:
61 | c_aux = torch.zeros(x.shape).to(self.device)
62 | x = torch.cat([x, c_aux], dim=1)
63 |
64 | x = torch.cat([x, x], dim=-1)
65 | self.odeblock.set_x0(x)
66 | self.set_solver_m2()
67 | if self.training and self.odeblock.nreg > 0:
68 | z, self.reg_states = self.odeblock(x)
69 | else:
70 | z = self.odeblock(x)
71 |
72 | if self.opt['augment']:
73 | z = torch.split(z, x.shape[1] // 2, dim=1)[0]
74 |
75 | z = z[:,self.opt['hidden_dim']:]
76 |
77 | # Activation.
78 | z = F.relu(z)
79 |
80 | if self.opt['fc_out']:
81 | z = self.fc(z)
82 | z = F.relu(z)
83 |
84 | # Dropout.
85 | z = F.dropout(z, self.opt['dropout'], training=self.training)
86 |
87 | # Decode each node embedding to get node label.
88 | z = self.m2(z)
89 | return z
90 |
91 |
92 | def main(opt):
93 | dataset = get_dataset(opt, '../data', False)
94 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
95 | model, data = GNNEarly(opt, dataset, device).to(device), dataset.data.to(device)
96 | print(opt)
97 | # todo for some reason the submodule parameters inside the attention module don't show up when running on GPU.
98 | parameters = [p for p in model.parameters() if p.requires_grad]
99 | optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay'])
100 | best_val_acc = test_acc = best_epoch = 0
101 | best_val_acc_int = best_test_acc_int = best_epoch_int = 0
102 | for epoch in range(1, opt['epoch']):
103 | start_time = time.time()
104 | loss = train(model, optimizer, data)
105 | train_acc, val_acc, tmp_test_acc = test(model, data)
106 | val_acc_int = model.odeblock.test_integrator.solver.best_val
107 | tmp_test_acc_int = model.odeblock.test_integrator.solver.best_test
108 | # store best stuff inside integrator forward pass
109 | if val_acc_int > best_val_acc_int:
110 | best_val_acc_int = val_acc_int
111 | test_acc_int = tmp_test_acc_int
112 | best_epoch_int = epoch
113 | # store best stuff at the end of integrator forward pass
114 | if val_acc > best_val_acc:
115 | best_val_acc = val_acc
116 | test_acc = tmp_test_acc
117 | best_epoch = epoch
118 | log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
119 | print(
120 | log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, val_acc, tmp_test_acc))
121 | log = 'Performance inside integrator Val: {:.4f}, Test: {:.4f}'
122 | print(log.format(val_acc_int, tmp_test_acc_int))
123 | # print(
124 | # log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, best_val_acc, test_acc))
125 | print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc, test_acc, best_epoch))
126 | print('best in integrator val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc_int,
127 | test_acc_int,
128 | best_epoch_int))
129 |
130 |
131 | if __name__ == '__main__':
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument('--use_cora_defaults', action='store_true',
134 | help='Whether to run with best params for cora. Overrides the choice of dataset')
135 | parser.add_argument('--dataset', type=str, default='Cora',
136 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS')
137 | parser.add_argument('--data_norm', type=str, default='rw',
138 | help='rw for random walk, gcn for symmetric gcn norm')
139 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.')
140 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.')
141 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
142 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.')
143 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
144 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization')
145 | parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.')
146 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.')
147 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.')
148 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.')
149 | parser.add_argument('--augment', action='store_true',
150 | help='double the length of the feature vector by appending zeros to stabilist ODE learning')
151 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha')
152 | parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true', help='apply sigmoid before multiplying by alpha')
153 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta')
154 | parser.add_argument('--block', type=str, default='constant', help='constant, mixed, attention, SDE')
155 | parser.add_argument('--function', type=str, default='laplacian', help='laplacian, transformer, dorsey, GAT, SDE')
156 | # ODE args
157 | parser.add_argument('--method', type=str, default='dopri5',
158 | help="set the numerical solver: dopri5, euler, rk4, midpoint")
159 | parser.add_argument('--step_size', type=float, default=1, help='fixed step size when using fixed step solvers e.g. rk4')
160 | parser.add_argument('--max_iters', type=int, default=100,
161 | help='fixed step size when using fixed step solvers e.g. rk4')
162 | parser.add_argument(
163 | "--adjoint_method", type=str, default="adaptive_heun",
164 | help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint"
165 | )
166 | parser.add_argument('--adjoint', dest='adjoint', action='store_true', help='use the adjoint ODE method to reduce memory footprint')
167 | parser.add_argument('--adjoint_step_size', type=float, default=1, help='fixed step size when using fixed step adjoint solvers e.g. rk4')
168 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol')
169 | parser.add_argument("--tol_scale_adjoint", type=float, default=1.0,
170 | help="multiplier for adjoint_atol and adjoint_rtol")
171 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run')
172 | parser.add_argument('--add_source', dest='add_source', action='store_true',
173 | help='If try get rid of alpha param and the beta*x0 source term')
174 | # SDE args
175 | parser.add_argument('--dt_min', type=float, default=1e-5, help='minimum timestep for the SDE solver')
176 | parser.add_argument('--dt', type=float, default=1e-3, help='fixed step size')
177 | parser.add_argument('--adaptive', dest='adaptive', action='store_true', help='use adaptive step sizes')
178 | # Attention args
179 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2,
180 | help='slope of the negative part of the leaky relu used in attention')
181 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights')
182 | parser.add_argument('--heads', type=int, default=4, help='number of attention heads')
183 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols')
184 | parser.add_argument('--attention_dim', type=int, default=64,
185 | help='the size to project x to before calculating att scores')
186 | parser.add_argument('--mix_features', dest='mix_features', action='store_true',
187 | help='apply a feature transformation xW to the ODE')
188 | parser.add_argument("--max_nfe", type=int, default=1000, help="Maximum number of function evaluations in an epoch. Stiff ODEs will hang if not set.")
189 | parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true', help="multiply attention scores by edge weights before softmax")
190 | # regularisation args
191 | parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2")
192 | parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2")
193 |
194 | parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2")
195 | parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2")
196 |
197 | # rewiring args
198 | parser.add_argument('--rewiring', type=str, default=None, help="two_hop, gdc")
199 | parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff")
200 | parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk")
201 | parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk")
202 | parser.add_argument('--gdc_threshold', type=float, default=0.0001, help="obove this edge weight, keep edges when using threshold")
203 | parser.add_argument('--gdc_avg_degree', type=int, default=64,
204 | help="if gdc_threshold is not given can be calculated by specifying avg degree")
205 | parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability")
206 | parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for")
207 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model')
208 |
209 | args = parser.parse_args()
210 |
211 | opt = vars(args)
212 |
213 | main(opt)
214 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/README.md:
--------------------------------------------------------------------------------
1 | Code adapted from [GRAND](https://github.com/twitter-research/graph-neural-pde).
2 | To run the experiments, simply do:
3 |
4 | ```
5 | python run_GNN.py --dataset
6 | ```
7 |
8 | where is either `Cora`, `Citeseer` or `Pubmed`.
9 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/base_classes.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.nn.conv import MessagePassing
4 | from utils import Meter
5 | from regularized_ODE_function import RegularizedODEfunc
6 | import regularized_ODE_function as reg_lib
7 | import six
8 |
9 |
10 | REGULARIZATION_FNS = {
11 | "kinetic_energy": reg_lib.quadratic_cost,
12 | "jacobian_norm2": reg_lib.jacobian_frobenius_regularization_fn,
13 | "total_deriv": reg_lib.total_derivative,
14 | "directional_penalty": reg_lib.directional_derivative
15 | }
16 |
17 |
18 | def create_regularization_fns(args):
19 | regularization_fns = []
20 | regularization_coeffs = []
21 |
22 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS):
23 | if args[arg_key] is not None:
24 | regularization_fns.append(reg_fn)
25 | regularization_coeffs.append(args[arg_key])
26 |
27 | regularization_fns = regularization_fns
28 | regularization_coeffs = regularization_coeffs
29 | return regularization_fns, regularization_coeffs
30 |
31 |
32 | class ODEblock(nn.Module):
33 | def __init__(self, odefunc, regularization_fns, opt, data, device, t):
34 | super(ODEblock, self).__init__()
35 | self.opt = opt
36 | self.t = t
37 |
38 | self.aug_dim = 2 if opt['augment'] else 1
39 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
40 |
41 | self.nreg = len(regularization_fns)
42 | self.reg_odefunc = RegularizedODEfunc(self.odefunc, regularization_fns)
43 |
44 | if opt['adjoint']:
45 | from torchdiffeq import odeint_adjoint as odeint
46 | else:
47 | from torchdiffeq import odeint
48 | self.train_integrator = odeint
49 | self.test_integrator = None
50 | self.set_tol()
51 |
52 | def set_x0(self, x0):
53 | self.odefunc.x0 = x0.clone().detach()
54 | self.reg_odefunc.odefunc.x0 = x0.clone().detach()
55 |
56 | def set_tol(self):
57 | self.atol = self.opt['tol_scale'] * 1e-7
58 | self.rtol = self.opt['tol_scale'] * 1e-9
59 | if self.opt['adjoint']:
60 | self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7
61 | self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9
62 |
63 | def reset_tol(self):
64 | self.atol = 1e-7
65 | self.rtol = 1e-9
66 | self.atol_adjoint = 1e-7
67 | self.rtol_adjoint = 1e-9
68 |
69 | def set_time(self, time):
70 | self.t = torch.tensor([0, time]).to(self.device)
71 |
72 | def __repr__(self):
73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
74 | + ")"
75 |
76 |
77 | class ODEFunc(MessagePassing):
78 |
79 | # currently requires in_features = out_features
80 | def __init__(self, opt, data, device):
81 | super(ODEFunc, self).__init__()
82 | self.opt = opt
83 | self.device = device
84 | self.edge_index = None
85 | self.edge_weight = None
86 | self.attention_weights = None
87 | self.attention_weights_2 = None
88 | self.alpha_train = nn.Parameter(torch.tensor(0.0))#nn.Parameter(torch.zeros(opt['hidden_dim']))
89 | self.beta_train = nn.Parameter(torch.tensor(0.0))
90 | self.beta_train2 = nn.Parameter(torch.tensor(0.0))
91 | self.x0 = None
92 | self.nfe = 0
93 | self.alpha_sc = nn.Parameter(torch.ones(1))
94 | self.beta_sc = nn.Parameter(torch.ones(1))
95 |
96 | def __repr__(self):
97 | return self.__class__.__name__
98 |
99 |
100 | class BaseGNN(MessagePassing):
101 | def __init__(self, opt, dataset, device=torch.device('cpu')):
102 | super(BaseGNN, self).__init__()
103 | self.opt = opt
104 | self.T = opt['time']
105 | self.num_classes = dataset.num_classes
106 | self.num_features = dataset.data.num_features
107 | self.device = device
108 | self.fm = Meter()
109 | self.bm = Meter()
110 | self.m1 = nn.Linear(dataset.data.num_features, opt['hidden_dim'])
111 | if self.opt['use_mlp']:
112 | self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
113 | self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
114 | if opt['use_labels']:
115 | # todo - fastest way to propagate this everywhere, but error prone - refactor later
116 | opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes
117 | else:
118 | self.hidden_dim = opt['hidden_dim']
119 | if opt['fc_out']:
120 | self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim'])
121 | self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes)
122 | if self.opt['batch_norm']:
123 | self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim'])
124 | self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim'])
125 |
126 | self.regularization_fns, self.regularization_coeffs = create_regularization_fns(self.opt)
127 |
128 | def getNFE(self):
129 | return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe
130 |
131 | def resetNFE(self):
132 | self.odeblock.odefunc.nfe = 0
133 | self.odeblock.reg_odefunc.odefunc.nfe = 0
134 |
135 | def reset(self):
136 | self.m1.reset_parameters()
137 | self.m2.reset_parameters()
138 |
139 | def __repr__(self):
140 | return self.__class__.__name__
141 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/block_constant.py:
--------------------------------------------------------------------------------
1 | from base_classes import ODEblock
2 | import torch
3 | from utils import get_rw_adj, gcn_norm_fill_val
4 |
5 | class ConstantODEblock(ODEblock):
6 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])):
7 | super(ConstantODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t)
8 |
9 | self.aug_dim = 2 if opt['augment'] else 1
10 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
11 | if opt['data_norm'] == 'rw':
12 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
13 | fill_value=opt['self_loop_weight'],
14 | num_nodes=data.num_nodes,
15 | dtype=data.x.dtype)
16 | else:
17 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr,
18 | fill_value=opt['self_loop_weight'],
19 | num_nodes=data.num_nodes,
20 | dtype=data.x.dtype)
21 | self.odefunc.edge_index = edge_index.to(device)
22 | self.odefunc.edge_weight = edge_weight.to(device)
23 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight
24 |
25 | if (opt['method'] == 'symplectic_euler' or opt['method'] == 'leapfrog'):
26 | from odeint_geometric import odeint
27 | elif opt['adjoint']:
28 | from torchdiffeq import odeint_adjoint as odeint
29 | else:
30 | from torchdiffeq import odeint
31 |
32 | self.train_integrator = odeint
33 | self.test_integrator = odeint
34 | self.set_tol()
35 |
36 | def forward(self, x):
37 | t = self.t.type_as(x)
38 |
39 | integrator = self.train_integrator if self.training else self.test_integrator
40 |
41 | reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) )
42 |
43 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
44 | state = (x,) + reg_states if self.training and self.nreg > 0 else x
45 |
46 | if self.opt["adjoint"] and self.training:
47 | state_dt = integrator(
48 | func, state, t,
49 | method=self.opt['method'],
50 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
51 | adjoint_method=self.opt['adjoint_method'],
52 | adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']),
53 | atol=self.atol,
54 | rtol=self.rtol,
55 | adjoint_atol=self.atol_adjoint,
56 | adjoint_rtol=self.rtol_adjoint)
57 | else:
58 | state_dt = integrator(
59 | func, state, t,
60 | method=self.opt['method'],
61 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']),
62 | atol=self.atol,
63 | rtol=self.rtol)
64 |
65 | if self.training and self.nreg > 0:
66 | z = state_dt[0][1]
67 | reg_states = tuple( st[1] for st in state_dt[1:] )
68 | return z, reg_states
69 | else:
70 | z = state_dt[1]
71 | return z
72 |
73 | def __repr__(self):
74 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
75 | + ")"
76 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/block_transformer_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from function_transformer_attention import SpGraphTransAttentionLayer
3 | from base_classes import ODEblock
4 | from utils import get_rw_adj
5 |
6 |
7 | class AttODEblock(ODEblock):
8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5):
9 | super(AttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t)
10 |
11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device)
12 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr
13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1,
14 | fill_value=opt['self_loop_weight'],
15 | num_nodes=data.num_nodes,
16 | dtype=data.x.dtype)
17 | self.odefunc.edge_index = edge_index.to(device)
18 | self.odefunc.edge_weight = edge_weight.to(device)
19 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight
20 |
21 | if(opt['method']=='symplectic_euler' or opt['method']=='leapfrog'):
22 | from odeint_geometric import odeint
23 | elif opt['adjoint']:
24 | from torchdiffeq import odeint_adjoint as odeint
25 | else:
26 | from torchdiffeq import odeint
27 | self.train_integrator = odeint
28 | self.test_integrator = odeint
29 | self.set_tol()
30 | # parameter trading off between attention and the Laplacian
31 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt,
32 | device, edge_weights=self.odefunc.edge_weight).to(device)
33 |
34 | def get_attention_weights(self, x):
35 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index)
36 | return attention
37 |
38 | def forward(self, x_all):
39 | x = x_all[:,:self.opt['hidden_dim']]
40 | y = x_all[:, self.opt['hidden_dim']:]
41 | t = self.t.type_as(x)
42 | self.odefunc.attention_weights = self.get_attention_weights(y)
43 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights
44 | integrator = self.train_integrator if self.training else self.test_integrator
45 |
46 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg))
47 |
48 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc
49 | state = (x_all,) + reg_states if self.training and self.nreg > 0 else x_all
50 |
51 | if self.opt["adjoint"] and self.training:
52 | state_dt = integrator(
53 | func, state, t,
54 | method=self.opt['method'],
55 | options={'step_size': self.opt['step_size']},
56 | adjoint_method=self.opt['adjoint_method'],
57 | adjoint_options={'step_size': self.opt['adjoint_step_size']},
58 | atol=self.atol,
59 | rtol=self.rtol,
60 | adjoint_atol=self.atol_adjoint,
61 | adjoint_rtol=self.rtol_adjoint)
62 | else:
63 | state_dt = integrator(
64 | func, state, t,
65 | method=self.opt['method'],
66 | options={'step_size': self.opt['step_size']},
67 | atol=self.atol,
68 | rtol=self.rtol)
69 |
70 | if self.training and self.nreg > 0:
71 | z = state_dt[0][1]
72 | reg_states = tuple(st[1] for st in state_dt[1:])
73 | return z, reg_states
74 | else:
75 | z = state_dt[1]
76 | return z
77 |
78 | def __repr__(self):
79 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \
80 | + ")"
81 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Code partially copied from 'Diffusion Improves Graph Learning' repo https://github.com/klicperajo/gdc/blob/master/data.py
3 | """
4 |
5 | import os
6 |
7 | import numpy as np
8 |
9 | import torch
10 | from torch_geometric.data import Data, InMemoryDataset
11 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor
12 | from ogb.nodeproppred import PygNodePropPredDataset
13 | import torch_geometric.transforms as T
14 | from torch_geometric.utils import to_undirected
15 |
16 | DATA_PATH = '../../data'
17 |
18 |
19 | def get_dataset(opt: dict, data_dir, use_lcc: bool = False) -> InMemoryDataset:
20 | ds = opt['dataset']
21 | path = os.path.join(data_dir, ds)
22 | if ds in ['Cora', 'Citeseer', 'Pubmed']:
23 | dataset = Planetoid(path, ds)
24 | elif ds in ['Computers', 'Photo']:
25 | dataset = Amazon(path, ds)
26 | elif ds == 'CoauthorCS':
27 | dataset = Coauthor(path, 'CS')
28 | elif ds == 'ogbn-arxiv':
29 | dataset = PygNodePropPredDataset(name=ds, root=path,
30 | transform=T.ToSparseTensor())
31 | use_lcc = False # never need to calculate the lcc with ogb datasets
32 | else:
33 | raise Exception('Unknown dataset.')
34 |
35 | if use_lcc:
36 | lcc = get_largest_connected_component(dataset)
37 |
38 | x_new = dataset.data.x[lcc]
39 | y_new = dataset.data.y[lcc]
40 |
41 | row, col = dataset.data.edge_index.numpy()
42 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc]
43 | edges = remap_edges(edges, get_node_mapper(lcc))
44 |
45 | data = Data(
46 | x=x_new,
47 | edge_index=torch.LongTensor(edges),
48 | y=y_new,
49 | train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
50 | test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool),
51 | val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool)
52 | )
53 | dataset.data = data
54 |
55 | train_mask_exists = True
56 | try:
57 | dataset.data.train_mask
58 | except AttributeError:
59 | train_mask_exists = False
60 |
61 | if ds == 'ogbn-arxiv':
62 | split_idx = dataset.get_idx_split()
63 | ei = to_undirected(dataset.data.edge_index)
64 | data = Data(
65 | x=dataset.data.x,
66 | edge_index=ei,
67 | y=dataset.data.y,
68 | train_mask=split_idx['train'],
69 | test_mask=split_idx['test'],
70 | val_mask=split_idx['valid'])
71 | dataset.data = data
72 | train_mask_exists = True
73 |
74 | if use_lcc or not train_mask_exists:
75 | dataset.data = set_train_val_test_split(
76 | 12345,
77 | dataset.data,
78 | num_development=5000 if ds == "CoauthorCS" else 1500)
79 |
80 | return dataset
81 |
82 |
83 | def get_component(dataset: InMemoryDataset, start: int = 0) -> set:
84 | visited_nodes = set()
85 | queued_nodes = set([start])
86 | row, col = dataset.data.edge_index.numpy()
87 | while queued_nodes:
88 | current_node = queued_nodes.pop()
89 | visited_nodes.update([current_node])
90 | neighbors = col[np.where(row == current_node)[0]]
91 | neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes]
92 | queued_nodes.update(neighbors)
93 | return visited_nodes
94 |
95 |
96 | def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray:
97 | remaining_nodes = set(range(dataset.data.x.shape[0]))
98 | comps = []
99 | while remaining_nodes:
100 | start = min(remaining_nodes)
101 | comp = get_component(dataset, start)
102 | comps.append(comp)
103 | remaining_nodes = remaining_nodes.difference(comp)
104 | return np.array(list(comps[np.argmax(list(map(len, comps)))]))
105 |
106 |
107 | def get_node_mapper(lcc: np.ndarray) -> dict:
108 | mapper = {}
109 | counter = 0
110 | for node in lcc:
111 | mapper[node] = counter
112 | counter += 1
113 | return mapper
114 |
115 |
116 | def remap_edges(edges: list, mapper: dict) -> list:
117 | row = [e[0] for e in edges]
118 | col = [e[1] for e in edges]
119 | row = list(map(lambda x: mapper[x], row))
120 | col = list(map(lambda x: mapper[x], col))
121 | return [row, col]
122 |
123 |
124 | def set_train_val_test_split(
125 | seed: int,
126 | data: Data,
127 | num_development: int = 1500,
128 | num_per_class: int = 20) -> Data:
129 | rnd_state = np.random.RandomState(seed)
130 | num_nodes = data.y.shape[0]
131 | development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
132 | test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]
133 |
134 | train_idx = []
135 | rnd_state = np.random.RandomState(seed)
136 | for c in range(data.y.max() + 1):
137 | class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
138 | train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))
139 |
140 | val_idx = [i for i in development_idx if i not in train_idx]
141 |
142 | def get_mask(idx):
143 | mask = torch.zeros(num_nodes, dtype=torch.bool)
144 | mask[idx] = 1
145 | return mask
146 |
147 | data.train_mask = get_mask(train_idx)
148 | data.val_mask = get_mask(val_idx)
149 | data.test_mask = get_mask(test_idx)
150 |
151 | return data
152 |
153 |
154 | if __name__ == '__main__':
155 | # example for heterophilic datasets
156 | from heterophilic import get_fixed_splits
157 |
158 | opt = {'dataset': 'Cora', 'device': 'cpu'}
159 | dataset = get_dataset(opt)
160 | for fold in range(10):
161 | data = dataset[0]
162 | data = get_fixed_splits(data, opt['dataset'], fold)
163 | data = data.to(opt['device'])
164 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/early_stop_solver.py:
--------------------------------------------------------------------------------
1 | import torchdiffeq
2 | from torchdiffeq._impl.dopri5 import _DORMAND_PRINCE_SHAMPINE_TABLEAU, DPS_C_MID
3 | from torchdiffeq._impl.solvers import FixedGridODESolver
4 | import torch
5 | import abc
6 | from geometric_solvers import GeomtricFixedGridODESolver
7 | from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape
8 | import torch.nn.functional as F
9 | import copy
10 | from geometric_integrators import SymplecticEuler_step_func, Leapfrog_step_func
11 |
12 | from torchdiffeq._impl.interp import _interp_evaluate
13 | from torchdiffeq._impl.rk_common import RKAdaptiveStepsizeODESolver, rk4_alt_step_func
14 | from ogb.nodeproppred import Evaluator
15 |
16 |
17 | def run_evaluator(evaluator, data, y_pred):
18 | train_acc = evaluator.eval({
19 | 'y_true': data.y[data.train_mask],
20 | 'y_pred': y_pred[data.train_mask],
21 | })['acc']
22 | valid_acc = evaluator.eval({
23 | 'y_true': data.y[data.val_mask],
24 | 'y_pred': y_pred[data.val_mask],
25 | })['acc']
26 | test_acc = evaluator.eval({
27 | 'y_true': data.y[data.test_mask],
28 | 'y_pred': y_pred[data.test_mask],
29 | })['acc']
30 | return train_acc, valid_acc, test_acc
31 |
32 |
33 | class EarlyStopDopri5(RKAdaptiveStepsizeODESolver):
34 | order = 5
35 | tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU
36 | mid = DPS_C_MID
37 |
38 | def __init__(self, func, y0, rtol, atol, opt, **kwargs):
39 | super(EarlyStopDopri5, self).__init__(func, y0, rtol, atol, **kwargs)
40 |
41 | self.lf = torch.nn.CrossEntropyLoss()
42 | self.m2 = None
43 | self.data = None
44 | self.best_val = 0
45 | self.best_test = 0
46 | self.best_time = 0
47 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
48 | self.dataset = opt['dataset']
49 | if opt['dataset'] == 'ogbn-arxiv':
50 | self.lf = torch.nn.functional.nll_loss
51 | self.evaluator = Evaluator(name=opt['dataset'])
52 |
53 | def set_accs(self, train, val, test, time):
54 | self.best_train = train
55 | self.best_val = val
56 | self.best_test = test
57 | self.best_time = time.item()
58 |
59 | def integrate(self, t):
60 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
61 | solution[0] = self.y0
62 | t = t.to(self.dtype)
63 | self._before_integrate(t)
64 | new_t = t
65 | for i in range(1, len(t)):
66 | new_t, y = self.advance(t[i])
67 | solution[i] = y
68 | return new_t, solution
69 |
70 | def advance(self, next_t):
71 | """
72 | Takes steps dt to get to the next user specified time point next_t. In practice this goes past next_t and then interpolates
73 | :param next_t:
74 | :return: The state, x(next_t)
75 | """
76 | n_steps = 0
77 | while next_t > self.rk_state.t1:
78 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps)
79 | self.rk_state = self._adaptive_step(self.rk_state)
80 | n_steps += 1
81 | train_acc, val_acc, test_acc = self.evaluate(self.rk_state)
82 | if val_acc > self.best_val:
83 | self.set_accs(train_acc, val_acc, test_acc, self.rk_state.t1)
84 | new_t = next_t
85 | return (new_t, _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t))
86 |
87 | @torch.no_grad()
88 | def test(self, logits):
89 | accs = []
90 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
91 | pred = logits[mask].max(1)[1]
92 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
93 | accs.append(acc)
94 | return accs
95 |
96 | @torch.no_grad()
97 | def test_OGB(self, logits):
98 | evaluator = self.evaluator
99 | data = self.data
100 | y_pred = logits.argmax(dim=-1, keepdim=True)
101 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
102 | return [train_acc, valid_acc, test_acc]
103 |
104 | @torch.no_grad()
105 | def evaluate(self, rkstate):
106 | # Activation.
107 | z = rkstate.y1
108 | if not self.m2.in_features == z.shape[1]: # system has been augmented
109 | z = torch.split(z, self.m2.in_features, dim=1)[0]
110 | z = F.relu(z)
111 | z = self.m2(z)
112 | t0, t1 = float(self.rk_state.t0), float(self.rk_state.t1)
113 | if self.dataset == 'ogbn-arxiv':
114 | z = z.log_softmax(dim=-1)
115 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
116 | else:
117 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
118 | train_acc, val_acc, test_acc = self.ode_test(z)
119 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
120 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
121 | return train_acc, val_acc, test_acc
122 |
123 | def set_m2(self, m2):
124 | self.m2 = copy.deepcopy(m2)
125 |
126 | def set_data(self, data):
127 | if self.data is None:
128 | self.data = data
129 |
130 | class EarlyStopRK4(FixedGridODESolver):
131 | order = 4
132 |
133 | def __init__(self, func, y0, opt, eps=0, **kwargs):
134 | super(EarlyStopRK4, self).__init__(func, y0, **kwargs)
135 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device)
136 | self.lf = torch.nn.CrossEntropyLoss()
137 | self.m2 = None
138 | self.data = None
139 | self.best_val = 0
140 | self.best_test = 0
141 | self.best_time = 0
142 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
143 | self.dataset = opt['dataset']
144 | if opt['dataset'] == 'ogbn-arxiv':
145 | self.lf = torch.nn.functional.nll_loss
146 | self.evaluator = Evaluator(name=opt['dataset'])
147 |
148 | def _step_func(self, func, t, dt, t1, y):
149 | ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4]
150 | if int(ver) >= 22: # '0.2.2'
151 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, t1, y)
152 | else:
153 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, y)
154 |
155 | def set_accs(self, train, val, test, time):
156 | self.best_train = train
157 | self.best_val = val
158 | self.best_test = test
159 | self.best_time = time.item()
160 |
161 | def integrate(self, t):
162 | time_grid = self.grid_constructor(self.func, self.y0, t)
163 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
164 |
165 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
166 | solution[0] = self.y0
167 |
168 | j = 1
169 | y0 = self.y0
170 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
171 | dy = self._step_func(self.func, t0, t1 - t0, t1, y0)
172 | y1 = y0 + dy
173 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1)
174 | if val_acc > self.best_val:
175 | self.set_accs(train_acc, val_acc, test_acc, t1)
176 |
177 | while j < len(t) and t1 >= t[j]:
178 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
179 | j += 1
180 | y0 = y1
181 |
182 | return t1, solution
183 |
184 | @torch.no_grad()
185 | def test(self, logits):
186 | accs = []
187 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
188 | pred = logits[mask].max(1)[1]
189 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
190 | accs.append(acc)
191 | return accs
192 |
193 | @torch.no_grad()
194 | def test_OGB(self, logits):
195 | evaluator = self.evaluator
196 | data = self.data
197 | y_pred = logits.argmax(dim=-1, keepdim=True)
198 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
199 | return [train_acc, valid_acc, test_acc]
200 |
201 | @torch.no_grad()
202 | def evaluate(self, z, t0, t1):
203 | # Activation.
204 | if not self.m2.in_features == z.shape[1]: # system has been augmented
205 | z = torch.split(z, self.m2.in_features, dim=1)[0]
206 | z = F.relu(z)
207 | z = self.m2(z)
208 | if self.dataset == 'ogbn-arxiv':
209 | z = z.log_softmax(dim=-1)
210 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
211 | else:
212 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
213 | train_acc, val_acc, test_acc = self.ode_test(z)
214 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
215 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
216 | return train_acc, val_acc, test_acc
217 |
218 | def set_m2(self, m2):
219 | self.m2 = copy.deepcopy(m2)
220 |
221 | def set_data(self, data):
222 | if self.data is None:
223 | self.data = data
224 |
225 |
226 | class EarlyStopSympEuler(GeomtricFixedGridODESolver):
227 | order = 1
228 |
229 | def __init__(self, func, y0, rtol, atol, opt, eps=0, **kwargs):
230 | super(EarlyStopSympEuler, self).__init__(func, y0, step_size=opt['step_size'])
231 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device)
232 | self.lf = torch.nn.CrossEntropyLoss()
233 | self.m2 = None
234 | self.data = None
235 | self.best_val = 0
236 | self.best_test = 0
237 | self.best_time = 0
238 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
239 | self.dataset = opt['dataset']
240 | if opt['dataset'] == 'ogbn-arxiv':
241 | self.lf = torch.nn.functional.nll_loss
242 | self.evaluator = Evaluator(name=opt['dataset'])
243 |
244 | def _step_func(self, func, t0, dt, t1, y0):
245 | return SymplecticEuler_step_func(func, t0, dt, t1, y0)
246 |
247 | def set_accs(self, train, val, test, time):
248 | self.best_train = train
249 | self.best_val = val
250 | self.best_test = test
251 | self.best_time = time.item()
252 |
253 | def integrate(self, t):
254 | time_grid = self.grid_constructor(self.func, self.y0, t)
255 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
256 |
257 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
258 | solution[0] = self.y0
259 |
260 | j = 1
261 | y0 = self.y0
262 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
263 | dt = t1 - t0
264 | y1 = self._step_func(self.func, t0, dt, t1, y0)
265 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1)
266 | if val_acc > self.best_val:
267 | self.set_accs(train_acc, val_acc, test_acc, t1)
268 |
269 | while j < len(t) and t1 >= t[j]:
270 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
271 | j += 1
272 | y0 = y1
273 |
274 | return t1, solution
275 |
276 | @torch.no_grad()
277 | def test(self, logits):
278 | accs = []
279 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
280 | pred = logits[mask].max(1)[1]
281 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
282 | accs.append(acc)
283 | return accs
284 |
285 | @torch.no_grad()
286 | def test_OGB(self, logits):
287 | evaluator = self.evaluator
288 | data = self.data
289 | y_pred = logits.argmax(dim=-1, keepdim=True)
290 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
291 | return [train_acc, valid_acc, test_acc]
292 |
293 | @torch.no_grad()
294 | def evaluate(self, z, t0, t1):
295 | # Activation.
296 | if not self.m2.in_features == z.shape[1]: # system has been augmented
297 | z = torch.split(z, self.m2.in_features, dim=1)[0]
298 | z = F.relu(z)
299 | z = self.m2(z)
300 | if self.dataset == 'ogbn-arxiv':
301 | z = z.log_softmax(dim=-1)
302 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
303 | else:
304 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
305 | train_acc, val_acc, test_acc = self.ode_test(z)
306 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
307 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
308 | return train_acc, val_acc, test_acc
309 |
310 | def set_m2(self, m2):
311 | self.m2 = copy.deepcopy(m2)
312 |
313 | def set_data(self, data):
314 | if self.data is None:
315 | self.data = data
316 |
317 |
318 | class EarlyStopLepfrog(GeomtricFixedGridODESolver):
319 | order = 2
320 |
321 | def __init__(self, func, y0, rtol, atol, opt, eps=0, **kwargs):
322 | super(EarlyStopLepfrog, self).__init__(func, y0, step_size=opt['step_size'])
323 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device)
324 | self.lf = torch.nn.CrossEntropyLoss()
325 | self.m2 = None
326 | self.data = None
327 | self.best_val = 0
328 | self.best_test = 0
329 | self.best_time = 0
330 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test
331 | self.dataset = opt['dataset']
332 | if opt['dataset'] == 'ogbn-arxiv':
333 | self.lf = torch.nn.functional.nll_loss
334 | self.evaluator = Evaluator(name=opt['dataset'])
335 |
336 | def _step_func(self, func, t0, dt, t1, y0):
337 | return Leapfrog_step_func(func, t0, dt, t1, y0)
338 |
339 | def set_accs(self, train, val, test, time):
340 | self.best_train = train
341 | self.best_val = val
342 | self.best_test = test
343 | self.best_time = time.item()
344 |
345 | def integrate(self, t):
346 | time_grid = self.grid_constructor(self.func, self.y0, t)
347 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
348 |
349 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
350 | solution[0] = self.y0
351 |
352 | j = 1
353 | y0 = self.y0
354 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
355 | dt = t1 - t0
356 | y1 = self._step_func(self.func, t0, dt, t1, y0)
357 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1)
358 | if val_acc > self.best_val:
359 | self.set_accs(train_acc, val_acc, test_acc, t1)
360 |
361 | while j < len(t) and t1 >= t[j]:
362 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
363 | j += 1
364 | y0 = y1
365 |
366 | return t1, solution
367 |
368 | @torch.no_grad()
369 | def test(self, logits):
370 | accs = []
371 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'):
372 | pred = logits[mask].max(1)[1]
373 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item()
374 | accs.append(acc)
375 | return accs
376 |
377 | @torch.no_grad()
378 | def test_OGB(self, logits):
379 | evaluator = self.evaluator
380 | data = self.data
381 | y_pred = logits.argmax(dim=-1, keepdim=True)
382 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred)
383 | return [train_acc, valid_acc, test_acc]
384 |
385 | @torch.no_grad()
386 | def evaluate(self, z, t0, t1):
387 | # Activation.
388 | if not self.m2.in_features == z.shape[1]: # system has been augmented
389 | z = torch.split(z, self.m2.in_features, dim=1)[0]
390 | z = F.relu(z)
391 | z = self.m2(z)
392 | if self.dataset == 'ogbn-arxiv':
393 | z = z.log_softmax(dim=-1)
394 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask])
395 | else:
396 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask])
397 | train_acc, val_acc, test_acc = self.ode_test(z)
398 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
399 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc))
400 | return train_acc, val_acc, test_acc
401 |
402 | def set_m2(self, m2):
403 | self.m2 = copy.deepcopy(m2)
404 |
405 | def set_data(self, data):
406 | if self.data is None:
407 | self.data = data
408 |
409 |
410 | SOLVERS = {
411 | 'dopri5': EarlyStopDopri5,
412 | 'rk4': EarlyStopRK4,
413 | 'symplectic_euler': EarlyStopSympEuler,
414 | 'leapfrog': EarlyStopLepfrog
415 | }
416 |
417 |
418 | class EarlyStopInt(torch.nn.Module):
419 | def __init__(self, t, opt, device=None):
420 | super(EarlyStopInt, self).__init__()
421 | self.device = device
422 | self.solver = None
423 | self.data = None
424 | self.m2 = None
425 | self.opt = opt
426 | self.t = torch.tensor([0, opt['earlystopxT'] * t], dtype=torch.float).to(self.device)
427 |
428 | def __call__(self, func, y0, t, method=None, rtol=1e-7, atol=1e-9,
429 | adjoint_method="dopri5", adjoint_atol=1e-9, adjoint_rtol=1e-7, options=None):
430 | """Integrate a system of ordinary differential equations.
431 |
432 | Solves the initial value problem for a non-stiff system of first order ODEs:
433 | ```
434 | dy/dt = func(t, y), y(t[0]) = y0
435 | ```
436 | where y is a Tensor of any shape.
437 |
438 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
439 |
440 | Args:
441 | func: Function that maps a Tensor holding the state `y` and a scalar Tensor
442 | `t` into a Tensor of state derivatives with respect to time.
443 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May
444 | have any floating point or complex dtype.
445 | t: 1-D Tensor holding a sequence of time points for which to solve for
446 | `y`. The initial time point should be the first element of this sequence,
447 | and each time must be larger than the previous time. May have any floating
448 | point dtype. Converted to a Tensor with float64 dtype.
449 | rtol: optional float64 Tensor specifying an upper bound on relative error,
450 | per element of `y`.
451 | atol: optional float64 Tensor specifying an upper bound on absolute error,
452 | per element of `y`.
453 | method: optional string indicating the integration method to use.
454 | options: optional dict of configuring options for the indicated integration
455 | method. Can only be provided if a `method` is explicitly set.
456 | name: Optional name for this operation.
457 |
458 | Returns:
459 | y: Tensor, where the first dimension corresponds to different
460 | time points. Contains the solved value of y for each desired time point in
461 | `t`, with the initial value `y0` being the first element along the first
462 | dimension.
463 |
464 | Raises:
465 | ValueError: if an invalid `method` is provided.
466 | TypeError: if `options` is supplied without `method`, or if `t` or `y0` has
467 | an invalid dtype.
468 | """
469 | method = self.opt['method']
470 | assert method in ['rk4', 'dopri5','symplectic_euler', 'leapfrog'], "Only dopri5, rk4, symplectic_euler and leapfrog implemented with early stopping"
471 |
472 | ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4]
473 | if int(ver) >= 22: #'0.2.2'
474 | event_fn = None
475 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, self.t, rtol,
476 | atol, method, options,
477 | event_fn, SOLVERS)
478 | else:
479 | shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, self.t, rtol, atol, method, options,
480 | SOLVERS)
481 |
482 | self.solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, opt=self.opt, **options)
483 | if self.solver.data is None:
484 | self.solver.data = self.data
485 | self.solver.m2 = self.m2
486 | t, solution = self.solver.integrate(t)
487 | if shapes is not None:
488 | solution = _flat_to_shape(solution, (len(t),), shapes)
489 | return solution
490 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/experiment_configs/gat_pubmed.yaml:
--------------------------------------------------------------------------------
1 | program: run_GNN.py
2 | method: random
3 | parameters:
4 | method:
5 | distribution: categorical
6 | values: [ 'leapfrog' ]
7 | tol_scale:
8 | distribution: log_uniform
9 | min: -1
10 | max: 7
11 | function:
12 | distribution: constant
13 | value: GAT
14 | block:
15 | distribution: constant
16 | value: constant
17 | num_splits:
18 | distribution: constant
19 | value: 3
20 | heads:
21 | distribution: categorical
22 | values: [ 1, 2 ]
23 | epoch:
24 | distribution: constant
25 | value: 70
26 | time:
27 | distribution: uniform
28 | min: 3
29 | max: 25
30 | dataset:
31 | distribution: constant
32 | value: 'Pubmed'
33 | hidden_dim:
34 | distribution: categorical
35 | values: [ 32, 64 , 128 ]
36 | attention_dim:
37 | distribution: categorical
38 | values: [ 32, 64 , 128 ]
39 | input_dropout:
40 | distribution: uniform
41 | min: 0.0
42 | max: 0.5
43 | dropout:
44 | distribution: uniform
45 | min: 0.0
46 | max: 0.5
47 | lr:
48 | distribution: uniform
49 | min: 0.01
50 | max: 0.07
51 | decay:
52 | distribution: uniform
53 | min: 0
54 | max: 0.02
55 | command:
56 | - ${env}
57 | - python3
58 | - ${program}
59 | - ${args}
60 | - --wandb
61 | - --wandb_sweep
62 | - --adjoint
63 | - --add_source
64 | entity:
65 | graphcon
66 | project:
67 | gat_pubmed
68 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/experiment_configs/gcn_cora.yaml:
--------------------------------------------------------------------------------
1 | program: run_GNN.py
2 | method: random
3 | parameters:
4 | method:
5 | distribution: categorical
6 | values: [ 'dopri5' ]
7 | tol_scale:
8 | distribution: log_uniform
9 | min: 0
10 | max: 7
11 | function:
12 | distribution: constant
13 | value: gcn
14 | num_splits:
15 | distribution: constant
16 | value: 3
17 | epoch:
18 | distribution: constant
19 | value: 50
20 | time:
21 | distribution: uniform
22 | min: 1
23 | max: 20
24 | dataset:
25 | distribution: constant
26 | value: 'Cora'
27 | hidden_dim:
28 | values: [ 32, 64 , 128 ]
29 | input_dropout:
30 | distribution: uniform
31 | min: 0.0
32 | max: 0.5
33 | dropout:
34 | distribution: uniform
35 | min: 0.0
36 | max: 0.5
37 | lr:
38 | distribution: uniform
39 | min: 0.001
40 | max: 0.03
41 | decay:
42 | distribution: uniform
43 | min: 0
44 | max: 0.02
45 | command:
46 | - ${env}
47 | - python3
48 | - ${program}
49 | - ${args}
50 | - --wandb
51 | - --wandb_sweep
52 | - --function
53 | - function_gcn
--------------------------------------------------------------------------------
/src/homophilic_graphs/experiment_configs/gcn_depth_random.yaml:
--------------------------------------------------------------------------------
1 | program: run_GNN.py
2 | method: random
3 | parameters:
4 | method:
5 | value: 'symplectic_euler'
6 | step_size:
7 | value: 1
8 | function:
9 | value: gcn
10 | epoch:
11 | value: 100
12 | hidden_dim:
13 | value: 64
14 | num_splits:
15 | value: 2
16 | dataset:
17 | values: [ 'Cora', 'Citeseer' ]
18 | time:
19 | values: [2, 4, 8, 16, 32, 64]
20 | input_dropout:
21 | distribution: uniform
22 | min: 0.0
23 | max: 0.5
24 | dropout:
25 | distribution: uniform
26 | min: 0.0
27 | max: 0.5
28 | lr:
29 | distribution: uniform
30 | min: 0.01
31 | max: 0.1
32 | decay:
33 | distribution: uniform
34 | min: 0.001
35 | max: 0.1
36 | command:
37 | - ${env}
38 | - python3
39 | - ${program}
40 | - ${args}
41 | - --wandb
42 | - --wandb_sweep
43 | entity: graphcon
44 | project: gcn_depth_random
--------------------------------------------------------------------------------
/src/homophilic_graphs/experiment_configs/gcn_planetoid.yaml:
--------------------------------------------------------------------------------
1 | program: run_GNN.py
2 | method: random
3 | parameters:
4 | method:
5 | distribution: categorical
6 | values: [ 'dopri5', 'symplectic_euler', 'rk4', 'leapfrog' ]
7 | step_size:
8 | distribution: categorical
9 | values: [ 0.25, 0.5, 1 ]
10 | function:
11 | distribution: constant
12 | value: gcn
13 | epoch:
14 | distribution: constant
15 | value: 100
16 | time:
17 | distribution: uniform
18 | min: 4
19 | max: 15
20 | dataset:
21 | distribution: categorical
22 | values: [ 'Cora', 'Citeseer', 'Pubmed' ]
23 | hidden_dim:
24 | values: [ 8, 16, 32, 64 ]
25 | attention_dim:
26 | values: [ 8, 16, 32, 64 ]
27 | input_dropout:
28 | distribution: uniform
29 | min: 0.0
30 | max: 0.5
31 | dropout:
32 | distribution: uniform
33 | min: 0.0
34 | max: 0.5
35 | lr:
36 | distribution: uniform
37 | min: 0.01
38 | max: 0.1
39 | decay:
40 | distribution: uniform
41 | min: 0.001
42 | max: 0.1
43 | command:
44 | - ${env}
45 | - python3
46 | - ${program}
47 | - ${args}
48 | - --wandb
49 | - --wandb_sweep
50 | - --function
51 | - function_gcn
--------------------------------------------------------------------------------
/src/homophilic_graphs/experiment_configs/run_sweeps.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | for i in {0..7}
4 | do
5 | CUDA_VISIBLE_DEVICES=$(($i % 8)) wandb agent graphcon/gat_pubmed/$1 &
6 | done
--------------------------------------------------------------------------------
/src/homophilic_graphs/function_GAT_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.utils import softmax
4 | import torch_sparse
5 | import torch.nn.functional as F
6 | from torch_geometric.utils.loop import add_remaining_self_loops
7 | from data import get_dataset
8 | from utils import MaxNFEException
9 | from base_classes import ODEFunc
10 |
11 |
12 | class ODEFuncAtt(ODEFunc):
13 |
14 | def __init__(self, in_features, out_features, opt, data, device):
15 | super(ODEFuncAtt, self).__init__(opt, data, device)
16 |
17 | if opt['self_loop_weight'] > 0:
18 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr,
19 | fill_value=opt['self_loop_weight'])
20 | else:
21 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr
22 |
23 | self.multihead_att_layer = SpGraphAttentionLayer(in_features, out_features, opt,
24 | device).to(device)
25 | try:
26 | self.attention_dim = opt['attention_dim']
27 | except KeyError:
28 | self.attention_dim = out_features
29 |
30 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size"
31 | self.d_k = self.attention_dim // opt['heads']
32 |
33 | def multiply_attention(self, x, attention, wx):
34 | if self.opt['mix_features']:
35 | wx = torch.mean(torch.stack(
36 | [torch_sparse.spmm(self.edge_index, attention[:, idx], wx.shape[0], wx.shape[0], wx) for idx in
37 | range(self.opt['heads'])], dim=0),
38 | dim=0)
39 | ax = torch.mm(wx, self.multihead_att_layer.Wout)
40 | else:
41 | ax = torch.mean(torch.stack(
42 | [torch_sparse.spmm(self.edge_index, attention[:, idx], x.shape[0], x.shape[0], x) for idx in
43 | range(self.opt['heads'])], dim=0),
44 | dim=0)
45 | return ax
46 |
47 | def forward(self, t, x_full): # t is needed when called by the integrator
48 | x = x_full[:, :self.opt['hidden_dim']]
49 | y = x_full[:, self.opt['hidden_dim']:]
50 | if self.nfe > self.opt["max_nfe"]:
51 | raise MaxNFEException
52 | self.nfe += 1
53 | attention, wy = self.multihead_att_layer(y, self.edge_index)
54 | ay = self.multiply_attention(y, attention, wy)
55 | # todo would be nice if this was more efficient
56 |
57 | if not self.opt['no_alpha_sigmoid']:
58 | alpha = torch.sigmoid(self.alpha_train)
59 | else:
60 | alpha = self.alpha_train
61 | f = (ay - y - x)
62 | if self.opt['add_source']:
63 | f = (1. - torch.sigmoid(self.beta_train)) * f + torch.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:]
64 | f = torch.cat([f, (1. - torch.sigmoid(self.beta_train2)) * alpha * x + torch.sigmoid(self.beta_train2) * self.x0[:,:self.opt['hidden_dim']]],dim=1)
65 | return f
66 |
67 | def __repr__(self):
68 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
69 |
70 |
71 | class SpGraphAttentionLayer(nn.Module):
72 | """
73 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
74 | """
75 |
76 | def __init__(self, in_features, out_features, opt, device, concat=True):
77 | super(SpGraphAttentionLayer, self).__init__()
78 | self.in_features = in_features
79 | self.out_features = out_features
80 | self.alpha = opt['leaky_relu_slope']
81 | self.concat = concat
82 | self.device = device
83 | self.opt = opt
84 | self.h = opt['heads']
85 |
86 | try:
87 | self.attention_dim = opt['attention_dim']
88 | except KeyError:
89 | self.attention_dim = out_features
90 |
91 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size"
92 | self.d_k = self.attention_dim // opt['heads']
93 |
94 | self.W = nn.Parameter(torch.zeros(size=(in_features, self.attention_dim))).to(device)
95 | nn.init.xavier_normal_(self.W.data, gain=1.414)
96 |
97 | self.Wout = nn.Parameter(torch.zeros(size=(self.attention_dim, self.in_features))).to(device)
98 | nn.init.xavier_normal_(self.Wout.data, gain=1.414)
99 |
100 | self.a = nn.Parameter(torch.zeros(size=(2 * self.d_k, 1, 1))).to(device)
101 | nn.init.xavier_normal_(self.a.data, gain=1.414)
102 |
103 | self.leakyrelu = nn.LeakyReLU(self.alpha)
104 |
105 | def forward(self, x, edge):
106 | wx = torch.mm(x, self.W) # h: N x out
107 | h = wx.view(-1, self.h, self.d_k)
108 | h = h.transpose(1, 2)
109 |
110 | # Self-attention on the nodes - Shared attention mechanism
111 | edge_h = torch.cat((h[edge[0, :], :, :], h[edge[1, :], :, :]), dim=1).transpose(0, 1).to(
112 | self.device) # edge: 2*D x E
113 | edge_e = self.leakyrelu(torch.sum(self.a * edge_h, dim=0)).to(self.device)
114 | attention = softmax(edge_e, edge[self.opt['attention_norm_idx']])
115 | return attention, wx
116 |
117 | def __repr__(self):
118 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
119 |
120 |
121 | if __name__ == '__main__':
122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
123 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 'attention_norm_idx': 0,
124 | 'add_source':False, 'alpha_dim': 'sc', 'beta_dim': 'vc', 'max_nfe':1000, 'mix_features': False}
125 | dataset = get_dataset(opt, '../data', False)
126 | t = 1
127 | func = ODEFuncAtt(dataset.data.num_features, 6, opt, dataset.data, device)
128 | out = func(t, dataset.data.x)
129 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/function_gcn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch_sparse
4 | import torch.nn.functional as F
5 | from base_classes import ODEFunc
6 | from utils import MaxNFEException
7 | from torch_geometric.nn.conv import GCNConv
8 |
9 |
10 | # Define the ODE function.
11 | # Input:
12 | # --- t: A tensor with shape [], meaning the current time.
13 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
14 | # Output:
15 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
16 | class GCNFunc(ODEFunc):
17 |
18 | # currently requires in_features = out_features
19 | def __init__(self, in_features, out_features, opt, data, device):
20 | super(GCNFunc, self).__init__(opt, data, device)
21 |
22 | self.in_features = in_features
23 | self.out_features = out_features
24 | self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
25 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
26 | self.alpha_sc = nn.Parameter(torch.ones(1))
27 | self.beta_sc = nn.Parameter(torch.ones(1))
28 | self.conv = GCNConv(in_features, out_features, normalize=True)
29 |
30 | def sparse_multiply(self, x):
31 | if self.opt['block'] in ['attention']: # adj is a multihead attention
32 | # ax = torch.mean(torch.stack(
33 | # [torch_sparse.spmm(self.edge_index, self.attention_weights[:, idx], x.shape[0], x.shape[0], x) for idx in
34 | # range(self.opt['heads'])], dim=0), dim=0)
35 | mean_attention = self.attention_weights.mean(dim=1)
36 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
37 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix
38 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
39 | else: # adj is a torch sparse matrix
40 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
41 | return ax
42 |
43 | def forward(self, t, x_full): # the t param is needed by the ODE solver.
44 | x = x_full[:, :self.opt['hidden_dim']]
45 | y = x_full[:, self.opt['hidden_dim']:]
46 | if self.nfe > self.opt["max_nfe"]:
47 | raise MaxNFEException
48 | self.nfe += 1
49 | # ay = self.sparse_multiply(y)
50 | if not self.opt['no_alpha_sigmoid']:
51 | alpha = torch.sigmoid(self.alpha_train)
52 | else:
53 | alpha = self.alpha_train
54 | f = self.conv(y, self.edge_index) - x - y
55 | # f = (ay - x - y)
56 | if self.opt['add_source']:
57 | f = (1. - F.sigmoid(self.beta_train)) * f + F.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:]
58 | f = torch.cat([f, (1. - F.sigmoid(self.beta_train2)) * alpha * x + F.sigmoid(self.beta_train2) *
59 | self.x0[:, :self.opt['hidden_dim']]], dim=1)
60 | return f
61 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/function_laplacian_diffusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch_sparse
4 | import torch.nn.functional as F
5 | from base_classes import ODEFunc
6 | from utils import MaxNFEException
7 |
8 |
9 | # Define the ODE function.
10 | # Input:
11 | # --- t: A tensor with shape [], meaning the current time.
12 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
13 | # Output:
14 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
15 | class LaplacianODEFunc(ODEFunc):
16 |
17 | # currently requires in_features = out_features
18 | def __init__(self, in_features, out_features, opt, data, device):
19 | super(LaplacianODEFunc, self).__init__(opt, data, device)
20 |
21 | self.in_features = in_features
22 | self.out_features = out_features
23 | self.w = nn.Parameter(torch.eye(opt['hidden_dim']))
24 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1)
25 | self.alpha_sc = nn.Parameter(torch.ones(1))
26 | self.beta_sc = nn.Parameter(torch.ones(1))
27 |
28 | def sparse_multiply(self, x):
29 | if self.opt['block'] in ['attention']: # adj is a multihead attention
30 | # ax = torch.mean(torch.stack(
31 | # [torch_sparse.spmm(self.edge_index, self.attention_weights[:, idx], x.shape[0], x.shape[0], x) for idx in
32 | # range(self.opt['heads'])], dim=0), dim=0)
33 | mean_attention = self.attention_weights.mean(dim=1)
34 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
35 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix
36 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x)
37 | else: # adj is a torch sparse matrix
38 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x)
39 | return ax
40 |
41 | def forward(self, t, x_full): # the t param is needed by the ODE solver.
42 | x = x_full[:,:self.opt['hidden_dim']]
43 | y = x_full[:,self.opt['hidden_dim']:]
44 | if self.nfe > self.opt["max_nfe"]:
45 | raise MaxNFEException
46 | self.nfe += 1
47 | ay = self.sparse_multiply(y)
48 | if not self.opt['no_alpha_sigmoid']:
49 | alpha = torch.sigmoid(self.alpha_train)
50 | else:
51 | alpha = self.alpha_train
52 | f = (ay - y - x)
53 | if self.opt['add_source']:
54 | f = (1.-F.sigmoid(self.beta_train))*f + F.sigmoid(self.beta_train) * self.x0[:,self.opt['hidden_dim']:]
55 | f = torch.cat([f,(1.-F.sigmoid(self.beta_train2))*alpha*x + F.sigmoid(self.beta_train2) * self.x0[:,:self.opt['hidden_dim']]],dim=1)
56 | return f
57 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/function_transformer_attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch_geometric.utils import softmax
4 | import torch_sparse
5 | from torch_geometric.utils.loop import add_remaining_self_loops
6 | import numpy as np
7 | from data import get_dataset
8 | from utils import MaxNFEException
9 | from base_classes import ODEFunc
10 |
11 |
12 | class ODEFuncTransformerAtt(ODEFunc):
13 |
14 | def __init__(self, in_features, out_features, opt, data, device):
15 | super(ODEFuncTransformerAtt, self).__init__(opt, data, device)
16 |
17 | if opt['self_loop_weight'] > 0:
18 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr,
19 | fill_value=opt['self_loop_weight'])
20 | else:
21 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr
22 | self.multihead_att_layer = SpGraphTransAttentionLayer(in_features, out_features, opt,
23 | device, edge_weights=self.edge_weight).to(device)
24 |
25 | def multiply_attention(self, x, attention, v=None):
26 | # todo would be nice if this was more efficient
27 | if self.opt['mix_features']:
28 | vx = torch.mean(torch.stack(
29 | [torch_sparse.spmm(self.edge_index, attention[:, idx], v.shape[0], v.shape[0], v[:, :, idx]) for idx in
30 | range(self.opt['heads'])], dim=0),
31 | dim=0)
32 | ax = self.multihead_att_layer.Wout(vx)
33 | else:
34 | mean_attention = attention.mean(dim=1)
35 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x)
36 | return ax
37 |
38 | def forward(self, t, x): # t is needed when called by the integrator
39 | if self.nfe > self.opt["max_nfe"]:
40 | raise MaxNFEException
41 |
42 | self.nfe += 1
43 | attention, values = self.multihead_att_layer(x, self.edge_index)
44 | ax = self.multiply_attention(x, attention, values)
45 |
46 | if not self.opt['no_alpha_sigmoid']:
47 | alpha = torch.sigmoid(self.alpha_train)
48 | else:
49 | alpha = self.alpha_train
50 | f = alpha * (ax - x)
51 | if self.opt['add_source']:
52 | f = f + self.beta_train * self.x0
53 | return f
54 |
55 | def __repr__(self):
56 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
57 |
58 |
59 | class SpGraphTransAttentionLayer(nn.Module):
60 | """
61 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
62 | """
63 |
64 | def __init__(self, in_features, out_features, opt, device, concat=True, edge_weights=None):
65 | super(SpGraphTransAttentionLayer, self).__init__()
66 | self.in_features = in_features
67 | self.out_features = out_features
68 | self.alpha = opt['leaky_relu_slope']
69 | self.concat = concat
70 | self.device = device
71 | self.opt = opt
72 | self.h = int(opt['heads'])
73 | self.edge_weights = edge_weights
74 |
75 | try:
76 | self.attention_dim = opt['attention_dim']
77 | except KeyError:
78 | self.attention_dim = out_features
79 |
80 | assert self.attention_dim % self.h == 0, "Number of heads ({}) must be a factor of the dimension size ({})".format(
81 | self.h, self.attention_dim)
82 | self.d_k = self.attention_dim // self.h
83 |
84 | self.Q = nn.Linear(in_features, self.attention_dim)
85 | self.init_weights(self.Q)
86 |
87 | self.V = nn.Linear(in_features, self.attention_dim)
88 | self.init_weights(self.V)
89 |
90 | self.K = nn.Linear(in_features, self.attention_dim)
91 | self.init_weights(self.K)
92 |
93 | self.activation = nn.Sigmoid() # nn.LeakyReLU(self.alpha)
94 |
95 | self.Wout = nn.Linear(self.d_k, in_features)
96 | self.init_weights(self.Wout)
97 |
98 | def init_weights(self, m):
99 | if type(m) == nn.Linear:
100 | # nn.init.xavier_uniform_(m.weight, gain=1.414)
101 | # m.bias.data.fill_(0.01)
102 | nn.init.constant_(m.weight, 1e-5)
103 |
104 | def forward(self, x, edge):
105 | q = self.Q(x)
106 | k = self.K(x)
107 | v = self.V(x)
108 |
109 | # perform linear operation and split into h heads
110 |
111 | k = k.view(-1, self.h, self.d_k)
112 | q = q.view(-1, self.h, self.d_k)
113 | v = v.view(-1, self.h, self.d_k)
114 |
115 | # transpose to get dimensions [n_nodes, attention_dim, n_heads]
116 |
117 | k = k.transpose(1, 2)
118 | q = q.transpose(1, 2)
119 | v = v.transpose(1, 2)
120 |
121 | src = q[edge[0, :], :, :]
122 | dst_k = k[edge[1, :], :, :]
123 | prods = torch.sum(src * dst_k, dim=1) / np.sqrt(self.d_k)
124 | if self.opt['reweight_attention'] and self.edge_weights is not None:
125 | prods = prods * self.edge_weights.unsqueeze(dim=1)
126 | attention = softmax(prods, edge[self.opt['attention_norm_idx']])
127 | return attention, v
128 |
129 | def __repr__(self):
130 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
131 |
132 |
133 | if __name__ == '__main__':
134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
135 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'heads': 2, 'K': 10,
136 | 'attention_norm_idx': 0, 'add_source': False,
137 | 'alpha_dim': 'sc', 'beta_dim': 'sc', 'max_nfe': 1000, 'mix_features': False
138 | }
139 | dataset = get_dataset(opt, '../data', False)
140 | t = 1
141 | func = ODEFuncTransformerAtt(dataset.data.num_features, 6, opt, dataset.data, device)
142 | out = func(t, dataset.data.x)
143 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/geometric_integrators.py:
--------------------------------------------------------------------------------
1 | from geometric_solvers import GeomtricFixedGridODESolver
2 | from torchdiffeq._impl.misc import Perturb
3 | import torch
4 |
5 |
6 | class SymplecticEuler(GeomtricFixedGridODESolver):
7 | order = 1
8 |
9 | def _step_func(self, func, t0, dt, t1, y0):
10 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1)
11 | f0x = torch.split(func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0]
12 | x1 = x0 + dt*f0x
13 | f0z = torch.split(func(t0, torch.cat([x1,z0],dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[1]
14 | z1 = z0 + dt*f0z
15 | y1 = torch.cat([x1,z1],dim=1)
16 | return y1
17 |
18 | class Leapfrog(GeomtricFixedGridODESolver):
19 | order = 2
20 |
21 | def _step_func(self, func, t0, dt, t1, y0):
22 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1)
23 | f0x_1 = torch.split(func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0]
24 | x_half = x0 + dt/2.*f0x_1
25 |
26 | f0z = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[1]
27 | z1 = z0 + dt * f0z
28 |
29 | f0x_2 = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[0]
30 | x1 = x_half + dt/2.*f0x_2
31 | y1 = torch.cat([x1,z1],dim=1)
32 | return y1
33 |
34 |
35 | def SymplecticEuler_step_func(func, t0, dt, t1, y0, perturb=False):
36 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1)
37 | f0x = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0]
38 | x1 = x0 + dt*f0x
39 | y0 = torch.cat([x1,z0],dim=1)
40 | f0z = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[1]
41 | z1 = z0 + dt*f0z
42 | y1 = torch.cat([x1,z1],dim=1)
43 | return y1
44 |
45 |
46 | def Leapfrog_step_func(func, t0, dt, t1, y0, perturb=False):
47 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1)
48 | f0x_1 = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[0]
49 | x_half = x0 + dt / 2. * f0x_1
50 |
51 | f0z = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[1]
52 | z1 = z0 + dt * f0z
53 |
54 | f0x_2 = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0]
55 | x1 = x_half + dt / 2. * f0x_2
56 | y1 = torch.cat([x1, z1], dim=1)
57 | return y1
--------------------------------------------------------------------------------
/src/homophilic_graphs/geometric_solvers.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import torch
3 | from torchdiffeq._impl.misc import _handle_unused_kwargs
4 |
5 |
6 | class GeomtricFixedGridODESolver(metaclass=abc.ABCMeta):
7 | order: int
8 |
9 | def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs):
10 | self.func = func
11 | self.y0 = y0
12 | self.dtype = y0.dtype
13 | self.device = y0.device
14 | self.step_size = step_size
15 | self.interp = interp
16 | self.perturb = perturb
17 | self.grid_constructor = self._grid_constructor_from_step_size(step_size)
18 |
19 | @staticmethod
20 | def _grid_constructor_from_step_size(step_size):
21 | def _grid_constructor(func, y0, t):
22 | start_time = t[0]
23 | end_time = t[-1]
24 |
25 | niters = torch.ceil((end_time - start_time) / step_size + 1).item()
26 | t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time
27 | t_infer[-1] = t[-1]
28 |
29 | return t_infer
30 | return _grid_constructor
31 |
32 | @abc.abstractmethod
33 | def _step_func(self, func, t0, dt, t1, y0):
34 | pass
35 |
36 | def integrate(self, t):
37 | time_grid = self.grid_constructor(self.func, self.y0, t)
38 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1]
39 |
40 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device)
41 | solution[0] = self.y0
42 |
43 | j = 1
44 | y0 = self.y0
45 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]):
46 | dt = t1 - t0
47 | y1 = self._step_func(self.func, t0, dt, t1, y0)
48 | while j < len(t) and t1 >= t[j]:
49 | if self.interp == "linear":
50 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j])
51 | elif self.interp == "cubic":
52 | raise NotImplementedError("Not implemented for geometric integrators")
53 | else:
54 | raise ValueError(f"Unknown interpolation method {self.interp}")
55 | j += 1
56 | y0 = y1
57 |
58 | return solution
59 |
60 | def integrate_until_event(self, t0, event_fn):
61 | raise NotImplementedError("Not implemented for geometric integrators")
62 |
63 | def _cubic_hermite_interp(self, t0, y0, f0, t1, y1, f1, t):
64 | h = (t - t0) / (t1 - t0)
65 | h00 = (1 + 2 * h) * (1 - h) * (1 - h)
66 | h10 = h * (1 - h) * (1 - h)
67 | h01 = h * h * (3 - 2 * h)
68 | h11 = h * h * (h - 1)
69 | dt = (t1 - t0)
70 | return h00 * y0 + h10 * dt * f0 + h01 * y1 + h11 * dt * f1
71 |
72 | def _linear_interp(self, t0, t1, y0, y1, t):
73 | if t == t0:
74 | return y0
75 | if t == t1:
76 | return y1
77 | slope = (t - t0) / (t1 - t0)
78 | return y0 + slope * (y1 - y0)
--------------------------------------------------------------------------------
/src/homophilic_graphs/good_params_graphCON.py:
--------------------------------------------------------------------------------
1 | """
2 | Currently the adjoint method is not implemented and so the params for Pubmed are not correct as I have changed
3 | adjoint to False, but kept all other hyperparameter settings.
4 | """
5 | good_params_dict = {
6 | 'Cora': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False, 'adjoint_method': 'adaptive_heun',
7 | 'adjoint_step_size': 1, 'alpha': 1.0, 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 128,
8 | 'attention_norm_idx': 1, 'attention_rewiring': False, 'attention_type': 'scaled_dot', 'augment': False,
9 | 'baseline': False, 'batch_norm': False, 'beltrami': True, 'beta_dim': 'sc', 'block': 'attention', 'cpus': 1,
10 | 'data_norm': 'rw', 'dataset': 'Cora', 'decay': 0.00507685443154266, 'directional_penalty': None,
11 | 'dropout': 0.17167167167167166, 'dt': 0.001, 'dt_min': 1e-05, 'epoch': 100, 'exact': True, 'fc_out': False,
12 | 'feat_hidden_dim': 64, 'function': 'laplacian', 'gdc_avg_degree': 64, 'gdc_k': 64, 'gdc_method': 'ppr',
13 | 'gdc_sparsification': 'topk', 'gdc_threshold': 0.01, 'gpus': 0.5, 'grace_period': 20, 'heads': 8,
14 | 'heat_time': 3.0, 'hidden_dim': 80, 'input_dropout': 0.5, 'jacobian_norm2': None, 'kinetic_energy': None,
15 | 'label_rate': 0.5, 'leaky_relu_slope': 0.2, 'lr': 0.01334134134134134, 'max_epochs': 1000, 'max_iters': 100,
16 | 'max_nfe': 2000, 'method': 'symplectic_euler', 'metric': 'accuracy', 'mix_features': False,
17 | 'name': 'cora_beltrami_splits', 'new_edges': 'random', 'no_alpha_sigmoid': False, 'not_lcc': True,
18 | 'num_init': 1, 'num_samples': 1000, 'num_splits': 1, 'ode_blocks': 1, 'optimizer': 'adamax', 'patience': 100,
19 | 'pos_enc_hidden_dim': 16, 'pos_enc_orientation': 'row', 'pos_enc_type': 'GDC', 'ppr_alpha': 0.05,
20 | 'reduction_factor': 10, 'regularise': False, 'reweight_attention': False, 'rewire_KNN': False,
21 | 'rewire_KNN_T': 'T0', 'rewire_KNN_epoch': 10, 'rewire_KNN_k': 64, 'rewire_KNN_sym': False, 'rewiring': None,
22 | 'rw_addD': 0.02, 'rw_rmvR': 0.02, 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True,
23 | 'step_size': 1, 'threshold_type': 'addD_rvR', 'time': 18.294754260552843, 'tol_scale': 821.9773048827274,
24 | 'tol_scale_adjoint': 1.0, 'total_deriv': None, 'use_cora_defaults': False, 'use_flux': False,
25 | 'use_labels': False, 'use_lcc': True, 'use_mlp': False},
26 | 'Citeseer': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False,
27 | 'adjoint_method': 'adaptive_heun', 'adjoint_step_size': 1, 'alpha': 1.0,
28 | 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 32, 'attention_norm_idx': 1,
29 | 'attention_rewiring': False, 'attention_type': 'exp_kernel', 'augment': False,
30 | 'baseline': False, 'batch_norm': False, 'beltrami': False, 'beta_dim': 'sc',
31 | 'block': 'attention', 'cpus': 1, 'data_norm': 'rw', 'dataset': 'Citeseer',
32 | 'decay': 0.1, 'directional_penalty': None, 'dropout': 0.8429429429429429, 'dt': 0.001,
33 | 'dt_min': 1e-05, 'epoch': 250, 'exact': True, 'fc_out': False, 'feat_hidden_dim': 64,
34 | 'function': 'laplacian', 'gdc_avg_degree': 64, 'gdc_k': 128, 'gdc_method': 'ppr',
35 | 'gdc_sparsification': 'topk', 'gdc_threshold': 0.01, 'gpus': 1.0, 'grace_period': 20,
36 | 'heads': 8, 'heat_time': 3.0, 'hidden_dim': 80, 'input_dropout': 0.6803233752085334,
37 | 'jacobian_norm2': None, 'kinetic_energy': None, 'label_rate': 0.5,
38 | 'leaky_relu_slope': 0.5825086997804176, 'lr': 0.006036036036036037, 'max_epochs': 1000,
39 | 'max_iters': 100, 'max_nfe': 3000, 'method': 'symplectic_euler', 'metric': 'accuracy',
40 | 'mix_features': False, 'name': 'Citeseer_beltrami_1_KNN', 'new_edges': 'random',
41 | 'no_alpha_sigmoid': False, 'not_lcc': True, 'num_class': 6, 'num_feature': 3703,
42 | 'num_init': 2, 'num_nodes': 2120, 'num_samples': 400, 'num_splits': 1, 'ode_blocks': 1,
43 | 'optimizer': 'adam', 'patience': 100, 'pos_enc_dim': 'row', 'pos_enc_hidden_dim': 16,
44 | 'ppr_alpha': 0.05, 'reduction_factor': 4, 'regularise': False,
45 | 'reweight_attention': False, 'rewire_KNN': False, 'rewire_KNN_epoch': 10,
46 | 'rewire_KNN_k': 64, 'rewire_KNN_sym': False, 'rewiring': None, 'rw_addD': 0.02,
47 | 'rw_rmvR': 0.02, 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True,
48 | 'step_size': 1, 'threshold_type': 'addD_rvR', 'time': 7.874113442879092,
49 | 'tol_scale': 2.9010446330432815, 'tol_scale_adjoint': 1.0, 'total_deriv': None,
50 | 'use_cora_defaults': False, 'use_flux': False, 'use_labels': False, 'use_lcc': True,
51 | 'use_mlp': False},
52 | 'Pubmed': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False,
53 | 'adjoint_method': 'adaptive_heun', 'adjoint_step_size': 1, 'alpha': 1.0,
54 | 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 16, 'attention_norm_idx': 0,
55 | 'attention_rewiring': False, 'attention_type': 'cosine_sim', 'augment': False,
56 | 'baseline': False, 'batch_norm': False, 'beltrami': False, 'beta_dim': 'sc',
57 | 'block': 'attention', 'cpus': 1, 'data_norm': 'rw', 'dataset': 'Pubmed',
58 | 'decay': 0.0018236722171703636, 'directional_penalty': None,
59 | 'dropout': 0.18918918918918917, 'dt': 0.001, 'dt_min': 1e-05, 'epoch': 600,
60 | 'exact': False, 'fc_out': False, 'feat_hidden_dim': 64, 'function': 'laplacian',
61 | 'gdc_avg_degree': 64, 'gdc_k': 64, 'gdc_method': 'ppr', 'gdc_sparsification': 'topk',
62 | 'gdc_threshold': 0.01, 'gpus': 1.0, 'grace_period': 20, 'heads': 1, 'heat_time': 3.0,
63 | 'hidden_dim': 128, 'input_dropout': 0.5, 'jacobian_norm2': None, 'kinetic_energy': None,
64 | 'label_rate': 0.5, 'leaky_relu_slope': 0.2, 'lr': 0.02522522522522523,
65 | 'max_epochs': 1000, 'max_iters': 100, 'max_nfe': 5000, 'method': 'symplectic_euler',
66 | 'metric': 'test_acc', 'mix_features': False, 'name': None, 'new_edges': 'random',
67 | 'no_alpha_sigmoid': False, 'not_lcc': True, 'num_init': 1, 'num_samples': 400,
68 | 'num_splits': 1, 'ode_blocks': 1, 'optimizer': 'adamax', 'patience': 100,
69 | 'pos_enc_dim': 'row', 'pos_enc_hidden_dim': 16, 'ppr_alpha': 0.05,
70 | 'reduction_factor': 10, 'regularise': False, 'reweight_attention': False,
71 | 'rewire_KNN': False, 'rewire_KNN_T': 'T0', 'rewire_KNN_epoch': 10, 'rewire_KNN_k': 64,
72 | 'rewire_KNN_sym': False, 'rewiring': None, 'rw_addD': 0.02, 'rw_rmvR': 0.02,
73 | 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True, 'step_size': 1,
74 | 'threshold_type': 'addD_rvR', 'time': 12.942327880200853,
75 | 'tol_scale': 1991.0688305523001, 'tol_scale_adjoint': 16324.368093998313,
76 | 'total_deriv': None, 'use_cora_defaults': False, 'use_flux': False, 'use_labels': False,
77 | 'use_lcc': True, 'use_mlp': False, 'folder': 'pubmed_linear_att_beltrami_adj2',
78 | 'index': 0, 'run_with_KNN': False, 'change_att_sim_type': False, 'reps': 1,
79 | 'max_test_steps': 100, 'no_early': False, 'earlystopxT': 5.0, 'pos_enc_csv': False,
80 | 'pos_enc_type': 'GDC'}
81 | }
82 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/model_configurations.py:
--------------------------------------------------------------------------------
1 | from function_transformer_attention import ODEFuncTransformerAtt
2 | from function_GAT_attention import ODEFuncAtt
3 | from function_laplacian_diffusion import LaplacianODEFunc
4 | from function_gcn import GCNFunc
5 | from block_transformer_attention import AttODEblock
6 | from block_constant import ConstantODEblock
7 |
8 |
9 | class BlockNotDefined(Exception):
10 | pass
11 |
12 |
13 | class FunctionNotDefined(Exception):
14 | pass
15 |
16 |
17 | def set_block(opt):
18 | ode_str = opt['block']
19 | if ode_str == 'attention':
20 | block = AttODEblock
21 | elif ode_str == 'constant':
22 | block = ConstantODEblock
23 | else:
24 | raise BlockNotDefined
25 | return block
26 |
27 |
28 | def set_function(opt):
29 | ode_str = opt['function']
30 | if ode_str == 'laplacian':
31 | f = LaplacianODEFunc
32 | elif ode_str == 'GAT':
33 | f = ODEFuncAtt
34 | elif ode_str == 'transformer':
35 | f = ODEFuncTransformerAtt
36 | elif ode_str == 'gcn':
37 | f = GCNFunc
38 | else:
39 | raise FunctionNotDefined
40 | return f
41 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/odeint_geometric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.functional import vjp
3 | from geometric_integrators import SymplecticEuler, Leapfrog
4 | from torchdiffeq._impl.scipy_wrapper import ScipyWrapperODESolver
5 | from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape
6 |
7 | SOLVERS = {
8 | 'symplectic_euler': SymplecticEuler,
9 | 'leapfrog': Leapfrog
10 | }
11 |
12 |
13 | def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None):
14 | """Integrate a system of ordinary differential equations.
15 |
16 | Solves the initial value problem for a non-stiff system of first order ODEs:
17 | ```
18 | dy/dt = func(t, y), y(t[0]) = y0
19 | ```
20 | where y is a Tensor or tuple of Tensors of any shape.
21 |
22 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`.
23 |
24 | Args:
25 | func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y`
26 | into a Tensor of state derivatives with respect to time. Optionally, `y`
27 | can also be a tuple of Tensors.
28 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0`
29 | can also be a tuple of Tensors.
30 | t: 1-D Tensor holding a sequence of time points for which to solve for
31 | `y`, in either increasing or decreasing order. The first element of
32 | this sequence is taken to be the initial time point.
33 | rtol: optional float64 Tensor specifying an upper bound on relative error,
34 | per element of `y`.
35 | atol: optional float64 Tensor specifying an upper bound on absolute error,
36 | per element of `y`.
37 | method: optional string indicating the integration method to use.
38 | options: optional dict of configuring options for the indicated integration
39 | method. Can only be provided if a `method` is explicitly set.
40 | event_fn: Function that maps the state `y` to a Tensor. The solve terminates when
41 | event_fn evaluates to zero. If this is not None, all but the first elements of
42 | `t` are ignored.
43 |
44 | Returns:
45 | y: Tensor, where the first dimension corresponds to different
46 | time points. Contains the solved value of y for each desired time point in
47 | `t`, with the initial value `y0` being the first element along the first
48 | dimension.
49 |
50 | Raises:
51 | ValueError: if an invalid `method` is provided.
52 | """
53 |
54 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS)
55 |
56 | solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options)
57 |
58 | if event_fn is None:
59 | solution = solver.integrate(t)
60 | else:
61 | event_t, solution = solver.integrate_until_event(t[0], event_fn)
62 | event_t = event_t.to(t)
63 | if t_is_reversed:
64 | event_t = -event_t
65 |
66 | if shapes is not None:
67 | solution = _flat_to_shape(solution, (len(t),), shapes)
68 |
69 | if event_fn is None:
70 | return solution
71 | else:
72 | return event_t, solution
--------------------------------------------------------------------------------
/src/homophilic_graphs/regularized_ODE_function.py:
--------------------------------------------------------------------------------
1 | ## This code has been adapted from https://github.com/cfinlay/ffjord-rnode/
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class RegularizedODEfunc(nn.Module):
8 | def __init__(self, odefunc, regularization_fns):
9 | super(RegularizedODEfunc, self).__init__()
10 | self.odefunc = odefunc
11 | self.regularization_fns = regularization_fns
12 |
13 | def before_odeint(self, *args, **kwargs):
14 | self.odefunc.before_odeint(*args, **kwargs)
15 |
16 | def forward(self, t, state):
17 |
18 | with torch.enable_grad():
19 | x = state[0]
20 | x.requires_grad_(True)
21 | t.requires_grad_(True)
22 | dstate = self.odefunc(t, x)
23 | if len(state) > 1:
24 | dx = dstate
25 | reg_states = tuple(reg_fn(x, t, dx, self.odefunc) for reg_fn in self.regularization_fns)
26 | return (dstate,) + reg_states
27 | else:
28 | return dstate
29 |
30 | @property
31 | def _num_evals(self):
32 | return self.odefunc._num_evals
33 |
34 |
35 | def total_derivative(x, t, dx, unused_context):
36 | del unused_context
37 |
38 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0]
39 |
40 | try:
41 | u = torch.full_like(dx, 1 / x.numel(), requires_grad=True)
42 | tmp = torch.autograd.grad((u * dx).sum(), t, create_graph=True)[0]
43 | partial_dt = torch.autograd.grad(tmp.sum(), u, create_graph=True)[0]
44 |
45 | total_deriv = directional_dx + partial_dt
46 | except RuntimeError as e:
47 | if 'One of the differentiated Tensors' in e.__str__():
48 | raise RuntimeError(
49 | 'No partial derivative with respect to time. Use mathematically equivalent "directional_derivative" regularizer instead')
50 |
51 | tdv2 = total_deriv.pow(2).view(x.size(0), -1)
52 |
53 | return 0.5 * tdv2.mean(dim=-1)
54 |
55 |
56 | def directional_derivative(x, t, dx, unused_context):
57 | del t, unused_context
58 |
59 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0]
60 | ddx2 = directional_dx.pow(2).view(x.size(0), -1)
61 |
62 | return 0.5 * ddx2.mean(dim=-1)
63 |
64 |
65 | def quadratic_cost(x, t, dx, unused_context):
66 | del x, t, unused_context
67 | dx = dx.view(dx.shape[0], -1)
68 | return 0.5 * dx.pow(2).mean(dim=-1)
69 |
70 |
71 | def divergence_bf(dx, x):
72 | sum_diag = 0.
73 | for i in range(x.shape[1]):
74 | sum_diag += torch.autograd.grad(dx[:, i].sum(), x, create_graph=True)[0].contiguous()[:, i].contiguous()
75 | return sum_diag.contiguous()
76 |
77 |
78 | def jacobian_frobenius_regularization_fn(x, t, dx, context):
79 | del t
80 | return divergence_bf(dx, x)
81 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/run_GNN.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import torch
5 | from torch_geometric.nn import GCNConv, ChebConv # noqa
6 | import torch.nn.functional as F
7 | from GNN import GNN
8 | from GNN_early import GNNEarly
9 | import time
10 | from data import get_dataset, set_train_val_test_split
11 | from ogb.nodeproppred import Evaluator
12 | from good_params_graphCON import good_params_dict
13 | import wandb
14 |
15 |
16 | def get_optimizer(name, parameters, lr, weight_decay=0):
17 | if name == 'sgd':
18 | return torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay)
19 | elif name == 'rmsprop':
20 | return torch.optim.RMSprop(parameters, lr=lr, weight_decay=weight_decay)
21 | elif name == 'adagrad':
22 | return torch.optim.Adagrad(parameters, lr=lr, weight_decay=weight_decay)
23 | elif name == 'adam':
24 | return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay)
25 | elif name == 'adamax':
26 | return torch.optim.Adamax(parameters, lr=lr, weight_decay=weight_decay)
27 | else:
28 | raise Exception("Unsupported optimizer: {}".format(name))
29 |
30 |
31 | def add_labels(feat, labels, idx, num_classes, device):
32 | onehot = torch.zeros([feat.shape[0], num_classes]).to(device)
33 | if idx.dtype == torch.bool:
34 | idx = torch.where(idx)[0] # convert mask to linear index
35 | onehot[idx, labels.squeeze()[idx]] = 1
36 |
37 | return torch.cat([feat, onehot], dim=-1)
38 |
39 |
40 | def get_label_masks(data, mask_rate=0.5):
41 | """
42 | when using labels as features need to split training nodes into training and prediction
43 | """
44 | if data.train_mask.dtype == torch.bool:
45 | idx = torch.where(data.train_mask)[0]
46 | else:
47 | idx = data.train_mask
48 | mask = torch.rand(idx.shape) < mask_rate
49 | train_label_idx = idx[mask]
50 | train_pred_idx = idx[~mask]
51 | return train_label_idx, train_pred_idx
52 |
53 |
54 | def train(model, optimizer, data):
55 | model.train()
56 | optimizer.zero_grad()
57 | feat = data.x
58 | if model.opt['use_labels']:
59 | train_label_idx, train_pred_idx = get_label_masks(data, model.opt['label_rate'])
60 |
61 | feat = add_labels(feat, data.y, train_label_idx, model.num_classes, model.device)
62 | else:
63 | train_pred_idx = data.train_mask
64 |
65 | out = model(feat)
66 | if model.opt['dataset'] == 'ogbn-arxiv':
67 | lf = torch.nn.functional.nll_loss
68 | loss = lf(out.log_softmax(dim=-1)[data.train_mask], data.y.squeeze(1)[data.train_mask])
69 | else:
70 | lf = torch.nn.CrossEntropyLoss()
71 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask])
72 |
73 | if model.opt['wandb_watch_grad']: # Tell wandb to watch what the model gets up to: gradients, weights, and more!
74 | wandb.watch(model, lf, log="all", log_freq=10)
75 |
76 | if model.odeblock.nreg > 0: # add regularisation - slower for small data, but faster and better performance for large data
77 | reg_states = tuple(torch.mean(rs) for rs in model.reg_states)
78 | regularization_coeffs = model.regularization_coeffs
79 |
80 | reg_loss = sum(
81 | reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0
82 | )
83 | loss = loss + reg_loss
84 |
85 | model.fm.update(model.getNFE())
86 | model.resetNFE()
87 | loss.backward()
88 | optimizer.step()
89 | model.bm.update(model.getNFE())
90 | model.resetNFE()
91 | return loss.item()
92 |
93 |
94 | @torch.no_grad()
95 | def test(model, data, opt=None): # opt required for runtime polymorphism
96 | model.eval()
97 | feat = data.x
98 | if model.opt['use_labels']:
99 | feat = add_labels(feat, data.y, data.train_mask, model.num_classes, model.device)
100 | logits, accs = model(feat), []
101 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
102 | pred = logits[mask].max(1)[1]
103 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
104 | accs.append(acc)
105 | return accs
106 |
107 |
108 | def print_model_params(model):
109 | print(model)
110 | for name, param in model.named_parameters():
111 | if param.requires_grad:
112 | print(name)
113 | print(param.data.shape)
114 |
115 |
116 | @torch.no_grad()
117 | def test_OGB(model, data, opt):
118 | if opt['dataset'] == 'ogbn-arxiv':
119 | name = 'ogbn-arxiv'
120 |
121 | feat = data.x
122 | if model.opt['use_labels']:
123 | feat = add_labels(feat, data.y, data.train_mask, model.num_classes, model.device)
124 |
125 | evaluator = Evaluator(name=name)
126 | model.eval()
127 |
128 | out = model(feat).log_softmax(dim=-1)
129 | y_pred = out.argmax(dim=-1, keepdim=True)
130 |
131 | train_acc = evaluator.eval({
132 | 'y_true': data.y[data.train_mask],
133 | 'y_pred': y_pred[data.train_mask],
134 | })['acc']
135 | valid_acc = evaluator.eval({
136 | 'y_true': data.y[data.val_mask],
137 | 'y_pred': y_pred[data.val_mask],
138 | })['acc']
139 | test_acc = evaluator.eval({
140 | 'y_true': data.y[data.test_mask],
141 | 'y_pred': y_pred[data.test_mask],
142 | })['acc']
143 |
144 | return train_acc, valid_acc, test_acc
145 |
146 |
147 | def merge_cmd_args(cmd_opt, opt):
148 | if cmd_opt['function'] is not None:
149 | opt['function'] = cmd_opt['function']
150 | if cmd_opt['block'] is not None:
151 | opt['block'] = cmd_opt['block']
152 | if cmd_opt['self_loop_weight'] is not None:
153 | opt['self_loop_weight'] = cmd_opt['self_loop_weight']
154 | if cmd_opt['method'] is not None:
155 | opt['method'] = cmd_opt['method']
156 | if cmd_opt['step_size'] != 1:
157 | opt['step_size'] = cmd_opt['step_size']
158 | if cmd_opt['time'] != 1:
159 | opt['time'] = cmd_opt['time']
160 | if cmd_opt['epoch'] != 100:
161 | opt['epoch'] = cmd_opt['epoch']
162 |
163 |
164 | def wandb_setup(opt):
165 | if opt['wandb']:
166 | if opt['use_wandb_offline']:
167 | os.environ["WANDB_MODE"] = "offline"
168 | else:
169 | os.environ["WANDB_MODE"] = "run"
170 | if opt['wandb'] and 'wandb_run_name' in opt.keys():
171 | wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'],
172 | name=opt['wandb_run_name'], reinit=True, config=opt)
173 | else:
174 | wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'],
175 | reinit=True, config=opt)
176 | wandb.define_metric("epoch_step") # Customize axes - https://docs.wandb.ai/guides/track/log
177 | if opt['wandb_track_grad_flow']:
178 | wandb.define_metric("grad_flow_step") # Customize axes - https://docs.wandb.ai/guides/track/log
179 | wandb.define_metric("gf_e*", step_metric="grad_flow_step") # grad_flow_epoch*
180 | opt = wandb.config # access all HPs through wandb.config, so logging matches execution!
181 | else:
182 | os.environ["WANDB_MODE"] = "disabled" # sets as NOOP, saves keep writing: if opt['wandb']:
183 | return opt
184 |
185 |
186 | def main(cmd_opt):
187 | best_opt = good_params_dict[cmd_opt['dataset']]
188 | opt = {**cmd_opt, **best_opt}
189 | merge_cmd_args(cmd_opt, opt)
190 |
191 | dataset = get_dataset(opt, '../../data', opt['not_lcc'])
192 |
193 | opt = wandb_setup(opt)
194 |
195 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
196 | results = []
197 | for rep in range(opt['num_splits']):
198 | model = GNN(opt, dataset, device).to(device) if opt["no_early"] else GNNEarly(opt, dataset, device).to(device)
199 | # print(opt)
200 | if not opt['planetoid_split'] and opt['dataset'] in ['Cora', 'Citeseer', 'Pubmed']:
201 | dataset.data = set_train_val_test_split(np.random.randint(0, 1000), dataset.data,
202 | num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500)
203 | # todo for some reason the submodule parameters inside the attention module don't show up when running on GPU.
204 | data = dataset.data.to(device)
205 | parameters = [p for p in model.parameters() if p.requires_grad]
206 | print_model_params(model)
207 | optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay'])
208 | best_time = val_acc = test_acc = train_acc = best_epoch = 0
209 | this_test = test_OGB if opt['dataset'] == 'ogbn-arxiv' else test
210 | for epoch in range(1, opt['epoch']):
211 | start_time = time.time()
212 |
213 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt)
214 | loss = train(model, optimizer, data)
215 |
216 | if tmp_val_acc > val_acc:
217 | best_epoch = epoch
218 | train_acc = tmp_train_acc
219 | val_acc = tmp_val_acc
220 | test_acc = tmp_test_acc
221 | best_time = opt['time']
222 | if not opt['no_early'] and model.odeblock.test_integrator.solver.best_val > val_acc:
223 | best_epoch = epoch
224 | val_acc = model.odeblock.test_integrator.solver.best_val
225 | test_acc = model.odeblock.test_integrator.solver.best_test
226 | train_acc = model.odeblock.test_integrator.solver.best_train
227 | best_time = model.odeblock.test_integrator.solver.best_time
228 |
229 | if opt['wandb'] and ((epoch) % opt['wandb_log_freq']) == 0 and rep == 0:
230 | wandb.log({"loss": loss,
231 | "train_acc": train_acc, "val_acc": val_acc, "test_acc": test_acc,
232 | "epoch_step": epoch}) # , step=epoch) wandb: WARNING Step must only increase in log calls
233 |
234 | log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Best time: {:.4f}'
235 |
236 | print(
237 | log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, val_acc,
238 | test_acc,
239 | best_time))
240 | print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d} and best time {:03f}'.format(val_acc,
241 | test_acc,
242 | best_epoch,
243 | best_time))
244 | if opt['num_splits'] > 1:
245 | results.append([test_acc, val_acc, train_acc])
246 |
247 | if opt['num_splits'] > 1:
248 | test_acc_mean, val_acc_mean, train_acc_mean = np.mean(results, axis=0) * 100
249 | test_acc_std = np.sqrt(np.var(results, axis=0)[0]) * 100
250 | if opt['wandb']:
251 | wandb_results = {'test_mean': test_acc_mean, 'val_mean': val_acc_mean, 'train_mean': train_acc_mean,
252 | 'test_acc_std': test_acc_std}
253 | wandb.log(wandb_results)
254 | if opt['wandb']:
255 | wandb.finish()
256 | return train_acc, val_acc, test_acc
257 |
258 |
259 | if __name__ == '__main__':
260 | parser = argparse.ArgumentParser()
261 | parser.add_argument('--use_cora_defaults', action='store_true',
262 | help='Whether to run with best params for cora. Overrides the choice of dataset')
263 |
264 | parser.add_argument('--id', type=int)
265 | parser.add_argument('--num_splits', type=int, default=1, help='number of random splits to use')
266 | # data args
267 | parser.add_argument('--dataset', type=str, default='Cora',
268 | help='Cora, Citeseer, Pubmed')
269 | parser.add_argument('--data_norm', type=str, default='rw',
270 | help='rw for random walk, gcn for symmetric gcn norm')
271 | parser.add_argument('--self_loop_weight', type=float, help='Weight of self-loops.')
272 | parser.add_argument('--use_labels', dest='use_labels', action='store_true', help='Also diffuse labels')
273 | parser.add_argument('--label_rate', type=float, default=0.5,
274 | help='% of training labels to use when --use_labels is set.')
275 | parser.add_argument('--planetoid_split', action='store_true',
276 | help='use planetoid splits for Cora/Citeseer/Pubmed')
277 | # GNN args
278 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.')
279 | parser.add_argument('--fc_out', dest='fc_out', action='store_true',
280 | help='Add a fully connected layer to the decoder.')
281 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.')
282 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.')
283 | parser.add_argument("--batch_norm", dest='batch_norm', action='store_true', help='search over reg params')
284 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.')
285 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.')
286 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization')
287 | parser.add_argument('--epoch', type=int, default=100, help='Number of training epochs per iteration.')
288 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.')
289 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha')
290 | parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true',
291 | help='apply sigmoid before multiplying by alpha')
292 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta')
293 | parser.add_argument('--block', type=str, help='constant, mixed, attention, hard_attention, SDE')
294 | parser.add_argument('--function', type=str, help='laplacian, transformer, dorsey, GAT, SDE')
295 | parser.add_argument('--use_mlp', dest='use_mlp', action='store_true',
296 | help='Add a fully connected layer to the encoder.')
297 | parser.add_argument('--add_source', dest='add_source', action='store_true',
298 | help='If try get rid of alpha param and the beta*x0 source term')
299 |
300 | # ODE args
301 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.')
302 | parser.add_argument('--augment', action='store_true',
303 | help='double the length of the feature vector by appending zeros to stabilist ODE learning')
304 | parser.add_argument('--method', type=str,
305 | help="set the numerical solver: dopri5, euler, rk4, midpoint")
306 | parser.add_argument('--step_size', type=float, default=1,
307 | help='fixed step size when using fixed step solvers e.g. rk4')
308 | parser.add_argument('--max_iters', type=float, default=100, help='maximum number of integration steps')
309 | parser.add_argument("--adjoint_method", type=str, default="adaptive_heun",
310 | help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint")
311 | parser.add_argument('--adjoint', dest='adjoint', action='store_true',
312 | help='use the adjoint ODE method to reduce memory footprint')
313 | parser.add_argument('--adjoint_step_size', type=float, default=1,
314 | help='fixed step size when using fixed step adjoint solvers e.g. rk4')
315 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol')
316 | parser.add_argument("--tol_scale_adjoint", type=float, default=1.0,
317 | help="multiplier for adjoint_atol and adjoint_rtol")
318 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run')
319 | parser.add_argument("--max_nfe", type=int, default=10000000,
320 | help="Maximum number of function evaluations in an epoch. Stiff ODEs will hang if not set.")
321 | parser.add_argument("--no_early", action="store_true",
322 | help="Whether or not to use early stopping of the ODE integrator when testing.")
323 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model')
324 | parser.add_argument("--max_test_steps", type=int, default=100,
325 | help="Maximum number steps for the dopri5Early test integrator. "
326 | "used if getting OOM errors at test time")
327 |
328 | # Attention args
329 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2,
330 | help='slope of the negative part of the leaky relu used in attention')
331 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights')
332 | parser.add_argument('--heads', type=int, default=4, help='number of attention heads')
333 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols')
334 | parser.add_argument('--attention_dim', type=int, default=64,
335 | help='the size to project x to before calculating att scores')
336 | parser.add_argument('--mix_features', dest='mix_features', action='store_true',
337 | help='apply a feature transformation xW to the ODE')
338 | parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true',
339 | help="multiply attention scores by edge weights before softmax")
340 | # regularisation args
341 | parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2")
342 | parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2")
343 |
344 | parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2")
345 | parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2")
346 |
347 | # rewiring args
348 | parser.add_argument("--not_lcc", action="store_false", help="don't use the largest connected component")
349 | parser.add_argument('--rewiring', type=str, default=None, help="two_hop, gdc")
350 | parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff")
351 | parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk")
352 | parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk")
353 | parser.add_argument('--gdc_threshold', type=float, default=0.0001,
354 | help="obove this edge weight, keep edges when using threshold")
355 | parser.add_argument('--gdc_avg_degree', type=int, default=64,
356 | help="if gdc_threshold is not given can be calculated by specifying avg degree")
357 | parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability")
358 | parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for")
359 | parser.add_argument('--att_samp_pct', type=float, default=1,
360 | help="float in [0,1). The percentage of edges to retain based on attention scores")
361 | parser.add_argument('--use_flux', dest='use_flux', action='store_true',
362 | help='incorporate the feature grad in attention based edge dropout')
363 | parser.add_argument("--exact", action="store_true",
364 | help="for small datasets can do exact diffusion. If dataset is too big for matrix inversion then you can't")
365 |
366 | # wandb logging and tuning
367 | parser.add_argument('--wandb', action='store_true', help="flag if logging to wandb")
368 | parser.add_argument('-wandb_offline', dest='use_wandb_offline',
369 | action='store_true') # https://docs.wandb.ai/guides/technical-faq
370 |
371 | parser.add_argument('--wandb_sweep', action='store_true',
372 | help="flag if sweeping") # if not it picks up params in greed_params
373 | parser.add_argument('--wandb_watch_grad', action='store_true', help='allows gradient tracking in train function')
374 | parser.add_argument('--wandb_track_grad_flow', action='store_true')
375 |
376 | parser.add_argument('--wandb_entity', default="graphcon", type=str) # not used as default set in web browser settings
377 | parser.add_argument('--wandb_project', default="graphcon", type=str)
378 | parser.add_argument('--wandb_group', default="testing", type=str, help="testing,tuning,eval")
379 | parser.add_argument('--wandb_run_name', default=None, type=str)
380 | parser.add_argument('--wandb_output_dir', default='./wandb_output',
381 | help='folder to output results, images and model checkpoints')
382 | parser.add_argument('--wandb_log_freq', type=int, default=1, help='Frequency to log metrics.')
383 | parser.add_argument('--wandb_epoch_list', nargs='+', default=[0, 1, 2, 4, 8, 16],
384 | help='list of epochs to log gradient flow')
385 |
386 | args = parser.parse_args()
387 |
388 | opt = vars(args)
389 |
390 | main(opt)
391 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/run_best_sweeps.py:
--------------------------------------------------------------------------------
1 | """
2 | Extracts the optimal hyperparameters found from a wandb hyperparameter sweep
3 | """
4 |
5 | import yaml
6 | import argparse
7 | import os
8 | import sys
9 | import time
10 |
11 | import wandb
12 | import torch
13 | import numpy as np
14 | from ray import tune
15 | from functools import partial
16 | from ray.tune import CLIReporter
17 |
18 | from good_params_waveGNN import good_params_dict
19 | from data import get_dataset, set_train_val_test_split
20 | from GNN import GNN
21 | from GNN_early import GNNEarly
22 | from run_GNN import get_optimizer, test, train
23 | from utils import get_sem, mean_confidence_interval
24 |
25 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
26 | from definitions import ROOT_DIR
27 |
28 |
29 | def main(opt, data_dir):
30 | # todo see if I can initialise wandb runs inside of ray processes
31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32 | dataset = get_dataset(opt, data_dir, opt['not_lcc'])
33 |
34 | if opt["num_splits"] > 0:
35 | dataset.data = set_train_val_test_split(
36 | 23 * np.random.randint(0, opt["num_splits"]),
37 | # random prime 23 to make the splits 'more' random. Could remove
38 | dataset.data,
39 | num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500)
40 |
41 | model = GNN(opt, dataset, device) if opt["no_early"] else GNNEarly(opt, dataset, device)
42 | model, data = model.to(device), dataset.data.to(device)
43 | parameters = [p for p in model.parameters() if p.requires_grad]
44 | optimizer = get_optimizer(opt["optimizer"], parameters, lr=opt["lr"], weight_decay=opt["decay"])
45 |
46 | this_test = test
47 | best_time = best_epoch = train_acc = val_acc = test_acc = 0
48 | for epoch in range(1, opt["epoch"]):
49 | loss = train(model, optimizer, data)
50 | if opt["no_early"]:
51 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt)
52 | best_time = opt['time']
53 | else:
54 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt)
55 | if tmp_val_acc > val_acc:
56 | best_epoch = epoch
57 | train_acc = tmp_train_acc
58 | val_acc = tmp_val_acc
59 | test_acc = tmp_test_acc
60 | if model.odeblock.test_integrator.solver.best_val > val_acc:
61 | best_epoch = epoch
62 | val_acc = model.odeblock.test_integrator.solver.best_val
63 | test_acc = model.odeblock.test_integrator.solver.best_test
64 | train_acc = model.odeblock.test_integrator.solver.best_train
65 | best_time = model.odeblock.test_integrator.solver.best_time
66 | tune.report(loss=loss, val_acc=val_acc, test_acc=test_acc, train_acc=train_acc, best_time=best_time,
67 | best_epoch=best_epoch,
68 | forward_nfe=model.fm.sum, backward_nfe=model.bm.sum)
69 | res_dict = {"loss": loss,
70 | "train_acc": train_acc, "val_acc": val_acc, "test_acc": test_acc, "best_time": best_time,
71 | "best_epoch": best_epoch, "epoch_step": epoch}
72 | print(res_dict)
73 |
74 |
75 | def run_best(opt):
76 | opt['wandb_entity'] = 'bchamberlain'
77 | opt['wandb_project'] = 'waveGNN-src_node_level'
78 | opt['wandb_group'] = None
79 | if 'wandb_run_name' in opt.keys():
80 | wandb_run = wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'],
81 | name=opt['wandb_run_name'], reinit=True, config=opt)
82 | else:
83 | wandb_run = wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'],
84 | reinit=True, config=opt)
85 |
86 | wandb.define_metric("epoch_step")
87 |
88 | yaml_path = f"./wandb/sweep-{opt['sweep']}/config-{opt['run']}.yaml"
89 | with open(yaml_path) as f:
90 | yaml_opt = yaml.load(f, Loader=yaml.FullLoader)
91 | temp_opt = {}
92 | for k, v in yaml_opt.items():
93 | if type(v) == dict:
94 | temp_opt[k] = v['value']
95 | else:
96 | temp_opt[k] = v
97 | yaml_opt = temp_opt
98 | dataset = yaml_opt['dataset']
99 |
100 | opt = {**good_params_dict[dataset], **yaml_opt, **opt}
101 | opt['wandb'] = True
102 | opt['use_wandb_offline'] = False
103 | opt['wandb_best_run_id'] = opt['run']
104 | opt['wandb_track_grad_flow'] = False
105 | opt['wandb_watch_grad'] = False
106 |
107 | reporter = CLIReporter(
108 | metric_columns=["val_acc", "loss", "test_acc", "train_acc", "best_time", "best_epoch"])
109 |
110 | result = tune.run(
111 | partial(main, data_dir=f"{ROOT_DIR}/data"),
112 | name=opt['run'],
113 | resources_per_trial={"cpu": opt['cpus'], "gpu": opt['gpus']},
114 | search_alg=None,
115 | keep_checkpoints_num=3,
116 | checkpoint_score_attr='val_acc',
117 | config=opt,
118 | num_samples=opt['reps'] if opt["num_splits"] == 0 else opt["num_splits"] * opt["reps"],
119 | scheduler=None,
120 | max_failures=1, # early stop solver can't recover from failure as it doesn't own m2.
121 | local_dir='../ray_tune',
122 | progress_reporter=reporter,
123 | raise_on_failed_trial=False)
124 |
125 | df = result.dataframe(metric='test_acc', mode="max").sort_values('test_acc', ascending=False)
126 | try:
127 | df.to_csv('../ray_results/{}_{}.csv'.format(opt['run'], time.strftime("%Y%m%d-%H%M%S")))
128 | except:
129 | pass
130 |
131 | print(df[['val_acc', 'test_acc', 'train_acc', 'best_time', 'best_epoch']])
132 |
133 | test_accs = df['test_acc'].values
134 | val_accs = df['val_acc'].values
135 | train_accs = df['train_acc'].values
136 | print("test accuracy {}".format(test_accs))
137 | log = "mean test {:04f}, test std {:04f}, test sem {:04f}, test 95% conf {:04f}"
138 | log_dic = {'test_acc': test_accs.mean(), 'val_acc': val_accs.mean(), 'train_acc': train_accs.mean(),
139 | 'test_std': np.std(test_accs), 'test_sem': get_sem(test_accs),
140 | 'test_95_conf': mean_confidence_interval(test_accs)}
141 | wandb.log(log_dic)
142 | print(log.format(test_accs.mean(), np.std(test_accs), get_sem(test_accs), mean_confidence_interval(test_accs)))
143 |
144 | wandb_run.finish()
145 |
146 |
147 | if __name__ == '__main__':
148 | parser = argparse.ArgumentParser()
149 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.')
150 | parser.add_argument('--sweep', type=str, default=None, help='sweep folder to read', required=True)
151 | parser.add_argument('--run', type=str, default=None, help='the run IDs', required=True)
152 | parser.add_argument('--reps', type=int, default=1, help='the number of random weight initialisations to use')
153 | parser.add_argument('--name', type=str, default=None)
154 | parser.add_argument('--gpus', type=float, default=0, help='number of gpus per trial. Can be fractional')
155 | parser.add_argument('--cpus', type=float, default=1, help='number of cpus per trial. Can be fractional')
156 | parser.add_argument("--num_splits", type=int, default=0, help="Number of random slpits >= 0. 0 for planetoid split")
157 | parser.add_argument("--adjoint", dest='adjoint', action='store_true',
158 | help="use the adjoint ODE method to reduce memory footprint")
159 | parser.add_argument("--max_nfe", type=int, default=5000,
160 | help="Maximum number of function evaluations allowed in an epcoh.")
161 | parser.add_argument("--no_early", action="store_true",
162 | help="Whether or not to use early stopping of the ODE integrator when testing.")
163 |
164 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model')
165 |
166 | args = parser.parse_args()
167 |
168 | opt = vars(args)
169 | run_best(opt)
170 |
--------------------------------------------------------------------------------
/src/homophilic_graphs/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | utility functions
3 | """
4 |
5 | import scipy
6 | from scipy.stats import sem
7 | import numpy as np
8 | import torch
9 | from torch_scatter import scatter_add
10 | from torch_geometric.utils import add_remaining_self_loops
11 | from torch_geometric.utils.num_nodes import maybe_num_nodes
12 | from torch_geometric.utils.convert import to_scipy_sparse_matrix
13 | from sklearn.preprocessing import normalize
14 | from torch_geometric.nn.conv.gcn_conv import gcn_norm
15 |
16 |
17 | class MaxNFEException(Exception): pass
18 |
19 |
20 | def print_model_params(model):
21 | total_num_params = 0
22 | print(model)
23 | for name, param in model.named_parameters():
24 | if param.requires_grad:
25 | print(name)
26 | print(param.data.shape)
27 | total_num_params += param.numel()
28 | print("Model has a total of {} params".format(total_num_params))
29 |
30 |
31 | def adjust_learning_rate(optimizer, lr, epoch, burnin=50):
32 | if epoch <= burnin:
33 | for param_group in optimizer.param_groups:
34 | param_group["lr"] = lr * epoch / burnin
35 |
36 |
37 | def gcn_norm_fill_val(edge_index, edge_weight=None, fill_value=0., num_nodes=None, dtype=None):
38 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
39 |
40 | if edge_weight is None:
41 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
42 | device=edge_index.device)
43 |
44 | if not int(fill_value) == 0:
45 | edge_index, tmp_edge_weight = add_remaining_self_loops(
46 | edge_index, edge_weight, fill_value, num_nodes)
47 | assert tmp_edge_weight is not None
48 | edge_weight = tmp_edge_weight
49 |
50 | row, col = edge_index[0], edge_index[1]
51 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
52 | deg_inv_sqrt = deg.pow_(-0.5)
53 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
54 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
55 |
56 |
57 | def coo2tensor(coo, device=None):
58 | indices = np.vstack((coo.row, coo.col))
59 | i = torch.LongTensor(indices)
60 | values = coo.data
61 | v = torch.FloatTensor(values)
62 | shape = coo.shape
63 | print('adjacency matrix generated with shape {}'.format(shape))
64 | # test
65 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to(device)
66 |
67 |
68 | def get_sym_adj(data, opt, improved=False):
69 | edge_index, edge_weight = gcn_norm( # yapf: disable
70 | data.edge_index, data.edge_attr, data.num_nodes,
71 | improved, opt['self_loop_weight'] > 0, dtype=data.x.dtype)
72 | coo = to_scipy_sparse_matrix(edge_index, edge_weight)
73 | return coo2tensor(coo)
74 |
75 |
76 | def get_rw_adj_old(data, opt):
77 | if opt['self_loop_weight'] > 0:
78 | edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr,
79 | fill_value=opt['self_loop_weight'])
80 | else:
81 | edge_index, edge_weight = data.edge_index, data.edge_attr
82 | coo = to_scipy_sparse_matrix(edge_index, edge_weight)
83 | normed_csc = normalize(coo, norm='l1', axis=0)
84 | return coo2tensor(normed_csc.tocoo())
85 |
86 |
87 | def get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None):
88 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
89 |
90 | if edge_weight is None:
91 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
92 | device=edge_index.device)
93 |
94 | if not fill_value == 0:
95 | edge_index, tmp_edge_weight = add_remaining_self_loops(
96 | edge_index, edge_weight, fill_value, num_nodes)
97 | assert tmp_edge_weight is not None
98 | edge_weight = tmp_edge_weight
99 |
100 | row, col = edge_index[0], edge_index[1]
101 | indices = row if norm_dim == 0 else col
102 | deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes)
103 | deg_inv_sqrt = deg.pow_(-1)
104 | edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices]
105 | return edge_index, edge_weight
106 |
107 |
108 | def mean_confidence_interval(data, confidence=0.95):
109 | """
110 | As number of samples will be < 10 use t-test for the mean confidence intervals
111 | :param data: NDarray of metric means
112 | :param confidence: The desired confidence interval
113 | :return: Float confidence interval
114 | """
115 | if len(data) < 2:
116 | return 0
117 | a = 1.0 * np.array(data)
118 | n = len(a)
119 | _, se = np.mean(a), scipy.stats.sem(a)
120 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
121 | return h
122 |
123 |
124 | def sparse_dense_mul(s, d):
125 | i = s._indices()
126 | v = s._values()
127 | return torch.sparse.FloatTensor(i, v * d, s.size())
128 |
129 |
130 | def get_sem(vec):
131 | """
132 | wrapper around the scipy standard error metric
133 | :param vec: List of metric means
134 | :return:
135 | """
136 | if len(vec) > 1:
137 | retval = sem(vec)
138 | else:
139 | retval = 0.
140 | return retval
141 |
142 |
143 | # Counter of forward and backward passes.
144 | class Meter(object):
145 |
146 | def __init__(self):
147 | self.reset()
148 |
149 | def reset(self):
150 | self.val = None
151 | self.sum = 0
152 | self.cnt = 0
153 |
154 | def update(self, val):
155 | self.val = val
156 | self.sum += val
157 | self.cnt += 1
158 |
159 | def get_average(self):
160 | if self.cnt == 0:
161 | return 0
162 | return self.sum / self.cnt
163 |
164 | def get_value(self):
165 | return self.val
166 |
--------------------------------------------------------------------------------