├── .gitignore ├── README.md ├── imgs ├── graphCON_figure.pdf └── graphCON_figure.png └── src ├── Superpixels ├── best_params.py ├── data_handling.py ├── models.py └── run_GNN.py ├── ZINC ├── data_handling.py ├── models.py └── run_GNN.py ├── definitions.py ├── heterophilic_graphs ├── best_params.py ├── data_handling.py ├── models.py └── run_GNN.py └── homophilic_graphs ├── GNN.py ├── GNN_early.py ├── README.md ├── base_classes.py ├── block_constant.py ├── block_transformer_attention.py ├── data.py ├── early_stop_solver.py ├── experiment_configs ├── gat_pubmed.yaml ├── gcn_cora.yaml ├── gcn_depth_random.yaml ├── gcn_planetoid.yaml └── run_sweeps.sh ├── function_GAT_attention.py ├── function_gcn.py ├── function_laplacian_diffusion.py ├── function_transformer_attention.py ├── geometric_integrators.py ├── geometric_solvers.py ├── good_params_graphCON.py ├── model_configurations.py ├── odeint_geometric.py ├── regularized_ODE_function.py ├── run_GNN.py ├── run_best_sweeps.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | code 2 | data/ 3 | ray_tune/ 4 | __pycache__/ 5 | src/checkpoint 6 | images/ 7 | ray_results/ 8 | models/ 9 | __pycache__/ 10 | .data/ 11 | .vector_cache/ 12 | .idea/ 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Graph-Coupled Oscillator Networks

2 | 3 | This repository contains the implementation to reproduce the numerical experiments 4 | of the **ICML 2022** paper [Graph-Coupled Oscillator Networks](https://arxiv.org/abs/2202.02296) 5 | 6 |

7 | 8 |

9 | 10 | ### Requirements 11 | Main dependencies (with python >= 3.7):
12 | torch==1.9.0
13 | torch-cluster==1.5.9
14 | torch-geometric==2.0.3
15 | torch-scatter==2.0.9
16 | torch-sparse==0.6.12
17 | torch-spline-conv==1.2.1
18 | torchdiffeq==0.2.2 19 | 20 | Commands to install all the dependencies in a new conda environment
21 | *(python 3.7 and cuda 10.2 -- for other cuda versions change accordingly)* 22 | ``` 23 | conda create --name graphCON python=3.7 24 | conda activate graphCON 25 | 26 | pip install ogb pykeops 27 | pip install torch==1.9.0 28 | pip install torchdiffeq -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 29 | 30 | pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 31 | pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 32 | pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 33 | pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html 34 | pip install torch-geometric 35 | pip install wandb 36 | ``` 37 | 38 | ### Run the experiments 39 | To run each experiment, navigate into `src/exp_dir` 40 | (change `exp_dir` to the name of the corresponding experiment directory). 41 | There, simply do 42 | ``` 43 | python run_GNN.py --kwargs 44 | ``` 45 | where kwargs are specified in each individual `run_GNN.py` file. 46 | 47 | ### Dataset and preprocessing 48 | All data gets downloaded and preprocessed automatically and stored in `./data` directory 49 | (which gets automatically created the first time one of the experiments is run). 50 | 51 | ### Usage 52 | GraphCON is a general framework for "stacking" many GNN layers (aka message passing mechanisms) 53 | in order to obtain a deep GNN which overcomes the oversmoothing problem.
54 | 55 | Given any standard GNN layer (such as *GCN* or *GAT*), 56 | GraphCON can be implemented using **PyTorch (geometric)** as simple as that: 57 | ```python 58 | from torch import nn 59 | import torch 60 | import torch.nn.functional as F 61 | 62 | 63 | class GraphCON(nn.Module): 64 | def __init__(self, GNNs, dt=1., alpha=1., gamma=1., dropout=None): 65 | super(GraphCON, self).__init__() 66 | self.dt = dt 67 | self.alpha = alpha 68 | self.gamma = gamma 69 | self.GNNs = GNNs # list of the individual GNN layers 70 | self.dropout = dropout 71 | 72 | def forward(self, X0, Y0, edge_index): 73 | # set initial values of ODEs 74 | X = X0 75 | Y = Y0 76 | # solve ODEs using simple IMEX scheme 77 | for gnn in self.GNNs: 78 | Y = Y + self.dt * (torch.relu(gnn(X, edge_index)) - self.alpha * Y - self.gamma * X) 79 | X = X + self.dt * Y 80 | 81 | if (self.dropout is not None): 82 | Y = F.dropout(Y, self.dropout, training=self.training) 83 | X = F.dropout(X, self.dropout, training=self.training) 84 | 85 | return X, Y 86 | ``` 87 | 88 | A deep GraphCON model using for instance [Kipf & Welling's GCN](https://arxiv.org/abs/1609.02907) 89 | as the underlying message passing mechanism can then be written as 90 | 91 | ```python 92 | from torch_geometric.nn import GCNConv 93 | 94 | 95 | class deep_GNN(nn.Module): 96 | def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1., dropout=None): 97 | super(deep_GNN, self).__init__() 98 | self.enc = nn.Linear(nfeat, nhid) 99 | self.GNNs = nn.ModuleList() 100 | for _ in range(nlayers): 101 | self.GNNs.append(GCNConv(nhid, nhid)) 102 | self.graphcon = GraphCON(self.GNNs, dt, alpha, gamma, dropout) 103 | self.dec = nn.Linear(nhid, nclass) 104 | 105 | def forward(self, x, edge_index): 106 | # compute initial values of ODEs (encode input) 107 | X0 = self.enc(x) 108 | Y0 = X0 109 | # stack GNNs using GraphCON 110 | X, Y = self.graphcon(X0, Y0, edge_index) 111 | # decode X state of GraphCON at final time for output nodes 112 | output = self.dec(X) 113 | return output 114 | ``` 115 | This is just an easy example to demonstrate the **simple usage of GraphCON**. 116 | You can find the full GraphCON models we used in our experiments in the `src` directory. 117 | 118 | # Citation 119 | If you found our work useful in your research, please cite our paper at: 120 | ```bibtex 121 | @article{graphcon, 122 | title={Graph-Coupled Oscillator Networks}, 123 | author={Rusch, T Konstantin and Chamberlain, Benjamin P and Rowbottom, James and Mishra, Siddhartha and Bronstein, Michael M}, 124 | journal={arXiv preprint arXiv:2202.02296}, 125 | year={2022} 126 | } 127 | ``` 128 | (Also consider starring the project on GitHub.) 129 | -------------------------------------------------------------------------------- /imgs/graphCON_figure.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tk-rusch/GraphCON/3326c48bd3631b9f25417455ded9a1a83ccd3e93/imgs/graphCON_figure.pdf -------------------------------------------------------------------------------- /imgs/graphCON_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tk-rusch/GraphCON/3326c48bd3631b9f25417455ded9a1a83ccd3e93/imgs/graphCON_figure.png -------------------------------------------------------------------------------- /src/Superpixels/best_params.py: -------------------------------------------------------------------------------- 1 | best_params_dict = {'GraphCON_GCN': {'lr': 0.00116, 'alpha': 1.396, 'gamma': 0.841, 'nlayers': 10, 'dropout': 0.014151415141514152}, 2 | 'GraphCON_GAT': {'lr': 0.00028, 'alpha': 0.972, 'gamma': 0.348, 'nlayers': 14, 'dropout': 0.04} 3 | } 4 | -------------------------------------------------------------------------------- /src/Superpixels/data_handling.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import MNISTSuperpixels 2 | DATA_PATH = '../../data' 3 | 4 | def get_data(train): 5 | path = '../../data/MNIST' 6 | dataset = MNISTSuperpixels(path,train=train) 7 | return dataset -------------------------------------------------------------------------------- /src/Superpixels/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import GCNConv, GATConv 5 | from torch_geometric.nn import global_mean_pool, global_add_pool 6 | 7 | class GraphCON_GCN(nn.Module): 8 | def __init__(self, nfeat, nhid, nclass, dropout, nlayers, dt=1., alpha=1., gamma=1.): 9 | super(GraphCON_GCN, self).__init__() 10 | self.dropout = dropout 11 | self.nhid = nhid 12 | self.nlayers = nlayers 13 | self.enc = nn.Linear(nfeat,nhid) 14 | self.conv = GCNConv(nhid, nhid) 15 | self.dec = nn.Linear(nhid,nclass) 16 | self.res = nn.Linear(nhid,nhid) 17 | self.dt = dt 18 | self.act_fn = nn.ReLU() 19 | self.alpha = alpha 20 | self.gamma = gamma 21 | 22 | def res_connection(self, X): 23 | res = self.res(X) 24 | return res 25 | 26 | def forward(self, data): 27 | input = torch.cat([data.x,data.pos],dim=-1) 28 | edge_index = data.edge_index 29 | Y = self.enc(input) 30 | X = Y 31 | Y = F.dropout(Y, self.dropout, training=self.training) 32 | X = F.dropout(X, self.dropout, training=self.training) 33 | 34 | for i in range(self.nlayers): 35 | Y = Y + self.dt*(self.act_fn(self.conv(X,edge_index) + self.res_connection(X)) - self.alpha*Y - self.gamma*X) 36 | X = X + self.dt*Y 37 | Y = F.dropout(Y, self.dropout, training=self.training) 38 | X = F.dropout(X, self.dropout, training=self.training) 39 | 40 | X = self.dec(X) 41 | X = global_add_pool(X, data.batch) 42 | 43 | return X.squeeze(-1) 44 | 45 | class GraphCON_GAT(nn.Module): 46 | def __init__(self, nfeat, nhid, nclass, nlayers, dropout, dt=1., alpha=1., gamma=1., nheads=4): 47 | super(GraphCON_GAT, self).__init__() 48 | self.alpha = alpha 49 | self.gamma = gamma 50 | self.dropout = dropout 51 | self.nheads = nheads 52 | self.nhid = nhid 53 | self.nlayers = nlayers 54 | self.act_fn = nn.ReLU() 55 | self.res = nn.Linear(nhid, nheads * nhid) 56 | self.enc = nn.Linear(nfeat,nhid) 57 | self.conv = GATConv(nhid, nhid, heads=nheads) 58 | self.dec = nn.Linear(nhid,nclass) 59 | self.dt = dt 60 | 61 | def res_connection(self, X): 62 | res = self.res(X) 63 | return res 64 | 65 | def forward(self, data): 66 | input = torch.cat([data.x, data.pos], dim=-1) 67 | n_nodes = input.size(0) 68 | edge_index = data.edge_index 69 | Y = self.enc(input) 70 | X = Y 71 | Y = F.dropout(Y, self.dropout, training=self.training) 72 | X = F.dropout(X, self.dropout, training=self.training) 73 | 74 | for i in range(self.nlayers): 75 | Y = Y + self.dt*(F.elu(self.conv(X, edge_index) + self.res_connection(X)).view(n_nodes, -1, self.nheads).mean(dim=-1) - self.alpha*Y - self.gamma*X) 76 | X = X + self.dt*Y 77 | Y = F.dropout(Y, self.dropout, training=self.training) 78 | X = F.dropout(X, self.dropout, training=self.training) 79 | 80 | X = self.dec(X) 81 | X = global_add_pool(X, data.batch) 82 | 83 | return X.squeeze(-1) -------------------------------------------------------------------------------- /src/Superpixels/run_GNN.py: -------------------------------------------------------------------------------- 1 | from data_handling import get_data 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | import torch.optim as optim 4 | from models import * 5 | from torch_geometric.data import DataLoader 6 | from best_params import best_params_dict 7 | import argparse 8 | 9 | def train_GNN(opt): 10 | train_dataset = get_data(train=True) 11 | test_dataset = get_data(train=False) 12 | 13 | train_dataset = train_dataset.shuffle() 14 | val_dataset = train_dataset[:5000] 15 | train_dataset = train_dataset[5000:] 16 | 17 | train_loader = DataLoader(train_dataset, batch_size=opt['batch'], shuffle=True) 18 | val_loader = DataLoader(val_dataset, batch_size=opt['batch'], shuffle=False) 19 | test_loader = DataLoader(test_dataset, batch_size=opt['batch'], shuffle=False) 20 | 21 | epochs = opt['epochs'] 22 | 23 | if opt['model'] == 'GraphCON_GCN': 24 | model = GraphCON_GCN(nfeat=train_dataset.data.num_features+2,nhid=opt['nhid'],nclass=10, 25 | dropout=opt['drop'],nlayers=opt['nlayers'],dt=1., 26 | alpha=opt['alpha'],gamma=opt['gamma']).to(opt['device']) 27 | elif opt['model'] == 'GraphCON_GAT': 28 | model = GraphCON_GAT(nfeat=train_dataset.data.num_features+2, nhid=opt['nhid'], nclass=10, 29 | dropout=opt['drop'], nlayers=opt['nlayers'], dt=1., 30 | alpha=opt['alpha'], gamma=opt['gamma'],nheads=opt['nheads']).to(opt['device']) 31 | 32 | optimizer = optim.Adam(model.parameters(), lr=opt['lr']) 33 | lf = torch.nn.CrossEntropyLoss() 34 | scheduler = ReduceLROnPlateau(optimizer, 'max', factor=opt['reduce_factor']) 35 | 36 | best_eval = 0 37 | 38 | def test(data_loader): 39 | model.eval() 40 | correct = 0 41 | with torch.no_grad(): 42 | for i, data in enumerate(data_loader): 43 | data = data.to(opt['device']) 44 | output = model(data) 45 | pred = output.data.max(1, keepdim=True)[1] 46 | correct += pred.eq(data.y.data.view_as(pred)).sum() 47 | 48 | accuracy = 100. * correct / len(data_loader.dataset) 49 | return accuracy.item() 50 | 51 | for epoch in range(epochs): 52 | model.train() 53 | for i, data in enumerate(train_loader): 54 | data = data.to(opt['device']) 55 | optimizer.zero_grad() 56 | out = model(data) 57 | loss = lf(out, data.y) 58 | loss.backward() 59 | optimizer.step() 60 | 61 | val_acc = test(val_loader) 62 | test_acc = test(test_loader) 63 | 64 | if(val_acc > best_eval): 65 | best_eval = val_acc 66 | best_test_acc = test_acc 67 | 68 | log = 'Epoch: {:03d}, Val: {:.4f}, Test: {:.4f}' 69 | print(log.format(epoch, val_acc, test_acc)) 70 | 71 | scheduler.step(val_acc) 72 | for param_group in optimizer.param_groups: 73 | curr_lr = param_group['lr'] 74 | if(curr_lr<1e-5): 75 | break 76 | 77 | if(epoch > 25 and val_acc < 20.): 78 | break 79 | 80 | print('Final test accuracy: ', best_test_acc) 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description='training parameters') 84 | parser.add_argument('--model', type=str, default='GraphCON_GAT', 85 | help='GraphCON_GCN, GraphCON_GAT') 86 | parser.add_argument('--nhid', type=int, default=256, 87 | help='number of hidden node features') 88 | parser.add_argument('--nlayers', type=int, default=5, 89 | help='number of layers') 90 | parser.add_argument('--alpha', type=float, default=1., 91 | help='alpha parameter of graphCON') 92 | parser.add_argument('--gamma', type=float, default=1., 93 | help='gamma parameter of graphCON') 94 | parser.add_argument('--nheads', type=int, default=4, 95 | help='number of attention heads for GraphCON-GAT') 96 | parser.add_argument('--epochs', type=int, default=1000, 97 | help='max epochs') 98 | parser.add_argument('--batch', type=int, default=32, 99 | help='batch size') 100 | parser.add_argument('--reduce_factor', type=float, default=0.5, 101 | help='reduce factor') 102 | parser.add_argument('--lr', type=float, default=0.001, 103 | help='learning rate') 104 | parser.add_argument('--drop', type=float, default=0.3, 105 | help='dropout rate') 106 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 107 | help='computing device') 108 | 109 | args = parser.parse_args() 110 | cmd_opt = vars(args) 111 | 112 | best_opt = best_params_dict[cmd_opt['model']] 113 | opt = {**cmd_opt, **best_opt} 114 | print(opt) 115 | 116 | train_GNN(opt) 117 | 118 | -------------------------------------------------------------------------------- /src/ZINC/data_handling.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch_geometric.datasets import ZINC 3 | DATA_PATH = '../../data' 4 | 5 | def get_zinc_data(split): 6 | path = '../../data/ZINC' 7 | dataset = ZINC(path,subset=True,split=split) 8 | return dataset -------------------------------------------------------------------------------- /src/ZINC/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch_geometric.nn import GCNConv 4 | from torch_geometric.nn import global_mean_pool, global_add_pool 5 | 6 | class GraphCON_GCN(nn.Module): 7 | def __init__(self, nfeat, nhid, nclass, nlayers, dt=1., alpha=1., gamma=1.): 8 | super(GraphCON_GCN, self).__init__() 9 | self.nlayers = nlayers 10 | self.enc = nn.Linear(nfeat,nhid) 11 | self.res = nn.Linear(nhid,nhid) 12 | self.conv = GCNConv(nhid, nhid) 13 | self.dec = nn.Linear(nhid,nclass) 14 | self.dt = dt 15 | self.alpha = alpha 16 | self.gamma = gamma 17 | 18 | def res_connection(self, X): 19 | ## residual connection 20 | res = self.res(X) - self.conv.lin(X) 21 | return res 22 | 23 | def forward(self, data): 24 | ## the following encoder gives better results than using nn.embedding() 25 | input = data.x.float() 26 | edge_index = data.edge_index 27 | Y = self.enc(input) 28 | X = Y 29 | 30 | for i in range(self.nlayers): 31 | Y = Y + self.dt * (torch.relu(self.conv(X, edge_index) + self.res_connection(X)) - self.alpha * Y - self.gamma * X) 32 | X = X + self.dt * Y 33 | 34 | X = self.dec(X) 35 | X = global_add_pool(X, data.batch).squeeze(-1) 36 | 37 | return X -------------------------------------------------------------------------------- /src/ZINC/run_GNN.py: -------------------------------------------------------------------------------- 1 | from data_handling import get_zinc_data 2 | import numpy as np 3 | import torch 4 | import torch.optim as optim 5 | from models import * 6 | from torch.optim.lr_scheduler import ReduceLROnPlateau 7 | from torch_geometric.data import DataLoader 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='training parameters') 12 | 13 | parser.add_argument('--nhid', type=int, default=220, 14 | help='number of hidden node features') 15 | parser.add_argument('--nlayers', type=int, default=22, 16 | help='number of layers') 17 | parser.add_argument('--alpha', type=float, default=0.215, 18 | help='alpha parameter of graphCON') 19 | parser.add_argument('--gamma', type=float, default=1.115, 20 | help='gamma parameter of graphCON') 21 | parser.add_argument('--epochs', type=int, default=3000, 22 | help='max epochs') 23 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 24 | help='computing device') 25 | parser.add_argument('--batch', type=int, default=128, 26 | help='batch size') 27 | parser.add_argument('--lr', type=float, default=0.00159, 28 | help='learning rate') 29 | parser.add_argument('--reduce_point', type=int, default=20, 30 | help='length of patience') 31 | parser.add_argument('--start_reduce', type=int, default=1000, 32 | help='epoch when to start reducing') 33 | parser.add_argument('--seed', type=int, default=1111, 34 | help='random seed') 35 | 36 | args = parser.parse_args() 37 | print(args) 38 | 39 | 40 | torch.backends.cudnn.deterministic = True 41 | torch.backends.cudnn.benchmark = False 42 | torch.manual_seed(args.seed) 43 | np.random.seed(args.seed) 44 | 45 | train_dataset = get_zinc_data('train') 46 | test_dataset = get_zinc_data('test') 47 | val_dataset = get_zinc_data('val') 48 | 49 | train_loader = DataLoader(train_dataset, batch_size=args.batch,shuffle=True) 50 | test_loader = DataLoader(test_dataset, batch_size=args.batch,shuffle=False) 51 | val_loader = DataLoader(val_dataset, batch_size=args.batch,shuffle=False) 52 | 53 | model = GraphCON_GCN(nfeat=train_dataset.data.num_features,nhid=args.nhid,nclass=1, 54 | nlayers=args.nlayers, dt=1., alpha=args.alpha, 55 | gamma=args.gamma).to(args.device) 56 | 57 | nparams = 0 58 | for p in model.parameters(): 59 | nparams += p.numel() 60 | print('number of parameters: ',nparams) 61 | 62 | optimizer = optim.Adam(model.parameters(),lr=args.lr) 63 | lf = torch.nn.L1Loss() 64 | 65 | patience = 0 66 | best_eval = 1000000 67 | 68 | def test(loader): 69 | model.eval() 70 | error = 0 71 | with torch.no_grad(): 72 | for data in loader: 73 | data = data.to(args.device) 74 | output = model(data) 75 | error += (output - data.y).abs().sum().item() 76 | return error / len(loader.dataset) 77 | 78 | for epoch in range(args.epochs): 79 | model.train() 80 | for i, data in enumerate(train_loader): 81 | data = data.to(args.device) 82 | optimizer.zero_grad() 83 | out = model(data) 84 | loss = lf(out,data.y) 85 | loss.backward() 86 | optimizer.step() 87 | 88 | val_loss = test(val_loader) 89 | 90 | f = open('zinc_graphcon_gcn_log.txt', 'a') 91 | f.write('validation loss: ' + str(val_loss) + '\n') 92 | f.close() 93 | 94 | print('epoch: ',epoch,'validation loss: ',val_loss) 95 | 96 | if (val_loss < best_eval): 97 | best_eval = val_loss 98 | best_test_loss = test(test_loader) 99 | 100 | elif (val_loss >= best_eval and (epoch + 1) >= args.start_reduce): 101 | patience += 1 102 | 103 | if (epoch + 1) >= args.start_reduce and patience == args.reduce_point: 104 | patience = 0 105 | args.lr /= 2. 106 | for param_group in optimizer.param_groups: 107 | param_group['lr'] = args.lr 108 | 109 | if (epoch > 25): 110 | if (val_loss > 20. or args.lr < 1e-5): 111 | break 112 | 113 | f = open('zinc_graphcon_gcn_log.txt', 'a') 114 | f.write('final test loss: ' + str(best_test_loss) + '\n') 115 | f.close() 116 | print('final test loss: ',best_test_loss) -------------------------------------------------------------------------------- /src/definitions.py: -------------------------------------------------------------------------------- 1 | import os 2 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /src/heterophilic_graphs/best_params.py: -------------------------------------------------------------------------------- 1 | best_params_dict = {'cornell': {'model': 'GraphCON_GCN', 'lr': 0.00721, 'nhid': 256, 'alpha': 0, 'gamma': 0, 'nlayers': 1, 'dropout': 0.15, 'weight_decay': 0.0012708787092020595, 'res_version': 1}, 2 | 'wisconsin': {'model': 'GraphCON_GCN', 'lr': 0.00356, 'nhid': 64, 'alpha': 0, 'gamma': 0, 'nlayers': 2, 'dropout': 0.23, 'weight_decay': 0.008126619200091946, 'res_version': 2}, 3 | 'texas': {'model': 'GraphCON_GCN', 'lr': 0.00155, 'nhid': 256, 'alpha': 0, 'gamma': 0, 'nlayers': 2, 'dropout': 0.68, 'weight_decay': 0.0008549327066268375, 'res_version': 2} 4 | } -------------------------------------------------------------------------------- /src/heterophilic_graphs/data_handling.py: -------------------------------------------------------------------------------- 1 | from torch_geometric.datasets import WebKB 2 | import torch 3 | import numpy as np 4 | 5 | DATA_PATH = '../../data' 6 | 7 | def get_data(name, split=0): 8 | path = '../../data/'+name 9 | dataset = WebKB(path,name=name) 10 | 11 | data = dataset[0] 12 | splits_file = np.load(f'{path}/{name}/raw/{name}_split_0.6_0.2_{split}.npz') 13 | train_mask = splits_file['train_mask'] 14 | val_mask = splits_file['val_mask'] 15 | test_mask = splits_file['test_mask'] 16 | 17 | data.train_mask = torch.tensor(train_mask, dtype=torch.bool) 18 | data.val_mask = torch.tensor(val_mask, dtype=torch.bool) 19 | data.test_mask = torch.tensor(test_mask, dtype=torch.bool) 20 | 21 | return data 22 | -------------------------------------------------------------------------------- /src/heterophilic_graphs/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch_geometric.nn import GCNConv, GATConv 6 | 7 | class GraphCON_GCN(nn.Module): 8 | def __init__(self, nfeat, nhid, nclass, dropout, nlayers, dt=1., alpha=1., gamma=1., res_version=1): 9 | super(GraphCON_GCN, self).__init__() 10 | self.dropout = dropout 11 | self.nhid = nhid 12 | self.nlayers = nlayers 13 | self.enc = nn.Linear(nfeat,nhid) 14 | self.conv = GCNConv(nhid, nhid) 15 | self.dec = nn.Linear(nhid,nclass) 16 | self.res = nn.Linear(nhid,nhid) 17 | if(res_version==1): 18 | self.residual = self.res_connection_v1 19 | else: 20 | self.residual = self.res_connection_v2 21 | self.dt = dt 22 | self.act_fn = nn.ReLU() 23 | self.alpha = alpha 24 | self.gamma = gamma 25 | self.reset_params() 26 | 27 | def reset_params(self): 28 | for name, param in self.named_parameters(): 29 | if 'weight' in name and 'emb' not in name and 'out' not in name: 30 | stdv = 1. / math.sqrt(self.nhid) 31 | param.data.uniform_(-stdv, stdv) 32 | 33 | def res_connection_v1(self, X): 34 | res = - self.res(self.conv.lin(X)) 35 | return res 36 | 37 | def res_connection_v2(self, X): 38 | res = - self.conv.lin(X) + self.res(X) 39 | return res 40 | 41 | def forward(self, data): 42 | input = data.x 43 | edge_index = data.edge_index 44 | input = F.dropout(input, self.dropout, training=self.training) 45 | Y = self.act_fn(self.enc(input)) 46 | X = Y 47 | Y = F.dropout(Y, self.dropout, training=self.training) 48 | X = F.dropout(X, self.dropout, training=self.training) 49 | 50 | for i in range(self.nlayers): 51 | Y = Y + self.dt*(self.act_fn(self.conv(X,edge_index) + self.residual(X)) - self.alpha*Y - self.gamma*X) 52 | X = X + self.dt*Y 53 | Y = F.dropout(Y, self.dropout, training=self.training) 54 | X = F.dropout(X, self.dropout, training=self.training) 55 | 56 | X = self.dec(X) 57 | 58 | return X 59 | 60 | class GraphCON_GAT(nn.Module): 61 | def __init__(self, nfeat, nhid, nclass, nlayers, dropout, dt=1., alpha=1., gamma=1., nheads=4): 62 | super(GraphCON_GAT, self).__init__() 63 | self.alpha = alpha 64 | self.gamma = gamma 65 | self.dropout = dropout 66 | self.nheads = nheads 67 | self.nhid = nhid 68 | self.nlayers = nlayers 69 | self.act_fn = nn.ReLU() 70 | self.res = nn.Linear(nhid, nheads * nhid) 71 | self.enc = nn.Linear(nfeat,nhid) 72 | self.conv = GATConv(nhid, nhid, heads=nheads) 73 | self.dec = nn.Linear(nhid,nclass) 74 | self.dt = dt 75 | 76 | def res_connection(self, X): 77 | res = self.res(X) 78 | return res 79 | 80 | def forward(self, data): 81 | input = data.x 82 | n_nodes = input.size(0) 83 | edge_index = data.edge_index 84 | input = F.dropout(input, self.dropout, training=self.training) 85 | Y = self.act_fn(self.enc(input)) 86 | X = Y 87 | Y = F.dropout(Y, self.dropout, training=self.training) 88 | X = F.dropout(X, self.dropout, training=self.training) 89 | 90 | for i in range(self.nlayers): 91 | Y = Y + self.dt*(F.elu(self.conv(X, edge_index) + self.res_connection(X)).view(n_nodes, -1, self.nheads).mean(dim=-1) - self.alpha*Y - self.gamma*X) 92 | X = X + self.dt*Y 93 | Y = F.dropout(Y, self.dropout, training=self.training) 94 | X = F.dropout(X, self.dropout, training=self.training) 95 | 96 | X = self.dec(X) 97 | 98 | return X 99 | -------------------------------------------------------------------------------- /src/heterophilic_graphs/run_GNN.py: -------------------------------------------------------------------------------- 1 | from data_handling import get_data 2 | import numpy as np 3 | import torch.optim as optim 4 | from models import * 5 | from torch import nn 6 | from best_params import best_params_dict 7 | 8 | import argparse 9 | 10 | def train_GNN(opt,split): 11 | data = get_data(opt['dataset'],split) 12 | 13 | best_eval = 10000 14 | bad_counter = 0 15 | best_test_acc = 0 16 | 17 | if opt['model'] == 'GraphCON_GCN': 18 | model = GraphCON_GCN(nfeat=data.num_features,nhid=opt['nhid'],nclass=5, 19 | dropout=opt['drop'],nlayers=opt['nlayers'],dt=1., 20 | alpha=opt['alpha'],gamma=opt['gamma'],res_version=opt['res_version']).to(opt['device']) 21 | elif opt['model'] == 'GraphCON_GAT': 22 | model = GraphCON_GAT(nfeat=data.num_features, nhid=opt['nhid'], nclass=5, 23 | dropout=opt['drop'], nlayers=opt['nlayers'], dt=1., 24 | alpha=opt['alpha'], gamma=opt['gamma'],nheads=opt['nheads']).to(opt['device']) 25 | 26 | optimizer = optim.Adam(model.parameters(),lr=opt['lr'],weight_decay=opt['weight_decay']) 27 | lf = nn.CrossEntropyLoss() 28 | 29 | @torch.no_grad() 30 | def test(model, data): 31 | model.eval() 32 | logits, accs, losses = model(data), [], [] 33 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 34 | loss = lf(out[mask], data.y.squeeze()[mask]) 35 | pred = logits[mask].max(1)[1] 36 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 37 | accs.append(acc) 38 | losses.append(loss.item()) 39 | return accs, losses 40 | 41 | for epoch in range(opt['epochs']): 42 | model.train() 43 | optimizer.zero_grad() 44 | out = model(data.to(opt['device'])) 45 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask]) 46 | loss.backward() 47 | optimizer.step() 48 | 49 | [train_acc, val_acc, test_acc], [train_loss, val_loss, test_loss] = test(model,data) 50 | 51 | if (val_loss < best_eval): 52 | best_eval = val_loss 53 | best_test_acc = test_acc 54 | else: 55 | bad_counter += 1 56 | 57 | if(bad_counter==opt['patience']): 58 | break 59 | 60 | log = 'Split: {:01d}, Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 61 | print(log.format(split, epoch, train_acc, val_acc, test_acc)) 62 | 63 | return best_test_acc 64 | 65 | if __name__ == '__main__': 66 | parser = argparse.ArgumentParser(description='training parameters') 67 | parser.add_argument('--dataset', type=str, default='texas', 68 | help='cornell, wisconsin, texas') 69 | parser.add_argument('--model', type=str, default='GraphCON_GCN', 70 | help='GraphCON_GCN, GraphCON_GAT') 71 | parser.add_argument('--nhid', type=int, default=64, 72 | help='number of hidden node features') 73 | parser.add_argument('--nlayers', type=int, default=5, 74 | help='number of layers') 75 | parser.add_argument('--alpha', type=float, default=1., 76 | help='alpha parameter of graphCON') 77 | parser.add_argument('--gamma', type=float, default=1., 78 | help='gamma parameter of graphCON') 79 | parser.add_argument('--nheads', type=int, default=4, 80 | help='number of attention heads for GraphCON-GAT') 81 | parser.add_argument('--epochs', type=int, default=1500, 82 | help='max epochs') 83 | parser.add_argument('--patience', type=int, default=100, 84 | help='patience') 85 | parser.add_argument('--lr', type=float, default=0.001, 86 | help='learning rate') 87 | parser.add_argument('--drop', type=float, default=0.3, 88 | help='dropout rate') 89 | parser.add_argument('--res_version', type=int, default=1, 90 | help='version of residual connection') 91 | parser.add_argument('--weight_decay', type=float, default=1e-5, 92 | help='weight_decay') 93 | parser.add_argument('--device', type=str, default=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 94 | help='computing device') 95 | 96 | args = parser.parse_args() 97 | cmd_opt = vars(args) 98 | 99 | best_opt = best_params_dict[cmd_opt['dataset']] 100 | opt = {**cmd_opt, **best_opt} 101 | print(opt) 102 | 103 | n_splits = 10 104 | 105 | best = [] 106 | for split in range(n_splits): 107 | best.append(train_GNN(opt,split)) 108 | print('Mean test accuracy: ', np.mean(np.array(best)*100),'std: ', np.std(np.array(best)*100)) 109 | 110 | -------------------------------------------------------------------------------- /src/homophilic_graphs/GNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from base_classes import BaseGNN 5 | from model_configurations import set_block, set_function 6 | 7 | 8 | # Define the GNN model. 9 | class GNN(BaseGNN): 10 | def __init__(self, opt, dataset, device=torch.device('cpu')): 11 | super(GNN, self).__init__(opt, dataset, device) 12 | self.f = set_function(opt) 13 | block = set_block(opt) 14 | time_tensor = torch.tensor([0, self.T]).to(device) 15 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device) 16 | 17 | def forward(self, x): 18 | # Encode each node based on its feature. 19 | if self.opt['use_labels']: 20 | y = x[:, self.num_features:] 21 | x = x[:, :self.num_features] 22 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 23 | x = self.m1(x) 24 | if self.opt['use_mlp']: 25 | x = F.dropout(x, self.opt['dropout'], training=self.training) 26 | x = F.dropout(x + self.m11(F.relu(x)), self.opt['dropout'], training=self.training) 27 | x = F.dropout(x + self.m12(F.relu(x)), self.opt['dropout'], training=self.training) 28 | # todo investigate if some input non-linearity solves the problem with smooth deformations identified in the ANODE paper 29 | 30 | if self.opt['use_labels']: 31 | x = torch.cat([x, y], dim=-1) 32 | 33 | if self.opt['batch_norm']: 34 | x = self.bn_in(x) 35 | 36 | # Solve the initial value problem of the ODE. 37 | if self.opt['augment']: 38 | c_aux = torch.zeros(x.shape).to(self.device) 39 | x = torch.cat([x, c_aux], dim=1) 40 | 41 | self.odeblock.set_x0(x) 42 | 43 | if self.training and self.odeblock.nreg > 0: 44 | z, self.reg_states = self.odeblock(x) 45 | else: 46 | z = self.odeblock(x) 47 | 48 | if self.opt['augment']: 49 | z = torch.split(z, x.shape[1] // 2, dim=1)[0] 50 | 51 | # Activation. 52 | z = F.relu(z) 53 | 54 | if self.opt['fc_out']: 55 | z = self.fc(z) 56 | z = F.relu(z) 57 | 58 | # Dropout. 59 | z = F.dropout(z, self.opt['dropout'], training=self.training) 60 | 61 | # Decode each node embedding to get node label. 62 | z = self.m2(z) 63 | return z 64 | -------------------------------------------------------------------------------- /src/homophilic_graphs/GNN_early.py: -------------------------------------------------------------------------------- 1 | """ 2 | A GNN used at test time that supports early stopping during the integrator 3 | """ 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | import argparse 8 | from torch_geometric.nn import GCNConv, ChebConv # noqa 9 | import time 10 | from data import get_dataset 11 | # from run_GNN import get_optimizer, train, test 12 | from early_stop_solver import EarlyStopInt 13 | from base_classes import BaseGNN 14 | from model_configurations import set_block, set_function 15 | 16 | 17 | class GNNEarly(BaseGNN): 18 | def __init__(self, opt, dataset, device=torch.device('cpu')): 19 | super(GNNEarly, self).__init__(opt, dataset, device) 20 | self.f = set_function(opt) 21 | block = set_block(opt) 22 | time_tensor = torch.tensor([0, self.T]).to(device) 23 | # self.regularization_fns = () 24 | self.odeblock = block(self.f, self.regularization_fns, opt, dataset.data, device, t=time_tensor).to(device) 25 | # overwrite the test integrator with this custom one 26 | self.odeblock.test_integrator = EarlyStopInt(self.T, opt, device) 27 | # if opt['adjoint']: 28 | # from torchdiffeq import odeint_adjoint as odeint 29 | # else: 30 | # from torchdiffeq import odeint 31 | # self.odeblock.train_integrator = odeint 32 | 33 | self.set_solver_data(dataset.data) 34 | 35 | def set_solver_m2(self): 36 | if self.odeblock.test_integrator.m2 is None: 37 | self.odeblock.test_integrator.m2 = self.m2 38 | else: 39 | self.odeblock.test_integrator.m2.weight.data = self.m2.weight.data 40 | self.odeblock.test_integrator.m2.bias.data = self.m2.bias.data 41 | 42 | def set_solver_data(self, data): 43 | self.odeblock.test_integrator.data = data 44 | 45 | def forward(self, x): 46 | # Encode each node based on its feature. 47 | if self.opt['use_labels']: 48 | y = x[:, self.num_features:] 49 | x = x[:, :self.num_features] 50 | x = F.dropout(x, self.opt['input_dropout'], training=self.training) 51 | x = self.m1(x) 52 | 53 | if self.opt['use_labels']: 54 | x = torch.cat([x, y], dim=-1) 55 | 56 | if self.opt['batch_norm']: 57 | x = self.bn_in(x) 58 | 59 | # Solve the initial value problem of the ODE. 60 | if self.opt['augment']: 61 | c_aux = torch.zeros(x.shape).to(self.device) 62 | x = torch.cat([x, c_aux], dim=1) 63 | 64 | x = torch.cat([x, x], dim=-1) 65 | self.odeblock.set_x0(x) 66 | self.set_solver_m2() 67 | if self.training and self.odeblock.nreg > 0: 68 | z, self.reg_states = self.odeblock(x) 69 | else: 70 | z = self.odeblock(x) 71 | 72 | if self.opt['augment']: 73 | z = torch.split(z, x.shape[1] // 2, dim=1)[0] 74 | 75 | z = z[:,self.opt['hidden_dim']:] 76 | 77 | # Activation. 78 | z = F.relu(z) 79 | 80 | if self.opt['fc_out']: 81 | z = self.fc(z) 82 | z = F.relu(z) 83 | 84 | # Dropout. 85 | z = F.dropout(z, self.opt['dropout'], training=self.training) 86 | 87 | # Decode each node embedding to get node label. 88 | z = self.m2(z) 89 | return z 90 | 91 | 92 | def main(opt): 93 | dataset = get_dataset(opt, '../data', False) 94 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 95 | model, data = GNNEarly(opt, dataset, device).to(device), dataset.data.to(device) 96 | print(opt) 97 | # todo for some reason the submodule parameters inside the attention module don't show up when running on GPU. 98 | parameters = [p for p in model.parameters() if p.requires_grad] 99 | optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay']) 100 | best_val_acc = test_acc = best_epoch = 0 101 | best_val_acc_int = best_test_acc_int = best_epoch_int = 0 102 | for epoch in range(1, opt['epoch']): 103 | start_time = time.time() 104 | loss = train(model, optimizer, data) 105 | train_acc, val_acc, tmp_test_acc = test(model, data) 106 | val_acc_int = model.odeblock.test_integrator.solver.best_val 107 | tmp_test_acc_int = model.odeblock.test_integrator.solver.best_test 108 | # store best stuff inside integrator forward pass 109 | if val_acc_int > best_val_acc_int: 110 | best_val_acc_int = val_acc_int 111 | test_acc_int = tmp_test_acc_int 112 | best_epoch_int = epoch 113 | # store best stuff at the end of integrator forward pass 114 | if val_acc > best_val_acc: 115 | best_val_acc = val_acc 116 | test_acc = tmp_test_acc 117 | best_epoch = epoch 118 | log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 119 | print( 120 | log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, val_acc, tmp_test_acc)) 121 | log = 'Performance inside integrator Val: {:.4f}, Test: {:.4f}' 122 | print(log.format(val_acc_int, tmp_test_acc_int)) 123 | # print( 124 | # log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, best_val_acc, test_acc)) 125 | print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc, test_acc, best_epoch)) 126 | print('best in integrator val accuracy {:03f} with test accuracy {:03f} at epoch {:d}'.format(best_val_acc_int, 127 | test_acc_int, 128 | best_epoch_int)) 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument('--use_cora_defaults', action='store_true', 134 | help='Whether to run with best params for cora. Overrides the choice of dataset') 135 | parser.add_argument('--dataset', type=str, default='Cora', 136 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS') 137 | parser.add_argument('--data_norm', type=str, default='rw', 138 | help='rw for random walk, gcn for symmetric gcn norm') 139 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.') 140 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.') 141 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') 142 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.') 143 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') 144 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization') 145 | parser.add_argument('--self_loop_weight', type=float, default=1.0, help='Weight of self-loops.') 146 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.') 147 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.') 148 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.') 149 | parser.add_argument('--augment', action='store_true', 150 | help='double the length of the feature vector by appending zeros to stabilist ODE learning') 151 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha') 152 | parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true', help='apply sigmoid before multiplying by alpha') 153 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta') 154 | parser.add_argument('--block', type=str, default='constant', help='constant, mixed, attention, SDE') 155 | parser.add_argument('--function', type=str, default='laplacian', help='laplacian, transformer, dorsey, GAT, SDE') 156 | # ODE args 157 | parser.add_argument('--method', type=str, default='dopri5', 158 | help="set the numerical solver: dopri5, euler, rk4, midpoint") 159 | parser.add_argument('--step_size', type=float, default=1, help='fixed step size when using fixed step solvers e.g. rk4') 160 | parser.add_argument('--max_iters', type=int, default=100, 161 | help='fixed step size when using fixed step solvers e.g. rk4') 162 | parser.add_argument( 163 | "--adjoint_method", type=str, default="adaptive_heun", 164 | help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint" 165 | ) 166 | parser.add_argument('--adjoint', dest='adjoint', action='store_true', help='use the adjoint ODE method to reduce memory footprint') 167 | parser.add_argument('--adjoint_step_size', type=float, default=1, help='fixed step size when using fixed step adjoint solvers e.g. rk4') 168 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol') 169 | parser.add_argument("--tol_scale_adjoint", type=float, default=1.0, 170 | help="multiplier for adjoint_atol and adjoint_rtol") 171 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run') 172 | parser.add_argument('--add_source', dest='add_source', action='store_true', 173 | help='If try get rid of alpha param and the beta*x0 source term') 174 | # SDE args 175 | parser.add_argument('--dt_min', type=float, default=1e-5, help='minimum timestep for the SDE solver') 176 | parser.add_argument('--dt', type=float, default=1e-3, help='fixed step size') 177 | parser.add_argument('--adaptive', dest='adaptive', action='store_true', help='use adaptive step sizes') 178 | # Attention args 179 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2, 180 | help='slope of the negative part of the leaky relu used in attention') 181 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights') 182 | parser.add_argument('--heads', type=int, default=4, help='number of attention heads') 183 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols') 184 | parser.add_argument('--attention_dim', type=int, default=64, 185 | help='the size to project x to before calculating att scores') 186 | parser.add_argument('--mix_features', dest='mix_features', action='store_true', 187 | help='apply a feature transformation xW to the ODE') 188 | parser.add_argument("--max_nfe", type=int, default=1000, help="Maximum number of function evaluations in an epoch. Stiff ODEs will hang if not set.") 189 | parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true', help="multiply attention scores by edge weights before softmax") 190 | # regularisation args 191 | parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2") 192 | parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2") 193 | 194 | parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2") 195 | parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2") 196 | 197 | # rewiring args 198 | parser.add_argument('--rewiring', type=str, default=None, help="two_hop, gdc") 199 | parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff") 200 | parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk") 201 | parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk") 202 | parser.add_argument('--gdc_threshold', type=float, default=0.0001, help="obove this edge weight, keep edges when using threshold") 203 | parser.add_argument('--gdc_avg_degree', type=int, default=64, 204 | help="if gdc_threshold is not given can be calculated by specifying avg degree") 205 | parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability") 206 | parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for") 207 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model') 208 | 209 | args = parser.parse_args() 210 | 211 | opt = vars(args) 212 | 213 | main(opt) 214 | -------------------------------------------------------------------------------- /src/homophilic_graphs/README.md: -------------------------------------------------------------------------------- 1 | Code adapted from [GRAND](https://github.com/twitter-research/graph-neural-pde). 2 | To run the experiments, simply do: 3 | 4 | ``` 5 | python run_GNN.py --dataset 6 | ``` 7 | 8 | where is either `Cora`, `Citeseer` or `Pubmed`. 9 | -------------------------------------------------------------------------------- /src/homophilic_graphs/base_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn.conv import MessagePassing 4 | from utils import Meter 5 | from regularized_ODE_function import RegularizedODEfunc 6 | import regularized_ODE_function as reg_lib 7 | import six 8 | 9 | 10 | REGULARIZATION_FNS = { 11 | "kinetic_energy": reg_lib.quadratic_cost, 12 | "jacobian_norm2": reg_lib.jacobian_frobenius_regularization_fn, 13 | "total_deriv": reg_lib.total_derivative, 14 | "directional_penalty": reg_lib.directional_derivative 15 | } 16 | 17 | 18 | def create_regularization_fns(args): 19 | regularization_fns = [] 20 | regularization_coeffs = [] 21 | 22 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): 23 | if args[arg_key] is not None: 24 | regularization_fns.append(reg_fn) 25 | regularization_coeffs.append(args[arg_key]) 26 | 27 | regularization_fns = regularization_fns 28 | regularization_coeffs = regularization_coeffs 29 | return regularization_fns, regularization_coeffs 30 | 31 | 32 | class ODEblock(nn.Module): 33 | def __init__(self, odefunc, regularization_fns, opt, data, device, t): 34 | super(ODEblock, self).__init__() 35 | self.opt = opt 36 | self.t = t 37 | 38 | self.aug_dim = 2 if opt['augment'] else 1 39 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 40 | 41 | self.nreg = len(regularization_fns) 42 | self.reg_odefunc = RegularizedODEfunc(self.odefunc, regularization_fns) 43 | 44 | if opt['adjoint']: 45 | from torchdiffeq import odeint_adjoint as odeint 46 | else: 47 | from torchdiffeq import odeint 48 | self.train_integrator = odeint 49 | self.test_integrator = None 50 | self.set_tol() 51 | 52 | def set_x0(self, x0): 53 | self.odefunc.x0 = x0.clone().detach() 54 | self.reg_odefunc.odefunc.x0 = x0.clone().detach() 55 | 56 | def set_tol(self): 57 | self.atol = self.opt['tol_scale'] * 1e-7 58 | self.rtol = self.opt['tol_scale'] * 1e-9 59 | if self.opt['adjoint']: 60 | self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7 61 | self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9 62 | 63 | def reset_tol(self): 64 | self.atol = 1e-7 65 | self.rtol = 1e-9 66 | self.atol_adjoint = 1e-7 67 | self.rtol_adjoint = 1e-9 68 | 69 | def set_time(self, time): 70 | self.t = torch.tensor([0, time]).to(self.device) 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 74 | + ")" 75 | 76 | 77 | class ODEFunc(MessagePassing): 78 | 79 | # currently requires in_features = out_features 80 | def __init__(self, opt, data, device): 81 | super(ODEFunc, self).__init__() 82 | self.opt = opt 83 | self.device = device 84 | self.edge_index = None 85 | self.edge_weight = None 86 | self.attention_weights = None 87 | self.attention_weights_2 = None 88 | self.alpha_train = nn.Parameter(torch.tensor(0.0))#nn.Parameter(torch.zeros(opt['hidden_dim'])) 89 | self.beta_train = nn.Parameter(torch.tensor(0.0)) 90 | self.beta_train2 = nn.Parameter(torch.tensor(0.0)) 91 | self.x0 = None 92 | self.nfe = 0 93 | self.alpha_sc = nn.Parameter(torch.ones(1)) 94 | self.beta_sc = nn.Parameter(torch.ones(1)) 95 | 96 | def __repr__(self): 97 | return self.__class__.__name__ 98 | 99 | 100 | class BaseGNN(MessagePassing): 101 | def __init__(self, opt, dataset, device=torch.device('cpu')): 102 | super(BaseGNN, self).__init__() 103 | self.opt = opt 104 | self.T = opt['time'] 105 | self.num_classes = dataset.num_classes 106 | self.num_features = dataset.data.num_features 107 | self.device = device 108 | self.fm = Meter() 109 | self.bm = Meter() 110 | self.m1 = nn.Linear(dataset.data.num_features, opt['hidden_dim']) 111 | if self.opt['use_mlp']: 112 | self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 113 | self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 114 | if opt['use_labels']: 115 | # todo - fastest way to propagate this everywhere, but error prone - refactor later 116 | opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes 117 | else: 118 | self.hidden_dim = opt['hidden_dim'] 119 | if opt['fc_out']: 120 | self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 121 | self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes) 122 | if self.opt['batch_norm']: 123 | self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim']) 124 | self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim']) 125 | 126 | self.regularization_fns, self.regularization_coeffs = create_regularization_fns(self.opt) 127 | 128 | def getNFE(self): 129 | return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe 130 | 131 | def resetNFE(self): 132 | self.odeblock.odefunc.nfe = 0 133 | self.odeblock.reg_odefunc.odefunc.nfe = 0 134 | 135 | def reset(self): 136 | self.m1.reset_parameters() 137 | self.m2.reset_parameters() 138 | 139 | def __repr__(self): 140 | return self.__class__.__name__ 141 | -------------------------------------------------------------------------------- /src/homophilic_graphs/block_constant.py: -------------------------------------------------------------------------------- 1 | from base_classes import ODEblock 2 | import torch 3 | from utils import get_rw_adj, gcn_norm_fill_val 4 | 5 | class ConstantODEblock(ODEblock): 6 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])): 7 | super(ConstantODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 8 | 9 | self.aug_dim = 2 if opt['augment'] else 1 10 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 11 | if opt['data_norm'] == 'rw': 12 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 13 | fill_value=opt['self_loop_weight'], 14 | num_nodes=data.num_nodes, 15 | dtype=data.x.dtype) 16 | else: 17 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr, 18 | fill_value=opt['self_loop_weight'], 19 | num_nodes=data.num_nodes, 20 | dtype=data.x.dtype) 21 | self.odefunc.edge_index = edge_index.to(device) 22 | self.odefunc.edge_weight = edge_weight.to(device) 23 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 24 | 25 | if (opt['method'] == 'symplectic_euler' or opt['method'] == 'leapfrog'): 26 | from odeint_geometric import odeint 27 | elif opt['adjoint']: 28 | from torchdiffeq import odeint_adjoint as odeint 29 | else: 30 | from torchdiffeq import odeint 31 | 32 | self.train_integrator = odeint 33 | self.test_integrator = odeint 34 | self.set_tol() 35 | 36 | def forward(self, x): 37 | t = self.t.type_as(x) 38 | 39 | integrator = self.train_integrator if self.training else self.test_integrator 40 | 41 | reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) ) 42 | 43 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 44 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 45 | 46 | if self.opt["adjoint"] and self.training: 47 | state_dt = integrator( 48 | func, state, t, 49 | method=self.opt['method'], 50 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 51 | adjoint_method=self.opt['adjoint_method'], 52 | adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']), 53 | atol=self.atol, 54 | rtol=self.rtol, 55 | adjoint_atol=self.atol_adjoint, 56 | adjoint_rtol=self.rtol_adjoint) 57 | else: 58 | state_dt = integrator( 59 | func, state, t, 60 | method=self.opt['method'], 61 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 62 | atol=self.atol, 63 | rtol=self.rtol) 64 | 65 | if self.training and self.nreg > 0: 66 | z = state_dt[0][1] 67 | reg_states = tuple( st[1] for st in state_dt[1:] ) 68 | return z, reg_states 69 | else: 70 | z = state_dt[1] 71 | return z 72 | 73 | def __repr__(self): 74 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 75 | + ")" 76 | -------------------------------------------------------------------------------- /src/homophilic_graphs/block_transformer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | 6 | 7 | class AttODEblock(ODEblock): 8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 9 | super(AttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 10 | 11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 12 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 14 | fill_value=opt['self_loop_weight'], 15 | num_nodes=data.num_nodes, 16 | dtype=data.x.dtype) 17 | self.odefunc.edge_index = edge_index.to(device) 18 | self.odefunc.edge_weight = edge_weight.to(device) 19 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 20 | 21 | if(opt['method']=='symplectic_euler' or opt['method']=='leapfrog'): 22 | from odeint_geometric import odeint 23 | elif opt['adjoint']: 24 | from torchdiffeq import odeint_adjoint as odeint 25 | else: 26 | from torchdiffeq import odeint 27 | self.train_integrator = odeint 28 | self.test_integrator = odeint 29 | self.set_tol() 30 | # parameter trading off between attention and the Laplacian 31 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 32 | device, edge_weights=self.odefunc.edge_weight).to(device) 33 | 34 | def get_attention_weights(self, x): 35 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index) 36 | return attention 37 | 38 | def forward(self, x_all): 39 | x = x_all[:,:self.opt['hidden_dim']] 40 | y = x_all[:, self.opt['hidden_dim']:] 41 | t = self.t.type_as(x) 42 | self.odefunc.attention_weights = self.get_attention_weights(y) 43 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 44 | integrator = self.train_integrator if self.training else self.test_integrator 45 | 46 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 47 | 48 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 49 | state = (x_all,) + reg_states if self.training and self.nreg > 0 else x_all 50 | 51 | if self.opt["adjoint"] and self.training: 52 | state_dt = integrator( 53 | func, state, t, 54 | method=self.opt['method'], 55 | options={'step_size': self.opt['step_size']}, 56 | adjoint_method=self.opt['adjoint_method'], 57 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 58 | atol=self.atol, 59 | rtol=self.rtol, 60 | adjoint_atol=self.atol_adjoint, 61 | adjoint_rtol=self.rtol_adjoint) 62 | else: 63 | state_dt = integrator( 64 | func, state, t, 65 | method=self.opt['method'], 66 | options={'step_size': self.opt['step_size']}, 67 | atol=self.atol, 68 | rtol=self.rtol) 69 | 70 | if self.training and self.nreg > 0: 71 | z = state_dt[0][1] 72 | reg_states = tuple(st[1] for st in state_dt[1:]) 73 | return z, reg_states 74 | else: 75 | z = state_dt[1] 76 | return z 77 | 78 | def __repr__(self): 79 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 80 | + ")" 81 | -------------------------------------------------------------------------------- /src/homophilic_graphs/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code partially copied from 'Diffusion Improves Graph Learning' repo https://github.com/klicperajo/gdc/blob/master/data.py 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch_geometric.data import Data, InMemoryDataset 11 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor 12 | from ogb.nodeproppred import PygNodePropPredDataset 13 | import torch_geometric.transforms as T 14 | from torch_geometric.utils import to_undirected 15 | 16 | DATA_PATH = '../../data' 17 | 18 | 19 | def get_dataset(opt: dict, data_dir, use_lcc: bool = False) -> InMemoryDataset: 20 | ds = opt['dataset'] 21 | path = os.path.join(data_dir, ds) 22 | if ds in ['Cora', 'Citeseer', 'Pubmed']: 23 | dataset = Planetoid(path, ds) 24 | elif ds in ['Computers', 'Photo']: 25 | dataset = Amazon(path, ds) 26 | elif ds == 'CoauthorCS': 27 | dataset = Coauthor(path, 'CS') 28 | elif ds == 'ogbn-arxiv': 29 | dataset = PygNodePropPredDataset(name=ds, root=path, 30 | transform=T.ToSparseTensor()) 31 | use_lcc = False # never need to calculate the lcc with ogb datasets 32 | else: 33 | raise Exception('Unknown dataset.') 34 | 35 | if use_lcc: 36 | lcc = get_largest_connected_component(dataset) 37 | 38 | x_new = dataset.data.x[lcc] 39 | y_new = dataset.data.y[lcc] 40 | 41 | row, col = dataset.data.edge_index.numpy() 42 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc] 43 | edges = remap_edges(edges, get_node_mapper(lcc)) 44 | 45 | data = Data( 46 | x=x_new, 47 | edge_index=torch.LongTensor(edges), 48 | y=y_new, 49 | train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 50 | test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 51 | val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool) 52 | ) 53 | dataset.data = data 54 | 55 | train_mask_exists = True 56 | try: 57 | dataset.data.train_mask 58 | except AttributeError: 59 | train_mask_exists = False 60 | 61 | if ds == 'ogbn-arxiv': 62 | split_idx = dataset.get_idx_split() 63 | ei = to_undirected(dataset.data.edge_index) 64 | data = Data( 65 | x=dataset.data.x, 66 | edge_index=ei, 67 | y=dataset.data.y, 68 | train_mask=split_idx['train'], 69 | test_mask=split_idx['test'], 70 | val_mask=split_idx['valid']) 71 | dataset.data = data 72 | train_mask_exists = True 73 | 74 | if use_lcc or not train_mask_exists: 75 | dataset.data = set_train_val_test_split( 76 | 12345, 77 | dataset.data, 78 | num_development=5000 if ds == "CoauthorCS" else 1500) 79 | 80 | return dataset 81 | 82 | 83 | def get_component(dataset: InMemoryDataset, start: int = 0) -> set: 84 | visited_nodes = set() 85 | queued_nodes = set([start]) 86 | row, col = dataset.data.edge_index.numpy() 87 | while queued_nodes: 88 | current_node = queued_nodes.pop() 89 | visited_nodes.update([current_node]) 90 | neighbors = col[np.where(row == current_node)[0]] 91 | neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes] 92 | queued_nodes.update(neighbors) 93 | return visited_nodes 94 | 95 | 96 | def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray: 97 | remaining_nodes = set(range(dataset.data.x.shape[0])) 98 | comps = [] 99 | while remaining_nodes: 100 | start = min(remaining_nodes) 101 | comp = get_component(dataset, start) 102 | comps.append(comp) 103 | remaining_nodes = remaining_nodes.difference(comp) 104 | return np.array(list(comps[np.argmax(list(map(len, comps)))])) 105 | 106 | 107 | def get_node_mapper(lcc: np.ndarray) -> dict: 108 | mapper = {} 109 | counter = 0 110 | for node in lcc: 111 | mapper[node] = counter 112 | counter += 1 113 | return mapper 114 | 115 | 116 | def remap_edges(edges: list, mapper: dict) -> list: 117 | row = [e[0] for e in edges] 118 | col = [e[1] for e in edges] 119 | row = list(map(lambda x: mapper[x], row)) 120 | col = list(map(lambda x: mapper[x], col)) 121 | return [row, col] 122 | 123 | 124 | def set_train_val_test_split( 125 | seed: int, 126 | data: Data, 127 | num_development: int = 1500, 128 | num_per_class: int = 20) -> Data: 129 | rnd_state = np.random.RandomState(seed) 130 | num_nodes = data.y.shape[0] 131 | development_idx = rnd_state.choice(num_nodes, num_development, replace=False) 132 | test_idx = [i for i in np.arange(num_nodes) if i not in development_idx] 133 | 134 | train_idx = [] 135 | rnd_state = np.random.RandomState(seed) 136 | for c in range(data.y.max() + 1): 137 | class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]] 138 | train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False)) 139 | 140 | val_idx = [i for i in development_idx if i not in train_idx] 141 | 142 | def get_mask(idx): 143 | mask = torch.zeros(num_nodes, dtype=torch.bool) 144 | mask[idx] = 1 145 | return mask 146 | 147 | data.train_mask = get_mask(train_idx) 148 | data.val_mask = get_mask(val_idx) 149 | data.test_mask = get_mask(test_idx) 150 | 151 | return data 152 | 153 | 154 | if __name__ == '__main__': 155 | # example for heterophilic datasets 156 | from heterophilic import get_fixed_splits 157 | 158 | opt = {'dataset': 'Cora', 'device': 'cpu'} 159 | dataset = get_dataset(opt) 160 | for fold in range(10): 161 | data = dataset[0] 162 | data = get_fixed_splits(data, opt['dataset'], fold) 163 | data = data.to(opt['device']) 164 | -------------------------------------------------------------------------------- /src/homophilic_graphs/early_stop_solver.py: -------------------------------------------------------------------------------- 1 | import torchdiffeq 2 | from torchdiffeq._impl.dopri5 import _DORMAND_PRINCE_SHAMPINE_TABLEAU, DPS_C_MID 3 | from torchdiffeq._impl.solvers import FixedGridODESolver 4 | import torch 5 | import abc 6 | from geometric_solvers import GeomtricFixedGridODESolver 7 | from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape 8 | import torch.nn.functional as F 9 | import copy 10 | from geometric_integrators import SymplecticEuler_step_func, Leapfrog_step_func 11 | 12 | from torchdiffeq._impl.interp import _interp_evaluate 13 | from torchdiffeq._impl.rk_common import RKAdaptiveStepsizeODESolver, rk4_alt_step_func 14 | from ogb.nodeproppred import Evaluator 15 | 16 | 17 | def run_evaluator(evaluator, data, y_pred): 18 | train_acc = evaluator.eval({ 19 | 'y_true': data.y[data.train_mask], 20 | 'y_pred': y_pred[data.train_mask], 21 | })['acc'] 22 | valid_acc = evaluator.eval({ 23 | 'y_true': data.y[data.val_mask], 24 | 'y_pred': y_pred[data.val_mask], 25 | })['acc'] 26 | test_acc = evaluator.eval({ 27 | 'y_true': data.y[data.test_mask], 28 | 'y_pred': y_pred[data.test_mask], 29 | })['acc'] 30 | return train_acc, valid_acc, test_acc 31 | 32 | 33 | class EarlyStopDopri5(RKAdaptiveStepsizeODESolver): 34 | order = 5 35 | tableau = _DORMAND_PRINCE_SHAMPINE_TABLEAU 36 | mid = DPS_C_MID 37 | 38 | def __init__(self, func, y0, rtol, atol, opt, **kwargs): 39 | super(EarlyStopDopri5, self).__init__(func, y0, rtol, atol, **kwargs) 40 | 41 | self.lf = torch.nn.CrossEntropyLoss() 42 | self.m2 = None 43 | self.data = None 44 | self.best_val = 0 45 | self.best_test = 0 46 | self.best_time = 0 47 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 48 | self.dataset = opt['dataset'] 49 | if opt['dataset'] == 'ogbn-arxiv': 50 | self.lf = torch.nn.functional.nll_loss 51 | self.evaluator = Evaluator(name=opt['dataset']) 52 | 53 | def set_accs(self, train, val, test, time): 54 | self.best_train = train 55 | self.best_val = val 56 | self.best_test = test 57 | self.best_time = time.item() 58 | 59 | def integrate(self, t): 60 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 61 | solution[0] = self.y0 62 | t = t.to(self.dtype) 63 | self._before_integrate(t) 64 | new_t = t 65 | for i in range(1, len(t)): 66 | new_t, y = self.advance(t[i]) 67 | solution[i] = y 68 | return new_t, solution 69 | 70 | def advance(self, next_t): 71 | """ 72 | Takes steps dt to get to the next user specified time point next_t. In practice this goes past next_t and then interpolates 73 | :param next_t: 74 | :return: The state, x(next_t) 75 | """ 76 | n_steps = 0 77 | while next_t > self.rk_state.t1: 78 | assert n_steps < self.max_num_steps, 'max_num_steps exceeded ({}>={})'.format(n_steps, self.max_num_steps) 79 | self.rk_state = self._adaptive_step(self.rk_state) 80 | n_steps += 1 81 | train_acc, val_acc, test_acc = self.evaluate(self.rk_state) 82 | if val_acc > self.best_val: 83 | self.set_accs(train_acc, val_acc, test_acc, self.rk_state.t1) 84 | new_t = next_t 85 | return (new_t, _interp_evaluate(self.rk_state.interp_coeff, self.rk_state.t0, self.rk_state.t1, next_t)) 86 | 87 | @torch.no_grad() 88 | def test(self, logits): 89 | accs = [] 90 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 91 | pred = logits[mask].max(1)[1] 92 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 93 | accs.append(acc) 94 | return accs 95 | 96 | @torch.no_grad() 97 | def test_OGB(self, logits): 98 | evaluator = self.evaluator 99 | data = self.data 100 | y_pred = logits.argmax(dim=-1, keepdim=True) 101 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 102 | return [train_acc, valid_acc, test_acc] 103 | 104 | @torch.no_grad() 105 | def evaluate(self, rkstate): 106 | # Activation. 107 | z = rkstate.y1 108 | if not self.m2.in_features == z.shape[1]: # system has been augmented 109 | z = torch.split(z, self.m2.in_features, dim=1)[0] 110 | z = F.relu(z) 111 | z = self.m2(z) 112 | t0, t1 = float(self.rk_state.t0), float(self.rk_state.t1) 113 | if self.dataset == 'ogbn-arxiv': 114 | z = z.log_softmax(dim=-1) 115 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 116 | else: 117 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 118 | train_acc, val_acc, test_acc = self.ode_test(z) 119 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 120 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 121 | return train_acc, val_acc, test_acc 122 | 123 | def set_m2(self, m2): 124 | self.m2 = copy.deepcopy(m2) 125 | 126 | def set_data(self, data): 127 | if self.data is None: 128 | self.data = data 129 | 130 | class EarlyStopRK4(FixedGridODESolver): 131 | order = 4 132 | 133 | def __init__(self, func, y0, opt, eps=0, **kwargs): 134 | super(EarlyStopRK4, self).__init__(func, y0, **kwargs) 135 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) 136 | self.lf = torch.nn.CrossEntropyLoss() 137 | self.m2 = None 138 | self.data = None 139 | self.best_val = 0 140 | self.best_test = 0 141 | self.best_time = 0 142 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 143 | self.dataset = opt['dataset'] 144 | if opt['dataset'] == 'ogbn-arxiv': 145 | self.lf = torch.nn.functional.nll_loss 146 | self.evaluator = Evaluator(name=opt['dataset']) 147 | 148 | def _step_func(self, func, t, dt, t1, y): 149 | ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4] 150 | if int(ver) >= 22: # '0.2.2' 151 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, t1, y) 152 | else: 153 | return rk4_alt_step_func(func, t + self.eps, dt - 2 * self.eps, y) 154 | 155 | def set_accs(self, train, val, test, time): 156 | self.best_train = train 157 | self.best_val = val 158 | self.best_test = test 159 | self.best_time = time.item() 160 | 161 | def integrate(self, t): 162 | time_grid = self.grid_constructor(self.func, self.y0, t) 163 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 164 | 165 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 166 | solution[0] = self.y0 167 | 168 | j = 1 169 | y0 = self.y0 170 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 171 | dy = self._step_func(self.func, t0, t1 - t0, t1, y0) 172 | y1 = y0 + dy 173 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1) 174 | if val_acc > self.best_val: 175 | self.set_accs(train_acc, val_acc, test_acc, t1) 176 | 177 | while j < len(t) and t1 >= t[j]: 178 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 179 | j += 1 180 | y0 = y1 181 | 182 | return t1, solution 183 | 184 | @torch.no_grad() 185 | def test(self, logits): 186 | accs = [] 187 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 188 | pred = logits[mask].max(1)[1] 189 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 190 | accs.append(acc) 191 | return accs 192 | 193 | @torch.no_grad() 194 | def test_OGB(self, logits): 195 | evaluator = self.evaluator 196 | data = self.data 197 | y_pred = logits.argmax(dim=-1, keepdim=True) 198 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 199 | return [train_acc, valid_acc, test_acc] 200 | 201 | @torch.no_grad() 202 | def evaluate(self, z, t0, t1): 203 | # Activation. 204 | if not self.m2.in_features == z.shape[1]: # system has been augmented 205 | z = torch.split(z, self.m2.in_features, dim=1)[0] 206 | z = F.relu(z) 207 | z = self.m2(z) 208 | if self.dataset == 'ogbn-arxiv': 209 | z = z.log_softmax(dim=-1) 210 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 211 | else: 212 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 213 | train_acc, val_acc, test_acc = self.ode_test(z) 214 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 215 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 216 | return train_acc, val_acc, test_acc 217 | 218 | def set_m2(self, m2): 219 | self.m2 = copy.deepcopy(m2) 220 | 221 | def set_data(self, data): 222 | if self.data is None: 223 | self.data = data 224 | 225 | 226 | class EarlyStopSympEuler(GeomtricFixedGridODESolver): 227 | order = 1 228 | 229 | def __init__(self, func, y0, rtol, atol, opt, eps=0, **kwargs): 230 | super(EarlyStopSympEuler, self).__init__(func, y0, step_size=opt['step_size']) 231 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) 232 | self.lf = torch.nn.CrossEntropyLoss() 233 | self.m2 = None 234 | self.data = None 235 | self.best_val = 0 236 | self.best_test = 0 237 | self.best_time = 0 238 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 239 | self.dataset = opt['dataset'] 240 | if opt['dataset'] == 'ogbn-arxiv': 241 | self.lf = torch.nn.functional.nll_loss 242 | self.evaluator = Evaluator(name=opt['dataset']) 243 | 244 | def _step_func(self, func, t0, dt, t1, y0): 245 | return SymplecticEuler_step_func(func, t0, dt, t1, y0) 246 | 247 | def set_accs(self, train, val, test, time): 248 | self.best_train = train 249 | self.best_val = val 250 | self.best_test = test 251 | self.best_time = time.item() 252 | 253 | def integrate(self, t): 254 | time_grid = self.grid_constructor(self.func, self.y0, t) 255 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 256 | 257 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 258 | solution[0] = self.y0 259 | 260 | j = 1 261 | y0 = self.y0 262 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 263 | dt = t1 - t0 264 | y1 = self._step_func(self.func, t0, dt, t1, y0) 265 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1) 266 | if val_acc > self.best_val: 267 | self.set_accs(train_acc, val_acc, test_acc, t1) 268 | 269 | while j < len(t) and t1 >= t[j]: 270 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 271 | j += 1 272 | y0 = y1 273 | 274 | return t1, solution 275 | 276 | @torch.no_grad() 277 | def test(self, logits): 278 | accs = [] 279 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 280 | pred = logits[mask].max(1)[1] 281 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 282 | accs.append(acc) 283 | return accs 284 | 285 | @torch.no_grad() 286 | def test_OGB(self, logits): 287 | evaluator = self.evaluator 288 | data = self.data 289 | y_pred = logits.argmax(dim=-1, keepdim=True) 290 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 291 | return [train_acc, valid_acc, test_acc] 292 | 293 | @torch.no_grad() 294 | def evaluate(self, z, t0, t1): 295 | # Activation. 296 | if not self.m2.in_features == z.shape[1]: # system has been augmented 297 | z = torch.split(z, self.m2.in_features, dim=1)[0] 298 | z = F.relu(z) 299 | z = self.m2(z) 300 | if self.dataset == 'ogbn-arxiv': 301 | z = z.log_softmax(dim=-1) 302 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 303 | else: 304 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 305 | train_acc, val_acc, test_acc = self.ode_test(z) 306 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 307 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 308 | return train_acc, val_acc, test_acc 309 | 310 | def set_m2(self, m2): 311 | self.m2 = copy.deepcopy(m2) 312 | 313 | def set_data(self, data): 314 | if self.data is None: 315 | self.data = data 316 | 317 | 318 | class EarlyStopLepfrog(GeomtricFixedGridODESolver): 319 | order = 2 320 | 321 | def __init__(self, func, y0, rtol, atol, opt, eps=0, **kwargs): 322 | super(EarlyStopLepfrog, self).__init__(func, y0, step_size=opt['step_size']) 323 | self.eps = torch.as_tensor(eps, dtype=self.dtype, device=self.device) 324 | self.lf = torch.nn.CrossEntropyLoss() 325 | self.m2 = None 326 | self.data = None 327 | self.best_val = 0 328 | self.best_test = 0 329 | self.best_time = 0 330 | self.ode_test = self.test_OGB if opt['dataset'] == 'ogbn-arxiv' else self.test 331 | self.dataset = opt['dataset'] 332 | if opt['dataset'] == 'ogbn-arxiv': 333 | self.lf = torch.nn.functional.nll_loss 334 | self.evaluator = Evaluator(name=opt['dataset']) 335 | 336 | def _step_func(self, func, t0, dt, t1, y0): 337 | return Leapfrog_step_func(func, t0, dt, t1, y0) 338 | 339 | def set_accs(self, train, val, test, time): 340 | self.best_train = train 341 | self.best_val = val 342 | self.best_test = test 343 | self.best_time = time.item() 344 | 345 | def integrate(self, t): 346 | time_grid = self.grid_constructor(self.func, self.y0, t) 347 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 348 | 349 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 350 | solution[0] = self.y0 351 | 352 | j = 1 353 | y0 = self.y0 354 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 355 | dt = t1 - t0 356 | y1 = self._step_func(self.func, t0, dt, t1, y0) 357 | train_acc, val_acc, test_acc = self.evaluate(y1, t0, t1) 358 | if val_acc > self.best_val: 359 | self.set_accs(train_acc, val_acc, test_acc, t1) 360 | 361 | while j < len(t) and t1 >= t[j]: 362 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 363 | j += 1 364 | y0 = y1 365 | 366 | return t1, solution 367 | 368 | @torch.no_grad() 369 | def test(self, logits): 370 | accs = [] 371 | for _, mask in self.data('train_mask', 'val_mask', 'test_mask'): 372 | pred = logits[mask].max(1)[1] 373 | acc = pred.eq(self.data.y[mask]).sum().item() / mask.sum().item() 374 | accs.append(acc) 375 | return accs 376 | 377 | @torch.no_grad() 378 | def test_OGB(self, logits): 379 | evaluator = self.evaluator 380 | data = self.data 381 | y_pred = logits.argmax(dim=-1, keepdim=True) 382 | train_acc, valid_acc, test_acc = run_evaluator(evaluator, data, y_pred) 383 | return [train_acc, valid_acc, test_acc] 384 | 385 | @torch.no_grad() 386 | def evaluate(self, z, t0, t1): 387 | # Activation. 388 | if not self.m2.in_features == z.shape[1]: # system has been augmented 389 | z = torch.split(z, self.m2.in_features, dim=1)[0] 390 | z = F.relu(z) 391 | z = self.m2(z) 392 | if self.dataset == 'ogbn-arxiv': 393 | z = z.log_softmax(dim=-1) 394 | loss = self.lf(z[self.data.train_mask], self.data.y.squeeze()[self.data.train_mask]) 395 | else: 396 | loss = self.lf(z[self.data.train_mask], self.data.y[self.data.train_mask]) 397 | train_acc, val_acc, test_acc = self.ode_test(z) 398 | log = 'ODE eval t0 {:.3f}, t1 {:.3f} Loss: {:.4f}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 399 | # print(log.format(t0, t1, loss, train_acc, val_acc, tmp_test_acc)) 400 | return train_acc, val_acc, test_acc 401 | 402 | def set_m2(self, m2): 403 | self.m2 = copy.deepcopy(m2) 404 | 405 | def set_data(self, data): 406 | if self.data is None: 407 | self.data = data 408 | 409 | 410 | SOLVERS = { 411 | 'dopri5': EarlyStopDopri5, 412 | 'rk4': EarlyStopRK4, 413 | 'symplectic_euler': EarlyStopSympEuler, 414 | 'leapfrog': EarlyStopLepfrog 415 | } 416 | 417 | 418 | class EarlyStopInt(torch.nn.Module): 419 | def __init__(self, t, opt, device=None): 420 | super(EarlyStopInt, self).__init__() 421 | self.device = device 422 | self.solver = None 423 | self.data = None 424 | self.m2 = None 425 | self.opt = opt 426 | self.t = torch.tensor([0, opt['earlystopxT'] * t], dtype=torch.float).to(self.device) 427 | 428 | def __call__(self, func, y0, t, method=None, rtol=1e-7, atol=1e-9, 429 | adjoint_method="dopri5", adjoint_atol=1e-9, adjoint_rtol=1e-7, options=None): 430 | """Integrate a system of ordinary differential equations. 431 | 432 | Solves the initial value problem for a non-stiff system of first order ODEs: 433 | ``` 434 | dy/dt = func(t, y), y(t[0]) = y0 435 | ``` 436 | where y is a Tensor of any shape. 437 | 438 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 439 | 440 | Args: 441 | func: Function that maps a Tensor holding the state `y` and a scalar Tensor 442 | `t` into a Tensor of state derivatives with respect to time. 443 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. May 444 | have any floating point or complex dtype. 445 | t: 1-D Tensor holding a sequence of time points for which to solve for 446 | `y`. The initial time point should be the first element of this sequence, 447 | and each time must be larger than the previous time. May have any floating 448 | point dtype. Converted to a Tensor with float64 dtype. 449 | rtol: optional float64 Tensor specifying an upper bound on relative error, 450 | per element of `y`. 451 | atol: optional float64 Tensor specifying an upper bound on absolute error, 452 | per element of `y`. 453 | method: optional string indicating the integration method to use. 454 | options: optional dict of configuring options for the indicated integration 455 | method. Can only be provided if a `method` is explicitly set. 456 | name: Optional name for this operation. 457 | 458 | Returns: 459 | y: Tensor, where the first dimension corresponds to different 460 | time points. Contains the solved value of y for each desired time point in 461 | `t`, with the initial value `y0` being the first element along the first 462 | dimension. 463 | 464 | Raises: 465 | ValueError: if an invalid `method` is provided. 466 | TypeError: if `options` is supplied without `method`, or if `t` or `y0` has 467 | an invalid dtype. 468 | """ 469 | method = self.opt['method'] 470 | assert method in ['rk4', 'dopri5','symplectic_euler', 'leapfrog'], "Only dopri5, rk4, symplectic_euler and leapfrog implemented with early stopping" 471 | 472 | ver = torchdiffeq.__version__[0] + torchdiffeq.__version__[2] + torchdiffeq.__version__[4] 473 | if int(ver) >= 22: #'0.2.2' 474 | event_fn = None 475 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, self.t, rtol, 476 | atol, method, options, 477 | event_fn, SOLVERS) 478 | else: 479 | shapes, func, y0, t, rtol, atol, method, options = _check_inputs(func, y0, self.t, rtol, atol, method, options, 480 | SOLVERS) 481 | 482 | self.solver = SOLVERS[method](func, y0, rtol=rtol, atol=atol, opt=self.opt, **options) 483 | if self.solver.data is None: 484 | self.solver.data = self.data 485 | self.solver.m2 = self.m2 486 | t, solution = self.solver.integrate(t) 487 | if shapes is not None: 488 | solution = _flat_to_shape(solution, (len(t),), shapes) 489 | return solution 490 | -------------------------------------------------------------------------------- /src/homophilic_graphs/experiment_configs/gat_pubmed.yaml: -------------------------------------------------------------------------------- 1 | program: run_GNN.py 2 | method: random 3 | parameters: 4 | method: 5 | distribution: categorical 6 | values: [ 'leapfrog' ] 7 | tol_scale: 8 | distribution: log_uniform 9 | min: -1 10 | max: 7 11 | function: 12 | distribution: constant 13 | value: GAT 14 | block: 15 | distribution: constant 16 | value: constant 17 | num_splits: 18 | distribution: constant 19 | value: 3 20 | heads: 21 | distribution: categorical 22 | values: [ 1, 2 ] 23 | epoch: 24 | distribution: constant 25 | value: 70 26 | time: 27 | distribution: uniform 28 | min: 3 29 | max: 25 30 | dataset: 31 | distribution: constant 32 | value: 'Pubmed' 33 | hidden_dim: 34 | distribution: categorical 35 | values: [ 32, 64 , 128 ] 36 | attention_dim: 37 | distribution: categorical 38 | values: [ 32, 64 , 128 ] 39 | input_dropout: 40 | distribution: uniform 41 | min: 0.0 42 | max: 0.5 43 | dropout: 44 | distribution: uniform 45 | min: 0.0 46 | max: 0.5 47 | lr: 48 | distribution: uniform 49 | min: 0.01 50 | max: 0.07 51 | decay: 52 | distribution: uniform 53 | min: 0 54 | max: 0.02 55 | command: 56 | - ${env} 57 | - python3 58 | - ${program} 59 | - ${args} 60 | - --wandb 61 | - --wandb_sweep 62 | - --adjoint 63 | - --add_source 64 | entity: 65 | graphcon 66 | project: 67 | gat_pubmed 68 | -------------------------------------------------------------------------------- /src/homophilic_graphs/experiment_configs/gcn_cora.yaml: -------------------------------------------------------------------------------- 1 | program: run_GNN.py 2 | method: random 3 | parameters: 4 | method: 5 | distribution: categorical 6 | values: [ 'dopri5' ] 7 | tol_scale: 8 | distribution: log_uniform 9 | min: 0 10 | max: 7 11 | function: 12 | distribution: constant 13 | value: gcn 14 | num_splits: 15 | distribution: constant 16 | value: 3 17 | epoch: 18 | distribution: constant 19 | value: 50 20 | time: 21 | distribution: uniform 22 | min: 1 23 | max: 20 24 | dataset: 25 | distribution: constant 26 | value: 'Cora' 27 | hidden_dim: 28 | values: [ 32, 64 , 128 ] 29 | input_dropout: 30 | distribution: uniform 31 | min: 0.0 32 | max: 0.5 33 | dropout: 34 | distribution: uniform 35 | min: 0.0 36 | max: 0.5 37 | lr: 38 | distribution: uniform 39 | min: 0.001 40 | max: 0.03 41 | decay: 42 | distribution: uniform 43 | min: 0 44 | max: 0.02 45 | command: 46 | - ${env} 47 | - python3 48 | - ${program} 49 | - ${args} 50 | - --wandb 51 | - --wandb_sweep 52 | - --function 53 | - function_gcn -------------------------------------------------------------------------------- /src/homophilic_graphs/experiment_configs/gcn_depth_random.yaml: -------------------------------------------------------------------------------- 1 | program: run_GNN.py 2 | method: random 3 | parameters: 4 | method: 5 | value: 'symplectic_euler' 6 | step_size: 7 | value: 1 8 | function: 9 | value: gcn 10 | epoch: 11 | value: 100 12 | hidden_dim: 13 | value: 64 14 | num_splits: 15 | value: 2 16 | dataset: 17 | values: [ 'Cora', 'Citeseer' ] 18 | time: 19 | values: [2, 4, 8, 16, 32, 64] 20 | input_dropout: 21 | distribution: uniform 22 | min: 0.0 23 | max: 0.5 24 | dropout: 25 | distribution: uniform 26 | min: 0.0 27 | max: 0.5 28 | lr: 29 | distribution: uniform 30 | min: 0.01 31 | max: 0.1 32 | decay: 33 | distribution: uniform 34 | min: 0.001 35 | max: 0.1 36 | command: 37 | - ${env} 38 | - python3 39 | - ${program} 40 | - ${args} 41 | - --wandb 42 | - --wandb_sweep 43 | entity: graphcon 44 | project: gcn_depth_random -------------------------------------------------------------------------------- /src/homophilic_graphs/experiment_configs/gcn_planetoid.yaml: -------------------------------------------------------------------------------- 1 | program: run_GNN.py 2 | method: random 3 | parameters: 4 | method: 5 | distribution: categorical 6 | values: [ 'dopri5', 'symplectic_euler', 'rk4', 'leapfrog' ] 7 | step_size: 8 | distribution: categorical 9 | values: [ 0.25, 0.5, 1 ] 10 | function: 11 | distribution: constant 12 | value: gcn 13 | epoch: 14 | distribution: constant 15 | value: 100 16 | time: 17 | distribution: uniform 18 | min: 4 19 | max: 15 20 | dataset: 21 | distribution: categorical 22 | values: [ 'Cora', 'Citeseer', 'Pubmed' ] 23 | hidden_dim: 24 | values: [ 8, 16, 32, 64 ] 25 | attention_dim: 26 | values: [ 8, 16, 32, 64 ] 27 | input_dropout: 28 | distribution: uniform 29 | min: 0.0 30 | max: 0.5 31 | dropout: 32 | distribution: uniform 33 | min: 0.0 34 | max: 0.5 35 | lr: 36 | distribution: uniform 37 | min: 0.01 38 | max: 0.1 39 | decay: 40 | distribution: uniform 41 | min: 0.001 42 | max: 0.1 43 | command: 44 | - ${env} 45 | - python3 46 | - ${program} 47 | - ${args} 48 | - --wandb 49 | - --wandb_sweep 50 | - --function 51 | - function_gcn -------------------------------------------------------------------------------- /src/homophilic_graphs/experiment_configs/run_sweeps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {0..7} 4 | do 5 | CUDA_VISIBLE_DEVICES=$(($i % 8)) wandb agent graphcon/gat_pubmed/$1 & 6 | done -------------------------------------------------------------------------------- /src/homophilic_graphs/function_GAT_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import softmax 4 | import torch_sparse 5 | import torch.nn.functional as F 6 | from torch_geometric.utils.loop import add_remaining_self_loops 7 | from data import get_dataset 8 | from utils import MaxNFEException 9 | from base_classes import ODEFunc 10 | 11 | 12 | class ODEFuncAtt(ODEFunc): 13 | 14 | def __init__(self, in_features, out_features, opt, data, device): 15 | super(ODEFuncAtt, self).__init__(opt, data, device) 16 | 17 | if opt['self_loop_weight'] > 0: 18 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 19 | fill_value=opt['self_loop_weight']) 20 | else: 21 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr 22 | 23 | self.multihead_att_layer = SpGraphAttentionLayer(in_features, out_features, opt, 24 | device).to(device) 25 | try: 26 | self.attention_dim = opt['attention_dim'] 27 | except KeyError: 28 | self.attention_dim = out_features 29 | 30 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 31 | self.d_k = self.attention_dim // opt['heads'] 32 | 33 | def multiply_attention(self, x, attention, wx): 34 | if self.opt['mix_features']: 35 | wx = torch.mean(torch.stack( 36 | [torch_sparse.spmm(self.edge_index, attention[:, idx], wx.shape[0], wx.shape[0], wx) for idx in 37 | range(self.opt['heads'])], dim=0), 38 | dim=0) 39 | ax = torch.mm(wx, self.multihead_att_layer.Wout) 40 | else: 41 | ax = torch.mean(torch.stack( 42 | [torch_sparse.spmm(self.edge_index, attention[:, idx], x.shape[0], x.shape[0], x) for idx in 43 | range(self.opt['heads'])], dim=0), 44 | dim=0) 45 | return ax 46 | 47 | def forward(self, t, x_full): # t is needed when called by the integrator 48 | x = x_full[:, :self.opt['hidden_dim']] 49 | y = x_full[:, self.opt['hidden_dim']:] 50 | if self.nfe > self.opt["max_nfe"]: 51 | raise MaxNFEException 52 | self.nfe += 1 53 | attention, wy = self.multihead_att_layer(y, self.edge_index) 54 | ay = self.multiply_attention(y, attention, wy) 55 | # todo would be nice if this was more efficient 56 | 57 | if not self.opt['no_alpha_sigmoid']: 58 | alpha = torch.sigmoid(self.alpha_train) 59 | else: 60 | alpha = self.alpha_train 61 | f = (ay - y - x) 62 | if self.opt['add_source']: 63 | f = (1. - torch.sigmoid(self.beta_train)) * f + torch.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:] 64 | f = torch.cat([f, (1. - torch.sigmoid(self.beta_train2)) * alpha * x + torch.sigmoid(self.beta_train2) * self.x0[:,:self.opt['hidden_dim']]],dim=1) 65 | return f 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 69 | 70 | 71 | class SpGraphAttentionLayer(nn.Module): 72 | """ 73 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 74 | """ 75 | 76 | def __init__(self, in_features, out_features, opt, device, concat=True): 77 | super(SpGraphAttentionLayer, self).__init__() 78 | self.in_features = in_features 79 | self.out_features = out_features 80 | self.alpha = opt['leaky_relu_slope'] 81 | self.concat = concat 82 | self.device = device 83 | self.opt = opt 84 | self.h = opt['heads'] 85 | 86 | try: 87 | self.attention_dim = opt['attention_dim'] 88 | except KeyError: 89 | self.attention_dim = out_features 90 | 91 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 92 | self.d_k = self.attention_dim // opt['heads'] 93 | 94 | self.W = nn.Parameter(torch.zeros(size=(in_features, self.attention_dim))).to(device) 95 | nn.init.xavier_normal_(self.W.data, gain=1.414) 96 | 97 | self.Wout = nn.Parameter(torch.zeros(size=(self.attention_dim, self.in_features))).to(device) 98 | nn.init.xavier_normal_(self.Wout.data, gain=1.414) 99 | 100 | self.a = nn.Parameter(torch.zeros(size=(2 * self.d_k, 1, 1))).to(device) 101 | nn.init.xavier_normal_(self.a.data, gain=1.414) 102 | 103 | self.leakyrelu = nn.LeakyReLU(self.alpha) 104 | 105 | def forward(self, x, edge): 106 | wx = torch.mm(x, self.W) # h: N x out 107 | h = wx.view(-1, self.h, self.d_k) 108 | h = h.transpose(1, 2) 109 | 110 | # Self-attention on the nodes - Shared attention mechanism 111 | edge_h = torch.cat((h[edge[0, :], :, :], h[edge[1, :], :, :]), dim=1).transpose(0, 1).to( 112 | self.device) # edge: 2*D x E 113 | edge_e = self.leakyrelu(torch.sum(self.a * edge_h, dim=0)).to(self.device) 114 | attention = softmax(edge_e, edge[self.opt['attention_norm_idx']]) 115 | return attention, wx 116 | 117 | def __repr__(self): 118 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 119 | 120 | 121 | if __name__ == '__main__': 122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 123 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 'attention_norm_idx': 0, 124 | 'add_source':False, 'alpha_dim': 'sc', 'beta_dim': 'vc', 'max_nfe':1000, 'mix_features': False} 125 | dataset = get_dataset(opt, '../data', False) 126 | t = 1 127 | func = ODEFuncAtt(dataset.data.num_features, 6, opt, dataset.data, device) 128 | out = func(t, dataset.data.x) 129 | -------------------------------------------------------------------------------- /src/homophilic_graphs/function_gcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_sparse 4 | import torch.nn.functional as F 5 | from base_classes import ODEFunc 6 | from utils import MaxNFEException 7 | from torch_geometric.nn.conv import GCNConv 8 | 9 | 10 | # Define the ODE function. 11 | # Input: 12 | # --- t: A tensor with shape [], meaning the current time. 13 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 14 | # Output: 15 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 16 | class GCNFunc(ODEFunc): 17 | 18 | # currently requires in_features = out_features 19 | def __init__(self, in_features, out_features, opt, data, device): 20 | super(GCNFunc, self).__init__(opt, data, device) 21 | 22 | self.in_features = in_features 23 | self.out_features = out_features 24 | self.w = nn.Parameter(torch.eye(opt['hidden_dim'])) 25 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1) 26 | self.alpha_sc = nn.Parameter(torch.ones(1)) 27 | self.beta_sc = nn.Parameter(torch.ones(1)) 28 | self.conv = GCNConv(in_features, out_features, normalize=True) 29 | 30 | def sparse_multiply(self, x): 31 | if self.opt['block'] in ['attention']: # adj is a multihead attention 32 | # ax = torch.mean(torch.stack( 33 | # [torch_sparse.spmm(self.edge_index, self.attention_weights[:, idx], x.shape[0], x.shape[0], x) for idx in 34 | # range(self.opt['heads'])], dim=0), dim=0) 35 | mean_attention = self.attention_weights.mean(dim=1) 36 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 37 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix 38 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x) 39 | else: # adj is a torch sparse matrix 40 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x) 41 | return ax 42 | 43 | def forward(self, t, x_full): # the t param is needed by the ODE solver. 44 | x = x_full[:, :self.opt['hidden_dim']] 45 | y = x_full[:, self.opt['hidden_dim']:] 46 | if self.nfe > self.opt["max_nfe"]: 47 | raise MaxNFEException 48 | self.nfe += 1 49 | # ay = self.sparse_multiply(y) 50 | if not self.opt['no_alpha_sigmoid']: 51 | alpha = torch.sigmoid(self.alpha_train) 52 | else: 53 | alpha = self.alpha_train 54 | f = self.conv(y, self.edge_index) - x - y 55 | # f = (ay - x - y) 56 | if self.opt['add_source']: 57 | f = (1. - F.sigmoid(self.beta_train)) * f + F.sigmoid(self.beta_train) * self.x0[:, self.opt['hidden_dim']:] 58 | f = torch.cat([f, (1. - F.sigmoid(self.beta_train2)) * alpha * x + F.sigmoid(self.beta_train2) * 59 | self.x0[:, :self.opt['hidden_dim']]], dim=1) 60 | return f 61 | -------------------------------------------------------------------------------- /src/homophilic_graphs/function_laplacian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_sparse 4 | import torch.nn.functional as F 5 | from base_classes import ODEFunc 6 | from utils import MaxNFEException 7 | 8 | 9 | # Define the ODE function. 10 | # Input: 11 | # --- t: A tensor with shape [], meaning the current time. 12 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 13 | # Output: 14 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 15 | class LaplacianODEFunc(ODEFunc): 16 | 17 | # currently requires in_features = out_features 18 | def __init__(self, in_features, out_features, opt, data, device): 19 | super(LaplacianODEFunc, self).__init__(opt, data, device) 20 | 21 | self.in_features = in_features 22 | self.out_features = out_features 23 | self.w = nn.Parameter(torch.eye(opt['hidden_dim'])) 24 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1) 25 | self.alpha_sc = nn.Parameter(torch.ones(1)) 26 | self.beta_sc = nn.Parameter(torch.ones(1)) 27 | 28 | def sparse_multiply(self, x): 29 | if self.opt['block'] in ['attention']: # adj is a multihead attention 30 | # ax = torch.mean(torch.stack( 31 | # [torch_sparse.spmm(self.edge_index, self.attention_weights[:, idx], x.shape[0], x.shape[0], x) for idx in 32 | # range(self.opt['heads'])], dim=0), dim=0) 33 | mean_attention = self.attention_weights.mean(dim=1) 34 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 35 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix 36 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x) 37 | else: # adj is a torch sparse matrix 38 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x) 39 | return ax 40 | 41 | def forward(self, t, x_full): # the t param is needed by the ODE solver. 42 | x = x_full[:,:self.opt['hidden_dim']] 43 | y = x_full[:,self.opt['hidden_dim']:] 44 | if self.nfe > self.opt["max_nfe"]: 45 | raise MaxNFEException 46 | self.nfe += 1 47 | ay = self.sparse_multiply(y) 48 | if not self.opt['no_alpha_sigmoid']: 49 | alpha = torch.sigmoid(self.alpha_train) 50 | else: 51 | alpha = self.alpha_train 52 | f = (ay - y - x) 53 | if self.opt['add_source']: 54 | f = (1.-F.sigmoid(self.beta_train))*f + F.sigmoid(self.beta_train) * self.x0[:,self.opt['hidden_dim']:] 55 | f = torch.cat([f,(1.-F.sigmoid(self.beta_train2))*alpha*x + F.sigmoid(self.beta_train2) * self.x0[:,:self.opt['hidden_dim']]],dim=1) 56 | return f 57 | -------------------------------------------------------------------------------- /src/homophilic_graphs/function_transformer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import softmax 4 | import torch_sparse 5 | from torch_geometric.utils.loop import add_remaining_self_loops 6 | import numpy as np 7 | from data import get_dataset 8 | from utils import MaxNFEException 9 | from base_classes import ODEFunc 10 | 11 | 12 | class ODEFuncTransformerAtt(ODEFunc): 13 | 14 | def __init__(self, in_features, out_features, opt, data, device): 15 | super(ODEFuncTransformerAtt, self).__init__(opt, data, device) 16 | 17 | if opt['self_loop_weight'] > 0: 18 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 19 | fill_value=opt['self_loop_weight']) 20 | else: 21 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr 22 | self.multihead_att_layer = SpGraphTransAttentionLayer(in_features, out_features, opt, 23 | device, edge_weights=self.edge_weight).to(device) 24 | 25 | def multiply_attention(self, x, attention, v=None): 26 | # todo would be nice if this was more efficient 27 | if self.opt['mix_features']: 28 | vx = torch.mean(torch.stack( 29 | [torch_sparse.spmm(self.edge_index, attention[:, idx], v.shape[0], v.shape[0], v[:, :, idx]) for idx in 30 | range(self.opt['heads'])], dim=0), 31 | dim=0) 32 | ax = self.multihead_att_layer.Wout(vx) 33 | else: 34 | mean_attention = attention.mean(dim=1) 35 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 36 | return ax 37 | 38 | def forward(self, t, x): # t is needed when called by the integrator 39 | if self.nfe > self.opt["max_nfe"]: 40 | raise MaxNFEException 41 | 42 | self.nfe += 1 43 | attention, values = self.multihead_att_layer(x, self.edge_index) 44 | ax = self.multiply_attention(x, attention, values) 45 | 46 | if not self.opt['no_alpha_sigmoid']: 47 | alpha = torch.sigmoid(self.alpha_train) 48 | else: 49 | alpha = self.alpha_train 50 | f = alpha * (ax - x) 51 | if self.opt['add_source']: 52 | f = f + self.beta_train * self.x0 53 | return f 54 | 55 | def __repr__(self): 56 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 57 | 58 | 59 | class SpGraphTransAttentionLayer(nn.Module): 60 | """ 61 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 62 | """ 63 | 64 | def __init__(self, in_features, out_features, opt, device, concat=True, edge_weights=None): 65 | super(SpGraphTransAttentionLayer, self).__init__() 66 | self.in_features = in_features 67 | self.out_features = out_features 68 | self.alpha = opt['leaky_relu_slope'] 69 | self.concat = concat 70 | self.device = device 71 | self.opt = opt 72 | self.h = int(opt['heads']) 73 | self.edge_weights = edge_weights 74 | 75 | try: 76 | self.attention_dim = opt['attention_dim'] 77 | except KeyError: 78 | self.attention_dim = out_features 79 | 80 | assert self.attention_dim % self.h == 0, "Number of heads ({}) must be a factor of the dimension size ({})".format( 81 | self.h, self.attention_dim) 82 | self.d_k = self.attention_dim // self.h 83 | 84 | self.Q = nn.Linear(in_features, self.attention_dim) 85 | self.init_weights(self.Q) 86 | 87 | self.V = nn.Linear(in_features, self.attention_dim) 88 | self.init_weights(self.V) 89 | 90 | self.K = nn.Linear(in_features, self.attention_dim) 91 | self.init_weights(self.K) 92 | 93 | self.activation = nn.Sigmoid() # nn.LeakyReLU(self.alpha) 94 | 95 | self.Wout = nn.Linear(self.d_k, in_features) 96 | self.init_weights(self.Wout) 97 | 98 | def init_weights(self, m): 99 | if type(m) == nn.Linear: 100 | # nn.init.xavier_uniform_(m.weight, gain=1.414) 101 | # m.bias.data.fill_(0.01) 102 | nn.init.constant_(m.weight, 1e-5) 103 | 104 | def forward(self, x, edge): 105 | q = self.Q(x) 106 | k = self.K(x) 107 | v = self.V(x) 108 | 109 | # perform linear operation and split into h heads 110 | 111 | k = k.view(-1, self.h, self.d_k) 112 | q = q.view(-1, self.h, self.d_k) 113 | v = v.view(-1, self.h, self.d_k) 114 | 115 | # transpose to get dimensions [n_nodes, attention_dim, n_heads] 116 | 117 | k = k.transpose(1, 2) 118 | q = q.transpose(1, 2) 119 | v = v.transpose(1, 2) 120 | 121 | src = q[edge[0, :], :, :] 122 | dst_k = k[edge[1, :], :, :] 123 | prods = torch.sum(src * dst_k, dim=1) / np.sqrt(self.d_k) 124 | if self.opt['reweight_attention'] and self.edge_weights is not None: 125 | prods = prods * self.edge_weights.unsqueeze(dim=1) 126 | attention = softmax(prods, edge[self.opt['attention_norm_idx']]) 127 | return attention, v 128 | 129 | def __repr__(self): 130 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 131 | 132 | 133 | if __name__ == '__main__': 134 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 135 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'heads': 2, 'K': 10, 136 | 'attention_norm_idx': 0, 'add_source': False, 137 | 'alpha_dim': 'sc', 'beta_dim': 'sc', 'max_nfe': 1000, 'mix_features': False 138 | } 139 | dataset = get_dataset(opt, '../data', False) 140 | t = 1 141 | func = ODEFuncTransformerAtt(dataset.data.num_features, 6, opt, dataset.data, device) 142 | out = func(t, dataset.data.x) 143 | -------------------------------------------------------------------------------- /src/homophilic_graphs/geometric_integrators.py: -------------------------------------------------------------------------------- 1 | from geometric_solvers import GeomtricFixedGridODESolver 2 | from torchdiffeq._impl.misc import Perturb 3 | import torch 4 | 5 | 6 | class SymplecticEuler(GeomtricFixedGridODESolver): 7 | order = 1 8 | 9 | def _step_func(self, func, t0, dt, t1, y0): 10 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1) 11 | f0x = torch.split(func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0] 12 | x1 = x0 + dt*f0x 13 | f0z = torch.split(func(t0, torch.cat([x1,z0],dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[1] 14 | z1 = z0 + dt*f0z 15 | y1 = torch.cat([x1,z1],dim=1) 16 | return y1 17 | 18 | class Leapfrog(GeomtricFixedGridODESolver): 19 | order = 2 20 | 21 | def _step_func(self, func, t0, dt, t1, y0): 22 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1) 23 | f0x_1 = torch.split(func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0] 24 | x_half = x0 + dt/2.*f0x_1 25 | 26 | f0z = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[1] 27 | z1 = z0 + dt * f0z 28 | 29 | f0x_2 = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if self.perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[0] 30 | x1 = x_half + dt/2.*f0x_2 31 | y1 = torch.cat([x1,z1],dim=1) 32 | return y1 33 | 34 | 35 | def SymplecticEuler_step_func(func, t0, dt, t1, y0, perturb=False): 36 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1) 37 | f0x = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0] 38 | x1 = x0 + dt*f0x 39 | y0 = torch.cat([x1,z0],dim=1) 40 | f0z = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[1] 41 | z1 = z0 + dt*f0z 42 | y1 = torch.cat([x1,z1],dim=1) 43 | return y1 44 | 45 | 46 | def Leapfrog_step_func(func, t0, dt, t1, y0, perturb=False): 47 | x0, z0 = torch.split(y0, y0.shape[1] // 2, dim=1) 48 | f0x_1 = torch.split(func(t0, y0, perturb=Perturb.NEXT if perturb else Perturb.NONE), y0.shape[1] // 2, dim=1)[0] 49 | x_half = x0 + dt / 2. * f0x_1 50 | 51 | f0z = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[1] 52 | z1 = z0 + dt * f0z 53 | 54 | f0x_2 = torch.split(func(t0, torch.cat([x_half, z0], dim=1), perturb=Perturb.NEXT if perturb else Perturb.NONE),y0.shape[1] // 2, dim=1)[0] 55 | x1 = x_half + dt / 2. * f0x_2 56 | y1 = torch.cat([x1, z1], dim=1) 57 | return y1 -------------------------------------------------------------------------------- /src/homophilic_graphs/geometric_solvers.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import torch 3 | from torchdiffeq._impl.misc import _handle_unused_kwargs 4 | 5 | 6 | class GeomtricFixedGridODESolver(metaclass=abc.ABCMeta): 7 | order: int 8 | 9 | def __init__(self, func, y0, step_size=None, grid_constructor=None, interp="linear", perturb=False, **unused_kwargs): 10 | self.func = func 11 | self.y0 = y0 12 | self.dtype = y0.dtype 13 | self.device = y0.device 14 | self.step_size = step_size 15 | self.interp = interp 16 | self.perturb = perturb 17 | self.grid_constructor = self._grid_constructor_from_step_size(step_size) 18 | 19 | @staticmethod 20 | def _grid_constructor_from_step_size(step_size): 21 | def _grid_constructor(func, y0, t): 22 | start_time = t[0] 23 | end_time = t[-1] 24 | 25 | niters = torch.ceil((end_time - start_time) / step_size + 1).item() 26 | t_infer = torch.arange(0, niters, dtype=t.dtype, device=t.device) * step_size + start_time 27 | t_infer[-1] = t[-1] 28 | 29 | return t_infer 30 | return _grid_constructor 31 | 32 | @abc.abstractmethod 33 | def _step_func(self, func, t0, dt, t1, y0): 34 | pass 35 | 36 | def integrate(self, t): 37 | time_grid = self.grid_constructor(self.func, self.y0, t) 38 | assert time_grid[0] == t[0] and time_grid[-1] == t[-1] 39 | 40 | solution = torch.empty(len(t), *self.y0.shape, dtype=self.y0.dtype, device=self.y0.device) 41 | solution[0] = self.y0 42 | 43 | j = 1 44 | y0 = self.y0 45 | for t0, t1 in zip(time_grid[:-1], time_grid[1:]): 46 | dt = t1 - t0 47 | y1 = self._step_func(self.func, t0, dt, t1, y0) 48 | while j < len(t) and t1 >= t[j]: 49 | if self.interp == "linear": 50 | solution[j] = self._linear_interp(t0, t1, y0, y1, t[j]) 51 | elif self.interp == "cubic": 52 | raise NotImplementedError("Not implemented for geometric integrators") 53 | else: 54 | raise ValueError(f"Unknown interpolation method {self.interp}") 55 | j += 1 56 | y0 = y1 57 | 58 | return solution 59 | 60 | def integrate_until_event(self, t0, event_fn): 61 | raise NotImplementedError("Not implemented for geometric integrators") 62 | 63 | def _cubic_hermite_interp(self, t0, y0, f0, t1, y1, f1, t): 64 | h = (t - t0) / (t1 - t0) 65 | h00 = (1 + 2 * h) * (1 - h) * (1 - h) 66 | h10 = h * (1 - h) * (1 - h) 67 | h01 = h * h * (3 - 2 * h) 68 | h11 = h * h * (h - 1) 69 | dt = (t1 - t0) 70 | return h00 * y0 + h10 * dt * f0 + h01 * y1 + h11 * dt * f1 71 | 72 | def _linear_interp(self, t0, t1, y0, y1, t): 73 | if t == t0: 74 | return y0 75 | if t == t1: 76 | return y1 77 | slope = (t - t0) / (t1 - t0) 78 | return y0 + slope * (y1 - y0) -------------------------------------------------------------------------------- /src/homophilic_graphs/good_params_graphCON.py: -------------------------------------------------------------------------------- 1 | """ 2 | Currently the adjoint method is not implemented and so the params for Pubmed are not correct as I have changed 3 | adjoint to False, but kept all other hyperparameter settings. 4 | """ 5 | good_params_dict = { 6 | 'Cora': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False, 'adjoint_method': 'adaptive_heun', 7 | 'adjoint_step_size': 1, 'alpha': 1.0, 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 128, 8 | 'attention_norm_idx': 1, 'attention_rewiring': False, 'attention_type': 'scaled_dot', 'augment': False, 9 | 'baseline': False, 'batch_norm': False, 'beltrami': True, 'beta_dim': 'sc', 'block': 'attention', 'cpus': 1, 10 | 'data_norm': 'rw', 'dataset': 'Cora', 'decay': 0.00507685443154266, 'directional_penalty': None, 11 | 'dropout': 0.17167167167167166, 'dt': 0.001, 'dt_min': 1e-05, 'epoch': 100, 'exact': True, 'fc_out': False, 12 | 'feat_hidden_dim': 64, 'function': 'laplacian', 'gdc_avg_degree': 64, 'gdc_k': 64, 'gdc_method': 'ppr', 13 | 'gdc_sparsification': 'topk', 'gdc_threshold': 0.01, 'gpus': 0.5, 'grace_period': 20, 'heads': 8, 14 | 'heat_time': 3.0, 'hidden_dim': 80, 'input_dropout': 0.5, 'jacobian_norm2': None, 'kinetic_energy': None, 15 | 'label_rate': 0.5, 'leaky_relu_slope': 0.2, 'lr': 0.01334134134134134, 'max_epochs': 1000, 'max_iters': 100, 16 | 'max_nfe': 2000, 'method': 'symplectic_euler', 'metric': 'accuracy', 'mix_features': False, 17 | 'name': 'cora_beltrami_splits', 'new_edges': 'random', 'no_alpha_sigmoid': False, 'not_lcc': True, 18 | 'num_init': 1, 'num_samples': 1000, 'num_splits': 1, 'ode_blocks': 1, 'optimizer': 'adamax', 'patience': 100, 19 | 'pos_enc_hidden_dim': 16, 'pos_enc_orientation': 'row', 'pos_enc_type': 'GDC', 'ppr_alpha': 0.05, 20 | 'reduction_factor': 10, 'regularise': False, 'reweight_attention': False, 'rewire_KNN': False, 21 | 'rewire_KNN_T': 'T0', 'rewire_KNN_epoch': 10, 'rewire_KNN_k': 64, 'rewire_KNN_sym': False, 'rewiring': None, 22 | 'rw_addD': 0.02, 'rw_rmvR': 0.02, 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True, 23 | 'step_size': 1, 'threshold_type': 'addD_rvR', 'time': 18.294754260552843, 'tol_scale': 821.9773048827274, 24 | 'tol_scale_adjoint': 1.0, 'total_deriv': None, 'use_cora_defaults': False, 'use_flux': False, 25 | 'use_labels': False, 'use_lcc': True, 'use_mlp': False}, 26 | 'Citeseer': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False, 27 | 'adjoint_method': 'adaptive_heun', 'adjoint_step_size': 1, 'alpha': 1.0, 28 | 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 32, 'attention_norm_idx': 1, 29 | 'attention_rewiring': False, 'attention_type': 'exp_kernel', 'augment': False, 30 | 'baseline': False, 'batch_norm': False, 'beltrami': False, 'beta_dim': 'sc', 31 | 'block': 'attention', 'cpus': 1, 'data_norm': 'rw', 'dataset': 'Citeseer', 32 | 'decay': 0.1, 'directional_penalty': None, 'dropout': 0.8429429429429429, 'dt': 0.001, 33 | 'dt_min': 1e-05, 'epoch': 250, 'exact': True, 'fc_out': False, 'feat_hidden_dim': 64, 34 | 'function': 'laplacian', 'gdc_avg_degree': 64, 'gdc_k': 128, 'gdc_method': 'ppr', 35 | 'gdc_sparsification': 'topk', 'gdc_threshold': 0.01, 'gpus': 1.0, 'grace_period': 20, 36 | 'heads': 8, 'heat_time': 3.0, 'hidden_dim': 80, 'input_dropout': 0.6803233752085334, 37 | 'jacobian_norm2': None, 'kinetic_energy': None, 'label_rate': 0.5, 38 | 'leaky_relu_slope': 0.5825086997804176, 'lr': 0.006036036036036037, 'max_epochs': 1000, 39 | 'max_iters': 100, 'max_nfe': 3000, 'method': 'symplectic_euler', 'metric': 'accuracy', 40 | 'mix_features': False, 'name': 'Citeseer_beltrami_1_KNN', 'new_edges': 'random', 41 | 'no_alpha_sigmoid': False, 'not_lcc': True, 'num_class': 6, 'num_feature': 3703, 42 | 'num_init': 2, 'num_nodes': 2120, 'num_samples': 400, 'num_splits': 1, 'ode_blocks': 1, 43 | 'optimizer': 'adam', 'patience': 100, 'pos_enc_dim': 'row', 'pos_enc_hidden_dim': 16, 44 | 'ppr_alpha': 0.05, 'reduction_factor': 4, 'regularise': False, 45 | 'reweight_attention': False, 'rewire_KNN': False, 'rewire_KNN_epoch': 10, 46 | 'rewire_KNN_k': 64, 'rewire_KNN_sym': False, 'rewiring': None, 'rw_addD': 0.02, 47 | 'rw_rmvR': 0.02, 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True, 48 | 'step_size': 1, 'threshold_type': 'addD_rvR', 'time': 7.874113442879092, 49 | 'tol_scale': 2.9010446330432815, 'tol_scale_adjoint': 1.0, 'total_deriv': None, 50 | 'use_cora_defaults': False, 'use_flux': False, 'use_labels': False, 'use_lcc': True, 51 | 'use_mlp': False}, 52 | 'Pubmed': {'M_nodes': 64, 'adaptive': False, 'add_source': True, 'adjoint': False, 53 | 'adjoint_method': 'adaptive_heun', 'adjoint_step_size': 1, 'alpha': 1.0, 54 | 'alpha_dim': 'sc', 'att_samp_pct': 1, 'attention_dim': 16, 'attention_norm_idx': 0, 55 | 'attention_rewiring': False, 'attention_type': 'cosine_sim', 'augment': False, 56 | 'baseline': False, 'batch_norm': False, 'beltrami': False, 'beta_dim': 'sc', 57 | 'block': 'attention', 'cpus': 1, 'data_norm': 'rw', 'dataset': 'Pubmed', 58 | 'decay': 0.0018236722171703636, 'directional_penalty': None, 59 | 'dropout': 0.18918918918918917, 'dt': 0.001, 'dt_min': 1e-05, 'epoch': 600, 60 | 'exact': False, 'fc_out': False, 'feat_hidden_dim': 64, 'function': 'laplacian', 61 | 'gdc_avg_degree': 64, 'gdc_k': 64, 'gdc_method': 'ppr', 'gdc_sparsification': 'topk', 62 | 'gdc_threshold': 0.01, 'gpus': 1.0, 'grace_period': 20, 'heads': 1, 'heat_time': 3.0, 63 | 'hidden_dim': 128, 'input_dropout': 0.5, 'jacobian_norm2': None, 'kinetic_energy': None, 64 | 'label_rate': 0.5, 'leaky_relu_slope': 0.2, 'lr': 0.02522522522522523, 65 | 'max_epochs': 1000, 'max_iters': 100, 'max_nfe': 5000, 'method': 'symplectic_euler', 66 | 'metric': 'test_acc', 'mix_features': False, 'name': None, 'new_edges': 'random', 67 | 'no_alpha_sigmoid': False, 'not_lcc': True, 'num_init': 1, 'num_samples': 400, 68 | 'num_splits': 1, 'ode_blocks': 1, 'optimizer': 'adamax', 'patience': 100, 69 | 'pos_enc_dim': 'row', 'pos_enc_hidden_dim': 16, 'ppr_alpha': 0.05, 70 | 'reduction_factor': 10, 'regularise': False, 'reweight_attention': False, 71 | 'rewire_KNN': False, 'rewire_KNN_T': 'T0', 'rewire_KNN_epoch': 10, 'rewire_KNN_k': 64, 72 | 'rewire_KNN_sym': False, 'rewiring': None, 'rw_addD': 0.02, 'rw_rmvR': 0.02, 73 | 'self_loop_weight': 1, 'sparsify': 'S_hat', 'square_plus': True, 'step_size': 1, 74 | 'threshold_type': 'addD_rvR', 'time': 12.942327880200853, 75 | 'tol_scale': 1991.0688305523001, 'tol_scale_adjoint': 16324.368093998313, 76 | 'total_deriv': None, 'use_cora_defaults': False, 'use_flux': False, 'use_labels': False, 77 | 'use_lcc': True, 'use_mlp': False, 'folder': 'pubmed_linear_att_beltrami_adj2', 78 | 'index': 0, 'run_with_KNN': False, 'change_att_sim_type': False, 'reps': 1, 79 | 'max_test_steps': 100, 'no_early': False, 'earlystopxT': 5.0, 'pos_enc_csv': False, 80 | 'pos_enc_type': 'GDC'} 81 | } 82 | -------------------------------------------------------------------------------- /src/homophilic_graphs/model_configurations.py: -------------------------------------------------------------------------------- 1 | from function_transformer_attention import ODEFuncTransformerAtt 2 | from function_GAT_attention import ODEFuncAtt 3 | from function_laplacian_diffusion import LaplacianODEFunc 4 | from function_gcn import GCNFunc 5 | from block_transformer_attention import AttODEblock 6 | from block_constant import ConstantODEblock 7 | 8 | 9 | class BlockNotDefined(Exception): 10 | pass 11 | 12 | 13 | class FunctionNotDefined(Exception): 14 | pass 15 | 16 | 17 | def set_block(opt): 18 | ode_str = opt['block'] 19 | if ode_str == 'attention': 20 | block = AttODEblock 21 | elif ode_str == 'constant': 22 | block = ConstantODEblock 23 | else: 24 | raise BlockNotDefined 25 | return block 26 | 27 | 28 | def set_function(opt): 29 | ode_str = opt['function'] 30 | if ode_str == 'laplacian': 31 | f = LaplacianODEFunc 32 | elif ode_str == 'GAT': 33 | f = ODEFuncAtt 34 | elif ode_str == 'transformer': 35 | f = ODEFuncTransformerAtt 36 | elif ode_str == 'gcn': 37 | f = GCNFunc 38 | else: 39 | raise FunctionNotDefined 40 | return f 41 | -------------------------------------------------------------------------------- /src/homophilic_graphs/odeint_geometric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.functional import vjp 3 | from geometric_integrators import SymplecticEuler, Leapfrog 4 | from torchdiffeq._impl.scipy_wrapper import ScipyWrapperODESolver 5 | from torchdiffeq._impl.misc import _check_inputs, _flat_to_shape 6 | 7 | SOLVERS = { 8 | 'symplectic_euler': SymplecticEuler, 9 | 'leapfrog': Leapfrog 10 | } 11 | 12 | 13 | def odeint(func, y0, t, *, rtol=1e-7, atol=1e-9, method=None, options=None, event_fn=None): 14 | """Integrate a system of ordinary differential equations. 15 | 16 | Solves the initial value problem for a non-stiff system of first order ODEs: 17 | ``` 18 | dy/dt = func(t, y), y(t[0]) = y0 19 | ``` 20 | where y is a Tensor or tuple of Tensors of any shape. 21 | 22 | Output dtypes and numerical precision are based on the dtypes of the inputs `y0`. 23 | 24 | Args: 25 | func: Function that maps a scalar Tensor `t` and a Tensor holding the state `y` 26 | into a Tensor of state derivatives with respect to time. Optionally, `y` 27 | can also be a tuple of Tensors. 28 | y0: N-D Tensor giving starting value of `y` at time point `t[0]`. Optionally, `y0` 29 | can also be a tuple of Tensors. 30 | t: 1-D Tensor holding a sequence of time points for which to solve for 31 | `y`, in either increasing or decreasing order. The first element of 32 | this sequence is taken to be the initial time point. 33 | rtol: optional float64 Tensor specifying an upper bound on relative error, 34 | per element of `y`. 35 | atol: optional float64 Tensor specifying an upper bound on absolute error, 36 | per element of `y`. 37 | method: optional string indicating the integration method to use. 38 | options: optional dict of configuring options for the indicated integration 39 | method. Can only be provided if a `method` is explicitly set. 40 | event_fn: Function that maps the state `y` to a Tensor. The solve terminates when 41 | event_fn evaluates to zero. If this is not None, all but the first elements of 42 | `t` are ignored. 43 | 44 | Returns: 45 | y: Tensor, where the first dimension corresponds to different 46 | time points. Contains the solved value of y for each desired time point in 47 | `t`, with the initial value `y0` being the first element along the first 48 | dimension. 49 | 50 | Raises: 51 | ValueError: if an invalid `method` is provided. 52 | """ 53 | 54 | shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_fn, SOLVERS) 55 | 56 | solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 57 | 58 | if event_fn is None: 59 | solution = solver.integrate(t) 60 | else: 61 | event_t, solution = solver.integrate_until_event(t[0], event_fn) 62 | event_t = event_t.to(t) 63 | if t_is_reversed: 64 | event_t = -event_t 65 | 66 | if shapes is not None: 67 | solution = _flat_to_shape(solution, (len(t),), shapes) 68 | 69 | if event_fn is None: 70 | return solution 71 | else: 72 | return event_t, solution -------------------------------------------------------------------------------- /src/homophilic_graphs/regularized_ODE_function.py: -------------------------------------------------------------------------------- 1 | ## This code has been adapted from https://github.com/cfinlay/ffjord-rnode/ 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class RegularizedODEfunc(nn.Module): 8 | def __init__(self, odefunc, regularization_fns): 9 | super(RegularizedODEfunc, self).__init__() 10 | self.odefunc = odefunc 11 | self.regularization_fns = regularization_fns 12 | 13 | def before_odeint(self, *args, **kwargs): 14 | self.odefunc.before_odeint(*args, **kwargs) 15 | 16 | def forward(self, t, state): 17 | 18 | with torch.enable_grad(): 19 | x = state[0] 20 | x.requires_grad_(True) 21 | t.requires_grad_(True) 22 | dstate = self.odefunc(t, x) 23 | if len(state) > 1: 24 | dx = dstate 25 | reg_states = tuple(reg_fn(x, t, dx, self.odefunc) for reg_fn in self.regularization_fns) 26 | return (dstate,) + reg_states 27 | else: 28 | return dstate 29 | 30 | @property 31 | def _num_evals(self): 32 | return self.odefunc._num_evals 33 | 34 | 35 | def total_derivative(x, t, dx, unused_context): 36 | del unused_context 37 | 38 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 39 | 40 | try: 41 | u = torch.full_like(dx, 1 / x.numel(), requires_grad=True) 42 | tmp = torch.autograd.grad((u * dx).sum(), t, create_graph=True)[0] 43 | partial_dt = torch.autograd.grad(tmp.sum(), u, create_graph=True)[0] 44 | 45 | total_deriv = directional_dx + partial_dt 46 | except RuntimeError as e: 47 | if 'One of the differentiated Tensors' in e.__str__(): 48 | raise RuntimeError( 49 | 'No partial derivative with respect to time. Use mathematically equivalent "directional_derivative" regularizer instead') 50 | 51 | tdv2 = total_deriv.pow(2).view(x.size(0), -1) 52 | 53 | return 0.5 * tdv2.mean(dim=-1) 54 | 55 | 56 | def directional_derivative(x, t, dx, unused_context): 57 | del t, unused_context 58 | 59 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 60 | ddx2 = directional_dx.pow(2).view(x.size(0), -1) 61 | 62 | return 0.5 * ddx2.mean(dim=-1) 63 | 64 | 65 | def quadratic_cost(x, t, dx, unused_context): 66 | del x, t, unused_context 67 | dx = dx.view(dx.shape[0], -1) 68 | return 0.5 * dx.pow(2).mean(dim=-1) 69 | 70 | 71 | def divergence_bf(dx, x): 72 | sum_diag = 0. 73 | for i in range(x.shape[1]): 74 | sum_diag += torch.autograd.grad(dx[:, i].sum(), x, create_graph=True)[0].contiguous()[:, i].contiguous() 75 | return sum_diag.contiguous() 76 | 77 | 78 | def jacobian_frobenius_regularization_fn(x, t, dx, context): 79 | del t 80 | return divergence_bf(dx, x) 81 | -------------------------------------------------------------------------------- /src/homophilic_graphs/run_GNN.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import torch 5 | from torch_geometric.nn import GCNConv, ChebConv # noqa 6 | import torch.nn.functional as F 7 | from GNN import GNN 8 | from GNN_early import GNNEarly 9 | import time 10 | from data import get_dataset, set_train_val_test_split 11 | from ogb.nodeproppred import Evaluator 12 | from good_params_graphCON import good_params_dict 13 | import wandb 14 | 15 | 16 | def get_optimizer(name, parameters, lr, weight_decay=0): 17 | if name == 'sgd': 18 | return torch.optim.SGD(parameters, lr=lr, weight_decay=weight_decay) 19 | elif name == 'rmsprop': 20 | return torch.optim.RMSprop(parameters, lr=lr, weight_decay=weight_decay) 21 | elif name == 'adagrad': 22 | return torch.optim.Adagrad(parameters, lr=lr, weight_decay=weight_decay) 23 | elif name == 'adam': 24 | return torch.optim.Adam(parameters, lr=lr, weight_decay=weight_decay) 25 | elif name == 'adamax': 26 | return torch.optim.Adamax(parameters, lr=lr, weight_decay=weight_decay) 27 | else: 28 | raise Exception("Unsupported optimizer: {}".format(name)) 29 | 30 | 31 | def add_labels(feat, labels, idx, num_classes, device): 32 | onehot = torch.zeros([feat.shape[0], num_classes]).to(device) 33 | if idx.dtype == torch.bool: 34 | idx = torch.where(idx)[0] # convert mask to linear index 35 | onehot[idx, labels.squeeze()[idx]] = 1 36 | 37 | return torch.cat([feat, onehot], dim=-1) 38 | 39 | 40 | def get_label_masks(data, mask_rate=0.5): 41 | """ 42 | when using labels as features need to split training nodes into training and prediction 43 | """ 44 | if data.train_mask.dtype == torch.bool: 45 | idx = torch.where(data.train_mask)[0] 46 | else: 47 | idx = data.train_mask 48 | mask = torch.rand(idx.shape) < mask_rate 49 | train_label_idx = idx[mask] 50 | train_pred_idx = idx[~mask] 51 | return train_label_idx, train_pred_idx 52 | 53 | 54 | def train(model, optimizer, data): 55 | model.train() 56 | optimizer.zero_grad() 57 | feat = data.x 58 | if model.opt['use_labels']: 59 | train_label_idx, train_pred_idx = get_label_masks(data, model.opt['label_rate']) 60 | 61 | feat = add_labels(feat, data.y, train_label_idx, model.num_classes, model.device) 62 | else: 63 | train_pred_idx = data.train_mask 64 | 65 | out = model(feat) 66 | if model.opt['dataset'] == 'ogbn-arxiv': 67 | lf = torch.nn.functional.nll_loss 68 | loss = lf(out.log_softmax(dim=-1)[data.train_mask], data.y.squeeze(1)[data.train_mask]) 69 | else: 70 | lf = torch.nn.CrossEntropyLoss() 71 | loss = lf(out[data.train_mask], data.y.squeeze()[data.train_mask]) 72 | 73 | if model.opt['wandb_watch_grad']: # Tell wandb to watch what the model gets up to: gradients, weights, and more! 74 | wandb.watch(model, lf, log="all", log_freq=10) 75 | 76 | if model.odeblock.nreg > 0: # add regularisation - slower for small data, but faster and better performance for large data 77 | reg_states = tuple(torch.mean(rs) for rs in model.reg_states) 78 | regularization_coeffs = model.regularization_coeffs 79 | 80 | reg_loss = sum( 81 | reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 82 | ) 83 | loss = loss + reg_loss 84 | 85 | model.fm.update(model.getNFE()) 86 | model.resetNFE() 87 | loss.backward() 88 | optimizer.step() 89 | model.bm.update(model.getNFE()) 90 | model.resetNFE() 91 | return loss.item() 92 | 93 | 94 | @torch.no_grad() 95 | def test(model, data, opt=None): # opt required for runtime polymorphism 96 | model.eval() 97 | feat = data.x 98 | if model.opt['use_labels']: 99 | feat = add_labels(feat, data.y, data.train_mask, model.num_classes, model.device) 100 | logits, accs = model(feat), [] 101 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 102 | pred = logits[mask].max(1)[1] 103 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 104 | accs.append(acc) 105 | return accs 106 | 107 | 108 | def print_model_params(model): 109 | print(model) 110 | for name, param in model.named_parameters(): 111 | if param.requires_grad: 112 | print(name) 113 | print(param.data.shape) 114 | 115 | 116 | @torch.no_grad() 117 | def test_OGB(model, data, opt): 118 | if opt['dataset'] == 'ogbn-arxiv': 119 | name = 'ogbn-arxiv' 120 | 121 | feat = data.x 122 | if model.opt['use_labels']: 123 | feat = add_labels(feat, data.y, data.train_mask, model.num_classes, model.device) 124 | 125 | evaluator = Evaluator(name=name) 126 | model.eval() 127 | 128 | out = model(feat).log_softmax(dim=-1) 129 | y_pred = out.argmax(dim=-1, keepdim=True) 130 | 131 | train_acc = evaluator.eval({ 132 | 'y_true': data.y[data.train_mask], 133 | 'y_pred': y_pred[data.train_mask], 134 | })['acc'] 135 | valid_acc = evaluator.eval({ 136 | 'y_true': data.y[data.val_mask], 137 | 'y_pred': y_pred[data.val_mask], 138 | })['acc'] 139 | test_acc = evaluator.eval({ 140 | 'y_true': data.y[data.test_mask], 141 | 'y_pred': y_pred[data.test_mask], 142 | })['acc'] 143 | 144 | return train_acc, valid_acc, test_acc 145 | 146 | 147 | def merge_cmd_args(cmd_opt, opt): 148 | if cmd_opt['function'] is not None: 149 | opt['function'] = cmd_opt['function'] 150 | if cmd_opt['block'] is not None: 151 | opt['block'] = cmd_opt['block'] 152 | if cmd_opt['self_loop_weight'] is not None: 153 | opt['self_loop_weight'] = cmd_opt['self_loop_weight'] 154 | if cmd_opt['method'] is not None: 155 | opt['method'] = cmd_opt['method'] 156 | if cmd_opt['step_size'] != 1: 157 | opt['step_size'] = cmd_opt['step_size'] 158 | if cmd_opt['time'] != 1: 159 | opt['time'] = cmd_opt['time'] 160 | if cmd_opt['epoch'] != 100: 161 | opt['epoch'] = cmd_opt['epoch'] 162 | 163 | 164 | def wandb_setup(opt): 165 | if opt['wandb']: 166 | if opt['use_wandb_offline']: 167 | os.environ["WANDB_MODE"] = "offline" 168 | else: 169 | os.environ["WANDB_MODE"] = "run" 170 | if opt['wandb'] and 'wandb_run_name' in opt.keys(): 171 | wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'], 172 | name=opt['wandb_run_name'], reinit=True, config=opt) 173 | else: 174 | wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'], 175 | reinit=True, config=opt) 176 | wandb.define_metric("epoch_step") # Customize axes - https://docs.wandb.ai/guides/track/log 177 | if opt['wandb_track_grad_flow']: 178 | wandb.define_metric("grad_flow_step") # Customize axes - https://docs.wandb.ai/guides/track/log 179 | wandb.define_metric("gf_e*", step_metric="grad_flow_step") # grad_flow_epoch* 180 | opt = wandb.config # access all HPs through wandb.config, so logging matches execution! 181 | else: 182 | os.environ["WANDB_MODE"] = "disabled" # sets as NOOP, saves keep writing: if opt['wandb']: 183 | return opt 184 | 185 | 186 | def main(cmd_opt): 187 | best_opt = good_params_dict[cmd_opt['dataset']] 188 | opt = {**cmd_opt, **best_opt} 189 | merge_cmd_args(cmd_opt, opt) 190 | 191 | dataset = get_dataset(opt, '../../data', opt['not_lcc']) 192 | 193 | opt = wandb_setup(opt) 194 | 195 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 196 | results = [] 197 | for rep in range(opt['num_splits']): 198 | model = GNN(opt, dataset, device).to(device) if opt["no_early"] else GNNEarly(opt, dataset, device).to(device) 199 | # print(opt) 200 | if not opt['planetoid_split'] and opt['dataset'] in ['Cora', 'Citeseer', 'Pubmed']: 201 | dataset.data = set_train_val_test_split(np.random.randint(0, 1000), dataset.data, 202 | num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500) 203 | # todo for some reason the submodule parameters inside the attention module don't show up when running on GPU. 204 | data = dataset.data.to(device) 205 | parameters = [p for p in model.parameters() if p.requires_grad] 206 | print_model_params(model) 207 | optimizer = get_optimizer(opt['optimizer'], parameters, lr=opt['lr'], weight_decay=opt['decay']) 208 | best_time = val_acc = test_acc = train_acc = best_epoch = 0 209 | this_test = test_OGB if opt['dataset'] == 'ogbn-arxiv' else test 210 | for epoch in range(1, opt['epoch']): 211 | start_time = time.time() 212 | 213 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt) 214 | loss = train(model, optimizer, data) 215 | 216 | if tmp_val_acc > val_acc: 217 | best_epoch = epoch 218 | train_acc = tmp_train_acc 219 | val_acc = tmp_val_acc 220 | test_acc = tmp_test_acc 221 | best_time = opt['time'] 222 | if not opt['no_early'] and model.odeblock.test_integrator.solver.best_val > val_acc: 223 | best_epoch = epoch 224 | val_acc = model.odeblock.test_integrator.solver.best_val 225 | test_acc = model.odeblock.test_integrator.solver.best_test 226 | train_acc = model.odeblock.test_integrator.solver.best_train 227 | best_time = model.odeblock.test_integrator.solver.best_time 228 | 229 | if opt['wandb'] and ((epoch) % opt['wandb_log_freq']) == 0 and rep == 0: 230 | wandb.log({"loss": loss, 231 | "train_acc": train_acc, "val_acc": val_acc, "test_acc": test_acc, 232 | "epoch_step": epoch}) # , step=epoch) wandb: WARNING Step must only increase in log calls 233 | 234 | log = 'Epoch: {:03d}, Runtime {:03f}, Loss {:03f}, forward nfe {:d}, backward nfe {:d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}, Best time: {:.4f}' 235 | 236 | print( 237 | log.format(epoch, time.time() - start_time, loss, model.fm.sum, model.bm.sum, train_acc, val_acc, 238 | test_acc, 239 | best_time)) 240 | print('best val accuracy {:03f} with test accuracy {:03f} at epoch {:d} and best time {:03f}'.format(val_acc, 241 | test_acc, 242 | best_epoch, 243 | best_time)) 244 | if opt['num_splits'] > 1: 245 | results.append([test_acc, val_acc, train_acc]) 246 | 247 | if opt['num_splits'] > 1: 248 | test_acc_mean, val_acc_mean, train_acc_mean = np.mean(results, axis=0) * 100 249 | test_acc_std = np.sqrt(np.var(results, axis=0)[0]) * 100 250 | if opt['wandb']: 251 | wandb_results = {'test_mean': test_acc_mean, 'val_mean': val_acc_mean, 'train_mean': train_acc_mean, 252 | 'test_acc_std': test_acc_std} 253 | wandb.log(wandb_results) 254 | if opt['wandb']: 255 | wandb.finish() 256 | return train_acc, val_acc, test_acc 257 | 258 | 259 | if __name__ == '__main__': 260 | parser = argparse.ArgumentParser() 261 | parser.add_argument('--use_cora_defaults', action='store_true', 262 | help='Whether to run with best params for cora. Overrides the choice of dataset') 263 | 264 | parser.add_argument('--id', type=int) 265 | parser.add_argument('--num_splits', type=int, default=1, help='number of random splits to use') 266 | # data args 267 | parser.add_argument('--dataset', type=str, default='Cora', 268 | help='Cora, Citeseer, Pubmed') 269 | parser.add_argument('--data_norm', type=str, default='rw', 270 | help='rw for random walk, gcn for symmetric gcn norm') 271 | parser.add_argument('--self_loop_weight', type=float, help='Weight of self-loops.') 272 | parser.add_argument('--use_labels', dest='use_labels', action='store_true', help='Also diffuse labels') 273 | parser.add_argument('--label_rate', type=float, default=0.5, 274 | help='% of training labels to use when --use_labels is set.') 275 | parser.add_argument('--planetoid_split', action='store_true', 276 | help='use planetoid splits for Cora/Citeseer/Pubmed') 277 | # GNN args 278 | parser.add_argument('--hidden_dim', type=int, default=16, help='Hidden dimension.') 279 | parser.add_argument('--fc_out', dest='fc_out', action='store_true', 280 | help='Add a fully connected layer to the decoder.') 281 | parser.add_argument('--input_dropout', type=float, default=0.5, help='Input dropout rate.') 282 | parser.add_argument('--dropout', type=float, default=0.0, help='Dropout rate.') 283 | parser.add_argument("--batch_norm", dest='batch_norm', action='store_true', help='search over reg params') 284 | parser.add_argument('--optimizer', type=str, default='adam', help='One from sgd, rmsprop, adam, adagrad, adamax.') 285 | parser.add_argument('--lr', type=float, default=0.01, help='Learning rate.') 286 | parser.add_argument('--decay', type=float, default=5e-4, help='Weight decay for optimization') 287 | parser.add_argument('--epoch', type=int, default=100, help='Number of training epochs per iteration.') 288 | parser.add_argument('--alpha', type=float, default=1.0, help='Factor in front matrix A.') 289 | parser.add_argument('--alpha_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) alpha') 290 | parser.add_argument('--no_alpha_sigmoid', dest='no_alpha_sigmoid', action='store_true', 291 | help='apply sigmoid before multiplying by alpha') 292 | parser.add_argument('--beta_dim', type=str, default='sc', help='choose either scalar (sc) or vector (vc) beta') 293 | parser.add_argument('--block', type=str, help='constant, mixed, attention, hard_attention, SDE') 294 | parser.add_argument('--function', type=str, help='laplacian, transformer, dorsey, GAT, SDE') 295 | parser.add_argument('--use_mlp', dest='use_mlp', action='store_true', 296 | help='Add a fully connected layer to the encoder.') 297 | parser.add_argument('--add_source', dest='add_source', action='store_true', 298 | help='If try get rid of alpha param and the beta*x0 source term') 299 | 300 | # ODE args 301 | parser.add_argument('--time', type=float, default=1.0, help='End time of ODE integrator.') 302 | parser.add_argument('--augment', action='store_true', 303 | help='double the length of the feature vector by appending zeros to stabilist ODE learning') 304 | parser.add_argument('--method', type=str, 305 | help="set the numerical solver: dopri5, euler, rk4, midpoint") 306 | parser.add_argument('--step_size', type=float, default=1, 307 | help='fixed step size when using fixed step solvers e.g. rk4') 308 | parser.add_argument('--max_iters', type=float, default=100, help='maximum number of integration steps') 309 | parser.add_argument("--adjoint_method", type=str, default="adaptive_heun", 310 | help="set the numerical solver for the backward pass: dopri5, euler, rk4, midpoint") 311 | parser.add_argument('--adjoint', dest='adjoint', action='store_true', 312 | help='use the adjoint ODE method to reduce memory footprint') 313 | parser.add_argument('--adjoint_step_size', type=float, default=1, 314 | help='fixed step size when using fixed step adjoint solvers e.g. rk4') 315 | parser.add_argument('--tol_scale', type=float, default=1., help='multiplier for atol and rtol') 316 | parser.add_argument("--tol_scale_adjoint", type=float, default=1.0, 317 | help="multiplier for adjoint_atol and adjoint_rtol") 318 | parser.add_argument('--ode_blocks', type=int, default=1, help='number of ode blocks to run') 319 | parser.add_argument("--max_nfe", type=int, default=10000000, 320 | help="Maximum number of function evaluations in an epoch. Stiff ODEs will hang if not set.") 321 | parser.add_argument("--no_early", action="store_true", 322 | help="Whether or not to use early stopping of the ODE integrator when testing.") 323 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model') 324 | parser.add_argument("--max_test_steps", type=int, default=100, 325 | help="Maximum number steps for the dopri5Early test integrator. " 326 | "used if getting OOM errors at test time") 327 | 328 | # Attention args 329 | parser.add_argument('--leaky_relu_slope', type=float, default=0.2, 330 | help='slope of the negative part of the leaky relu used in attention') 331 | parser.add_argument('--attention_dropout', type=float, default=0., help='dropout of attention weights') 332 | parser.add_argument('--heads', type=int, default=4, help='number of attention heads') 333 | parser.add_argument('--attention_norm_idx', type=int, default=0, help='0 = normalise rows, 1 = normalise cols') 334 | parser.add_argument('--attention_dim', type=int, default=64, 335 | help='the size to project x to before calculating att scores') 336 | parser.add_argument('--mix_features', dest='mix_features', action='store_true', 337 | help='apply a feature transformation xW to the ODE') 338 | parser.add_argument('--reweight_attention', dest='reweight_attention', action='store_true', 339 | help="multiply attention scores by edge weights before softmax") 340 | # regularisation args 341 | parser.add_argument('--jacobian_norm2', type=float, default=None, help="int_t ||df/dx||_F^2") 342 | parser.add_argument('--total_deriv', type=float, default=None, help="int_t ||df/dt||^2") 343 | 344 | parser.add_argument('--kinetic_energy', type=float, default=None, help="int_t ||f||_2^2") 345 | parser.add_argument('--directional_penalty', type=float, default=None, help="int_t ||(df/dx)^T f||^2") 346 | 347 | # rewiring args 348 | parser.add_argument("--not_lcc", action="store_false", help="don't use the largest connected component") 349 | parser.add_argument('--rewiring', type=str, default=None, help="two_hop, gdc") 350 | parser.add_argument('--gdc_method', type=str, default='ppr', help="ppr, heat, coeff") 351 | parser.add_argument('--gdc_sparsification', type=str, default='topk', help="threshold, topk") 352 | parser.add_argument('--gdc_k', type=int, default=64, help="number of neighbours to sparsify to when using topk") 353 | parser.add_argument('--gdc_threshold', type=float, default=0.0001, 354 | help="obove this edge weight, keep edges when using threshold") 355 | parser.add_argument('--gdc_avg_degree', type=int, default=64, 356 | help="if gdc_threshold is not given can be calculated by specifying avg degree") 357 | parser.add_argument('--ppr_alpha', type=float, default=0.05, help="teleport probability") 358 | parser.add_argument('--heat_time', type=float, default=3., help="time to run gdc heat kernal diffusion for") 359 | parser.add_argument('--att_samp_pct', type=float, default=1, 360 | help="float in [0,1). The percentage of edges to retain based on attention scores") 361 | parser.add_argument('--use_flux', dest='use_flux', action='store_true', 362 | help='incorporate the feature grad in attention based edge dropout') 363 | parser.add_argument("--exact", action="store_true", 364 | help="for small datasets can do exact diffusion. If dataset is too big for matrix inversion then you can't") 365 | 366 | # wandb logging and tuning 367 | parser.add_argument('--wandb', action='store_true', help="flag if logging to wandb") 368 | parser.add_argument('-wandb_offline', dest='use_wandb_offline', 369 | action='store_true') # https://docs.wandb.ai/guides/technical-faq 370 | 371 | parser.add_argument('--wandb_sweep', action='store_true', 372 | help="flag if sweeping") # if not it picks up params in greed_params 373 | parser.add_argument('--wandb_watch_grad', action='store_true', help='allows gradient tracking in train function') 374 | parser.add_argument('--wandb_track_grad_flow', action='store_true') 375 | 376 | parser.add_argument('--wandb_entity', default="graphcon", type=str) # not used as default set in web browser settings 377 | parser.add_argument('--wandb_project', default="graphcon", type=str) 378 | parser.add_argument('--wandb_group', default="testing", type=str, help="testing,tuning,eval") 379 | parser.add_argument('--wandb_run_name', default=None, type=str) 380 | parser.add_argument('--wandb_output_dir', default='./wandb_output', 381 | help='folder to output results, images and model checkpoints') 382 | parser.add_argument('--wandb_log_freq', type=int, default=1, help='Frequency to log metrics.') 383 | parser.add_argument('--wandb_epoch_list', nargs='+', default=[0, 1, 2, 4, 8, 16], 384 | help='list of epochs to log gradient flow') 385 | 386 | args = parser.parse_args() 387 | 388 | opt = vars(args) 389 | 390 | main(opt) 391 | -------------------------------------------------------------------------------- /src/homophilic_graphs/run_best_sweeps.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts the optimal hyperparameters found from a wandb hyperparameter sweep 3 | """ 4 | 5 | import yaml 6 | import argparse 7 | import os 8 | import sys 9 | import time 10 | 11 | import wandb 12 | import torch 13 | import numpy as np 14 | from ray import tune 15 | from functools import partial 16 | from ray.tune import CLIReporter 17 | 18 | from good_params_waveGNN import good_params_dict 19 | from data import get_dataset, set_train_val_test_split 20 | from GNN import GNN 21 | from GNN_early import GNNEarly 22 | from run_GNN import get_optimizer, test, train 23 | from utils import get_sem, mean_confidence_interval 24 | 25 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 26 | from definitions import ROOT_DIR 27 | 28 | 29 | def main(opt, data_dir): 30 | # todo see if I can initialise wandb runs inside of ray processes 31 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 32 | dataset = get_dataset(opt, data_dir, opt['not_lcc']) 33 | 34 | if opt["num_splits"] > 0: 35 | dataset.data = set_train_val_test_split( 36 | 23 * np.random.randint(0, opt["num_splits"]), 37 | # random prime 23 to make the splits 'more' random. Could remove 38 | dataset.data, 39 | num_development=5000 if opt["dataset"] == "CoauthorCS" else 1500) 40 | 41 | model = GNN(opt, dataset, device) if opt["no_early"] else GNNEarly(opt, dataset, device) 42 | model, data = model.to(device), dataset.data.to(device) 43 | parameters = [p for p in model.parameters() if p.requires_grad] 44 | optimizer = get_optimizer(opt["optimizer"], parameters, lr=opt["lr"], weight_decay=opt["decay"]) 45 | 46 | this_test = test 47 | best_time = best_epoch = train_acc = val_acc = test_acc = 0 48 | for epoch in range(1, opt["epoch"]): 49 | loss = train(model, optimizer, data) 50 | if opt["no_early"]: 51 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt) 52 | best_time = opt['time'] 53 | else: 54 | tmp_train_acc, tmp_val_acc, tmp_test_acc = this_test(model, data, opt) 55 | if tmp_val_acc > val_acc: 56 | best_epoch = epoch 57 | train_acc = tmp_train_acc 58 | val_acc = tmp_val_acc 59 | test_acc = tmp_test_acc 60 | if model.odeblock.test_integrator.solver.best_val > val_acc: 61 | best_epoch = epoch 62 | val_acc = model.odeblock.test_integrator.solver.best_val 63 | test_acc = model.odeblock.test_integrator.solver.best_test 64 | train_acc = model.odeblock.test_integrator.solver.best_train 65 | best_time = model.odeblock.test_integrator.solver.best_time 66 | tune.report(loss=loss, val_acc=val_acc, test_acc=test_acc, train_acc=train_acc, best_time=best_time, 67 | best_epoch=best_epoch, 68 | forward_nfe=model.fm.sum, backward_nfe=model.bm.sum) 69 | res_dict = {"loss": loss, 70 | "train_acc": train_acc, "val_acc": val_acc, "test_acc": test_acc, "best_time": best_time, 71 | "best_epoch": best_epoch, "epoch_step": epoch} 72 | print(res_dict) 73 | 74 | 75 | def run_best(opt): 76 | opt['wandb_entity'] = 'bchamberlain' 77 | opt['wandb_project'] = 'waveGNN-src_node_level' 78 | opt['wandb_group'] = None 79 | if 'wandb_run_name' in opt.keys(): 80 | wandb_run = wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'], 81 | name=opt['wandb_run_name'], reinit=True, config=opt) 82 | else: 83 | wandb_run = wandb.init(entity=opt['wandb_entity'], project=opt['wandb_project'], group=opt['wandb_group'], 84 | reinit=True, config=opt) 85 | 86 | wandb.define_metric("epoch_step") 87 | 88 | yaml_path = f"./wandb/sweep-{opt['sweep']}/config-{opt['run']}.yaml" 89 | with open(yaml_path) as f: 90 | yaml_opt = yaml.load(f, Loader=yaml.FullLoader) 91 | temp_opt = {} 92 | for k, v in yaml_opt.items(): 93 | if type(v) == dict: 94 | temp_opt[k] = v['value'] 95 | else: 96 | temp_opt[k] = v 97 | yaml_opt = temp_opt 98 | dataset = yaml_opt['dataset'] 99 | 100 | opt = {**good_params_dict[dataset], **yaml_opt, **opt} 101 | opt['wandb'] = True 102 | opt['use_wandb_offline'] = False 103 | opt['wandb_best_run_id'] = opt['run'] 104 | opt['wandb_track_grad_flow'] = False 105 | opt['wandb_watch_grad'] = False 106 | 107 | reporter = CLIReporter( 108 | metric_columns=["val_acc", "loss", "test_acc", "train_acc", "best_time", "best_epoch"]) 109 | 110 | result = tune.run( 111 | partial(main, data_dir=f"{ROOT_DIR}/data"), 112 | name=opt['run'], 113 | resources_per_trial={"cpu": opt['cpus'], "gpu": opt['gpus']}, 114 | search_alg=None, 115 | keep_checkpoints_num=3, 116 | checkpoint_score_attr='val_acc', 117 | config=opt, 118 | num_samples=opt['reps'] if opt["num_splits"] == 0 else opt["num_splits"] * opt["reps"], 119 | scheduler=None, 120 | max_failures=1, # early stop solver can't recover from failure as it doesn't own m2. 121 | local_dir='../ray_tune', 122 | progress_reporter=reporter, 123 | raise_on_failed_trial=False) 124 | 125 | df = result.dataframe(metric='test_acc', mode="max").sort_values('test_acc', ascending=False) 126 | try: 127 | df.to_csv('../ray_results/{}_{}.csv'.format(opt['run'], time.strftime("%Y%m%d-%H%M%S"))) 128 | except: 129 | pass 130 | 131 | print(df[['val_acc', 'test_acc', 'train_acc', 'best_time', 'best_epoch']]) 132 | 133 | test_accs = df['test_acc'].values 134 | val_accs = df['val_acc'].values 135 | train_accs = df['train_acc'].values 136 | print("test accuracy {}".format(test_accs)) 137 | log = "mean test {:04f}, test std {:04f}, test sem {:04f}, test 95% conf {:04f}" 138 | log_dic = {'test_acc': test_accs.mean(), 'val_acc': val_accs.mean(), 'train_acc': train_accs.mean(), 139 | 'test_std': np.std(test_accs), 'test_sem': get_sem(test_accs), 140 | 'test_95_conf': mean_confidence_interval(test_accs)} 141 | wandb.log(log_dic) 142 | print(log.format(test_accs.mean(), np.std(test_accs), get_sem(test_accs), mean_confidence_interval(test_accs))) 143 | 144 | wandb_run.finish() 145 | 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser() 149 | parser.add_argument('--epoch', type=int, default=10, help='Number of training epochs per iteration.') 150 | parser.add_argument('--sweep', type=str, default=None, help='sweep folder to read', required=True) 151 | parser.add_argument('--run', type=str, default=None, help='the run IDs', required=True) 152 | parser.add_argument('--reps', type=int, default=1, help='the number of random weight initialisations to use') 153 | parser.add_argument('--name', type=str, default=None) 154 | parser.add_argument('--gpus', type=float, default=0, help='number of gpus per trial. Can be fractional') 155 | parser.add_argument('--cpus', type=float, default=1, help='number of cpus per trial. Can be fractional') 156 | parser.add_argument("--num_splits", type=int, default=0, help="Number of random slpits >= 0. 0 for planetoid split") 157 | parser.add_argument("--adjoint", dest='adjoint', action='store_true', 158 | help="use the adjoint ODE method to reduce memory footprint") 159 | parser.add_argument("--max_nfe", type=int, default=5000, 160 | help="Maximum number of function evaluations allowed in an epcoh.") 161 | parser.add_argument("--no_early", action="store_true", 162 | help="Whether or not to use early stopping of the ODE integrator when testing.") 163 | 164 | parser.add_argument('--earlystopxT', type=float, default=3, help='multiplier for T used to evaluate best model') 165 | 166 | args = parser.parse_args() 167 | 168 | opt = vars(args) 169 | run_best(opt) 170 | -------------------------------------------------------------------------------- /src/homophilic_graphs/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utility functions 3 | """ 4 | 5 | import scipy 6 | from scipy.stats import sem 7 | import numpy as np 8 | import torch 9 | from torch_scatter import scatter_add 10 | from torch_geometric.utils import add_remaining_self_loops 11 | from torch_geometric.utils.num_nodes import maybe_num_nodes 12 | from torch_geometric.utils.convert import to_scipy_sparse_matrix 13 | from sklearn.preprocessing import normalize 14 | from torch_geometric.nn.conv.gcn_conv import gcn_norm 15 | 16 | 17 | class MaxNFEException(Exception): pass 18 | 19 | 20 | def print_model_params(model): 21 | total_num_params = 0 22 | print(model) 23 | for name, param in model.named_parameters(): 24 | if param.requires_grad: 25 | print(name) 26 | print(param.data.shape) 27 | total_num_params += param.numel() 28 | print("Model has a total of {} params".format(total_num_params)) 29 | 30 | 31 | def adjust_learning_rate(optimizer, lr, epoch, burnin=50): 32 | if epoch <= burnin: 33 | for param_group in optimizer.param_groups: 34 | param_group["lr"] = lr * epoch / burnin 35 | 36 | 37 | def gcn_norm_fill_val(edge_index, edge_weight=None, fill_value=0., num_nodes=None, dtype=None): 38 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 39 | 40 | if edge_weight is None: 41 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 42 | device=edge_index.device) 43 | 44 | if not int(fill_value) == 0: 45 | edge_index, tmp_edge_weight = add_remaining_self_loops( 46 | edge_index, edge_weight, fill_value, num_nodes) 47 | assert tmp_edge_weight is not None 48 | edge_weight = tmp_edge_weight 49 | 50 | row, col = edge_index[0], edge_index[1] 51 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 52 | deg_inv_sqrt = deg.pow_(-0.5) 53 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 54 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 55 | 56 | 57 | def coo2tensor(coo, device=None): 58 | indices = np.vstack((coo.row, coo.col)) 59 | i = torch.LongTensor(indices) 60 | values = coo.data 61 | v = torch.FloatTensor(values) 62 | shape = coo.shape 63 | print('adjacency matrix generated with shape {}'.format(shape)) 64 | # test 65 | return torch.sparse.FloatTensor(i, v, torch.Size(shape)).to(device) 66 | 67 | 68 | def get_sym_adj(data, opt, improved=False): 69 | edge_index, edge_weight = gcn_norm( # yapf: disable 70 | data.edge_index, data.edge_attr, data.num_nodes, 71 | improved, opt['self_loop_weight'] > 0, dtype=data.x.dtype) 72 | coo = to_scipy_sparse_matrix(edge_index, edge_weight) 73 | return coo2tensor(coo) 74 | 75 | 76 | def get_rw_adj_old(data, opt): 77 | if opt['self_loop_weight'] > 0: 78 | edge_index, edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 79 | fill_value=opt['self_loop_weight']) 80 | else: 81 | edge_index, edge_weight = data.edge_index, data.edge_attr 82 | coo = to_scipy_sparse_matrix(edge_index, edge_weight) 83 | normed_csc = normalize(coo, norm='l1', axis=0) 84 | return coo2tensor(normed_csc.tocoo()) 85 | 86 | 87 | def get_rw_adj(edge_index, edge_weight=None, norm_dim=1, fill_value=0., num_nodes=None, dtype=None): 88 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 89 | 90 | if edge_weight is None: 91 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 92 | device=edge_index.device) 93 | 94 | if not fill_value == 0: 95 | edge_index, tmp_edge_weight = add_remaining_self_loops( 96 | edge_index, edge_weight, fill_value, num_nodes) 97 | assert tmp_edge_weight is not None 98 | edge_weight = tmp_edge_weight 99 | 100 | row, col = edge_index[0], edge_index[1] 101 | indices = row if norm_dim == 0 else col 102 | deg = scatter_add(edge_weight, indices, dim=0, dim_size=num_nodes) 103 | deg_inv_sqrt = deg.pow_(-1) 104 | edge_weight = deg_inv_sqrt[indices] * edge_weight if norm_dim == 0 else edge_weight * deg_inv_sqrt[indices] 105 | return edge_index, edge_weight 106 | 107 | 108 | def mean_confidence_interval(data, confidence=0.95): 109 | """ 110 | As number of samples will be < 10 use t-test for the mean confidence intervals 111 | :param data: NDarray of metric means 112 | :param confidence: The desired confidence interval 113 | :return: Float confidence interval 114 | """ 115 | if len(data) < 2: 116 | return 0 117 | a = 1.0 * np.array(data) 118 | n = len(a) 119 | _, se = np.mean(a), scipy.stats.sem(a) 120 | h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1) 121 | return h 122 | 123 | 124 | def sparse_dense_mul(s, d): 125 | i = s._indices() 126 | v = s._values() 127 | return torch.sparse.FloatTensor(i, v * d, s.size()) 128 | 129 | 130 | def get_sem(vec): 131 | """ 132 | wrapper around the scipy standard error metric 133 | :param vec: List of metric means 134 | :return: 135 | """ 136 | if len(vec) > 1: 137 | retval = sem(vec) 138 | else: 139 | retval = 0. 140 | return retval 141 | 142 | 143 | # Counter of forward and backward passes. 144 | class Meter(object): 145 | 146 | def __init__(self): 147 | self.reset() 148 | 149 | def reset(self): 150 | self.val = None 151 | self.sum = 0 152 | self.cnt = 0 153 | 154 | def update(self, val): 155 | self.val = val 156 | self.sum += val 157 | self.cnt += 1 158 | 159 | def get_average(self): 160 | if self.cnt == 0: 161 | return 0 162 | return self.sum / self.cnt 163 | 164 | def get_value(self): 165 | return self.val 166 | --------------------------------------------------------------------------------