├── README.md ├── environment.yml └── experiments ├── datasets.py ├── gcn.py ├── psgd.py ├── run.sh ├── train_eval.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Optimization of Graph Neural Networks with Natural Gradient Descent 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-citeseer)](https://paperswithcode.com/sota/node-classification-on-citeseer?p=optimization-of-graph-neural-networks-with) 4 | 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-cora)](https://paperswithcode.com/sota/node-classification-on-cora?p=optimization-of-graph-neural-networks-with) 6 | 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-pubmed)](https://paperswithcode.com/sota/node-classification-on-pubmed?p=optimization-of-graph-neural-networks-with) 8 | 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-citeseer-with-public)](https://paperswithcode.com/sota/node-classification-on-citeseer-with-public?p=optimization-of-graph-neural-networks-with) 10 | 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-pubmed-with-public)](https://paperswithcode.com/sota/node-classification-on-pubmed-with-public?p=optimization-of-graph-neural-networks-with) 12 | 13 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/optimization-of-graph-neural-networks-with/node-classification-on-cora-with-public-split)](https://paperswithcode.com/sota/node-classification-on-cora-with-public-split?p=optimization-of-graph-neural-networks-with) 14 | 15 | This repository contains the implementaion of the [Optimization of Graph Neural Networks with Natural Gradient Descent](https://arxiv.org/abs/2008.09624). Most of the code is adapted from [github.com/rusty1s/pytorch_geometric](https://github.com/rusty1s/pytorch_geometric) and [github.com/Thrandis/EKFAC-pytorch](https://github.com/Thrandis/EKFAC-pytorch). To duplicate the results reported in the paper, follow the subsequent steps in order. 16 | 17 | - Clone the repository and change your current directory: 18 | ``` 19 | git clone https://github.com/russellizadi/ssp 20 | cd ssp 21 | ``` 22 | - Create a new `conda` environment using the default `environment.yml`: 23 | ``` 24 | conda env create 25 | ``` 26 | - Activate the default environment: 27 | ``` 28 | conda activate ssp 29 | ``` 30 | - Go to the `experiments` folder: 31 | ``` 32 | cd experiments 33 | ``` 34 | - Run all the experiments performed in the paper: 35 | ``` 36 | ./run.sh 37 | ``` 38 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ssp 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.7 6 | - torchvision 7 | - pytorch 8 | - cudatoolkit=10.2 9 | - jupyter 10 | - jupyterlab 11 | - seaborn 12 | - tensorboard 13 | - matplotlib 14 | 15 | -------------------------------------------------------------------------------- /experiments/datasets.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import torch 3 | from torch_geometric.datasets import Planetoid 4 | import torch_geometric.transforms as T 5 | 6 | def get_planetoid_dataset(name, normalize_features=False, transform=None, split="public"): 7 | path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name) 8 | if split == 'complete': 9 | dataset = Planetoid(path, name) 10 | dataset[0].train_mask.fill_(False) 11 | dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1 12 | dataset[0].val_mask.fill_(False) 13 | dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1 14 | dataset[0].test_mask.fill_(False) 15 | dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1 16 | else: 17 | dataset = Planetoid(path, name, split=split) 18 | if transform is not None and normalize_features: 19 | dataset.transform = T.Compose([T.NormalizeFeatures(), transform]) 20 | elif normalize_features: 21 | dataset.transform = T.NormalizeFeatures() 22 | elif transform is not None: 23 | dataset.transform = transform 24 | return dataset 25 | 26 | 27 | if __name__ == '__main__': 28 | lst_names = ['Cora', 'CiteSeer', 'PubMed'] 29 | for name in lst_names: 30 | dataset = get_planetoid_dataset(name) 31 | print(f"dataset: {name}") 32 | print(f"num_nodes: {dataset[0]['x'].shape[0]}") 33 | print(f"num_edges: {dataset[0]['edge_index'].shape[1]}") 34 | print(f"num_classes: {dataset.num_classes}") 35 | print(f"num_features: {dataset.num_node_features}") -------------------------------------------------------------------------------- /experiments/gcn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch_geometric.nn import GCNConv 6 | from datasets import get_planetoid_dataset 7 | from train_eval import run 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dataset', type=str, required=True) 11 | parser.add_argument('--split', type=str, default='public') 12 | parser.add_argument('--runs', type=int, default=10) 13 | parser.add_argument('--epochs', type=int, default=200) 14 | parser.add_argument('--lr', type=float, default=0.01) 15 | parser.add_argument('--weight_decay', type=float, default=0.0005) 16 | parser.add_argument('--early_stopping', type=int, default=0) 17 | parser.add_argument('--hidden', type=int, default=16) 18 | parser.add_argument('--dropout', type=float, default=0.5) 19 | parser.add_argument('--normalize_features', type=bool, default=True) 20 | parser.add_argument('--logger', type=str, default=None) 21 | parser.add_argument('--optimizer', type=str, default='Adam') 22 | parser.add_argument('--preconditioner', type=str, default=None) 23 | parser.add_argument('--momentum', type=float, default=0.9) 24 | parser.add_argument('--eps', type=float, default=0.01) 25 | parser.add_argument('--update_freq', type=int, default=50) 26 | parser.add_argument('--gamma', type=float, default=None) 27 | parser.add_argument('--alpha', type=float, default=None) 28 | parser.add_argument('--hyperparam', type=str, default=None) 29 | args = parser.parse_args() 30 | 31 | class Net_orig(torch.nn.Module): 32 | def __init__(self, dataset): 33 | super(Net2, self).__init__() 34 | self.conv1 = GCNConv(dataset.num_features, args.hidden) 35 | self.conv2 = GCNConv(args.hidden, dataset.num_classes) 36 | 37 | def reset_parameters(self): 38 | self.conv1.reset_parameters() 39 | self.conv2.reset_parameters() 40 | 41 | def forward(self, data): 42 | x, edge_index = data.x, data.edge_index 43 | x = F.relu(self.conv1(x, edge_index)) 44 | x = F.dropout(x, p=args.dropout, training=self.training) 45 | x = self.conv2(x, edge_index) 46 | return F.log_softmax(x, dim=1) 47 | 48 | class CRD(torch.nn.Module): 49 | def __init__(self, d_in, d_out, p): 50 | super(CRD, self).__init__() 51 | self.conv = GCNConv(d_in, d_out, cached=True) 52 | self.p = p 53 | 54 | def reset_parameters(self): 55 | self.conv.reset_parameters() 56 | 57 | def forward(self, x, edge_index, mask=None): 58 | x = F.relu(self.conv(x, edge_index)) 59 | x = F.dropout(x, p=self.p, training=self.training) 60 | return x 61 | 62 | class CLS(torch.nn.Module): 63 | def __init__(self, d_in, d_out): 64 | super(CLS, self).__init__() 65 | self.conv = GCNConv(d_in, d_out, cached=True) 66 | 67 | def reset_parameters(self): 68 | self.conv.reset_parameters() 69 | 70 | def forward(self, x, edge_index, mask=None): 71 | x = self.conv(x, edge_index) 72 | x = F.log_softmax(x, dim=1) 73 | return x 74 | 75 | class Net(torch.nn.Module): 76 | def __init__(self, dataset): 77 | super(Net, self).__init__() 78 | self.crd = CRD(dataset.num_features, args.hidden, args.dropout) 79 | self.cls = CLS(args.hidden, dataset.num_classes) 80 | 81 | def reset_parameters(self): 82 | self.crd.reset_parameters() 83 | self.cls.reset_parameters() 84 | 85 | def forward(self, data): 86 | x, edge_index = data.x, data.edge_index 87 | x = self.crd(x, edge_index, data.train_mask) 88 | x = self.cls(x, edge_index, data.train_mask) 89 | return x 90 | 91 | dataset = get_planetoid_dataset(name=args.dataset, normalize_features=args.normalize_features, split=args.split) 92 | 93 | kwargs = { 94 | 'dataset': dataset, 95 | 'model': Net(dataset), 96 | 'str_optimizer': args.optimizer, 97 | 'str_preconditioner': args.preconditioner, 98 | 'runs': args.runs, 99 | 'epochs': args.epochs, 100 | 'lr': args.lr, 101 | 'weight_decay': args.weight_decay, 102 | 'early_stopping': args.early_stopping, 103 | 'logger': args.logger, 104 | 'momentum': args.momentum, 105 | 'eps': args.eps, 106 | 'update_freq': args.update_freq, 107 | 'gamma': args.gamma, 108 | 'alpha': args.alpha, 109 | 'hyperparam': args.hyperparam 110 | } 111 | 112 | if args.hyperparam == 'eps': 113 | for param in np.logspace(-3, 0, 10, endpoint=True): 114 | print(f"{args.hyperparam}: {param}") 115 | kwargs[args.hyperparam] = param 116 | run(**kwargs) 117 | elif args.hyperparam == 'update_freq': 118 | for param in [4, 8, 16, 32, 64, 128]: 119 | print(f"{args.hyperparam}: {param}") 120 | kwargs[args.hyperparam] = param 121 | run(**kwargs) 122 | elif args.hyperparam == 'gamma': 123 | for param in np.linspace(1., 10., 10, endpoint=True): 124 | print(f"{args.hyperparam}: {param}") 125 | kwargs[args.hyperparam] = param 126 | run(**kwargs) 127 | else: 128 | run(**kwargs) -------------------------------------------------------------------------------- /experiments/psgd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_scatter import scatter 4 | from torch.optim.optimizer import Optimizer 5 | 6 | class KFAC(Optimizer): 7 | 8 | def __init__(self, net, eps, sua=False, pi=False, update_freq=1, 9 | alpha=1.0, constraint_norm=False): 10 | """ K-FAC Preconditionner for Linear and Conv2d layers. 11 | Computes the K-FAC of the second moment of the gradients. 12 | It works for Linear and Conv2d layers and silently skip other layers. 13 | Args: 14 | net (torch.nn.Module): Network to precondition. 15 | eps (float): Tikhonov regularization parameter for the inverses. 16 | sua (bool): Applies SUA approximation. 17 | pi (bool): Computes pi correction for Tikhonov regularization. 18 | update_freq (int): Perform inverses every update_freq updates. 19 | alpha (float): Running average parameter (if == 1, no r. ave.). 20 | constraint_norm (bool): Scale the gradients by the squared 21 | fisher norm. 22 | """ 23 | self.eps = eps 24 | self.sua = sua 25 | self.pi = pi 26 | self.update_freq = update_freq 27 | self.alpha = alpha 28 | self.constraint_norm = constraint_norm 29 | self.params = [] 30 | self._fwd_handles = [] 31 | self._bwd_handles = [] 32 | self._iteration_counter = 0 33 | 34 | for mod in net.modules(): 35 | mod_name = mod.__class__.__name__ 36 | if mod_name in ['CRD', 'CLS']: 37 | handle = mod.register_forward_pre_hook(self._save_input) 38 | self._fwd_handles.append(handle) 39 | 40 | for sub_mod in mod.modules(): 41 | i_sub_mod = 0 42 | if hasattr(sub_mod, 'weight'): 43 | assert i_sub_mod == 0 44 | handle = sub_mod.register_backward_hook(self._save_grad_output) 45 | self._bwd_handles.append(handle) 46 | 47 | params = [sub_mod.weight] 48 | if sub_mod.bias is not None: 49 | params.append(sub_mod.bias) 50 | 51 | d = {'params': params, 'mod': mod, 'sub_mod': sub_mod} 52 | self.params.append(d) 53 | i_sub_mod += 1 54 | 55 | super(KFAC, self).__init__(self.params, {}) 56 | 57 | def step(self, update_stats=True, update_params=True, lam=0.): 58 | """Performs one step of preconditioning.""" 59 | self.lam = lam 60 | fisher_norm = 0. 61 | for group in self.param_groups: 62 | 63 | if len(group['params']) == 2: 64 | weight, bias = group['params'] 65 | else: 66 | weight = group['params'][0] 67 | bias = None 68 | state = self.state[weight] 69 | 70 | # Update convariances and inverses 71 | if update_stats: 72 | if self._iteration_counter % self.update_freq == 0: 73 | self._compute_covs(group, state) 74 | ixxt, iggt = self._inv_covs(state['xxt'], state['ggt'], 75 | state['num_locations']) 76 | state['ixxt'] = ixxt 77 | state['iggt'] = iggt 78 | else: 79 | if self.alpha != 1: 80 | self._compute_covs(group, state) 81 | 82 | if update_params: 83 | gw, gb = self._precond(weight, bias, group, state) 84 | 85 | # Updating gradients 86 | if self.constraint_norm: 87 | fisher_norm += (weight.grad * gw).sum() 88 | 89 | weight.grad.data = gw 90 | if bias is not None: 91 | if self.constraint_norm: 92 | fisher_norm += (bias.grad * gb).sum() 93 | bias.grad.data = gb 94 | 95 | # Cleaning 96 | if 'x' in self.state[group['mod']]: 97 | del self.state[group['mod']]['x'] 98 | if 'gy' in self.state[group['mod']]: 99 | del self.state[group['mod']]['gy'] 100 | 101 | # Eventually scale the norm of the gradients 102 | if update_params and self.constraint_norm: 103 | scale = (1. / fisher_norm) ** 0.5 104 | for group in self.param_groups: 105 | for param in group['params']: 106 | print(param.shape, param) 107 | param.grad.data *= scale 108 | 109 | if update_stats: 110 | self._iteration_counter += 1 111 | 112 | def _save_input(self, mod, i): 113 | """Saves input of layer to compute covariance.""" 114 | # i = (x, edge_index) 115 | if mod.training: 116 | self.state[mod]['x'] = i[0] 117 | 118 | self.mask = i[-1] 119 | 120 | def _save_grad_output(self, mod, grad_input, grad_output): 121 | """Saves grad on output of layer to compute covariance.""" 122 | if mod.training: 123 | self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(1) 124 | self._cached_edge_index = mod._cached_edge_index 125 | 126 | def _precond(self, weight, bias, group, state): 127 | """Applies preconditioning.""" 128 | ixxt = state['ixxt'] # [d_in x d_in] 129 | iggt = state['iggt'] # [d_out x d_out] 130 | g = weight.grad.data # [d_in x d_out] 131 | s = g.shape 132 | 133 | g = g.contiguous().view(-1, g.shape[-1]) 134 | 135 | if bias is not None: 136 | gb = bias.grad.data 137 | g = torch.cat([g, gb.view(1, gb.shape[0])], dim=0) 138 | 139 | g = torch.mm(ixxt, torch.mm(g, iggt)) 140 | if bias is not None: 141 | gb = g[-1].contiguous().view(*bias.shape) 142 | g = g[:-1] 143 | else: 144 | gb = None 145 | g = g.contiguous().view(*s) 146 | return g, gb 147 | 148 | def _compute_covs(self, group, state): 149 | """Computes the covariances.""" 150 | sub_mod = group['sub_mod'] 151 | x = self.state[group['mod']]['x'] # [n x d_in] 152 | gy = self.state[group['sub_mod']]['gy'] # [n x d_out] 153 | edge_index, edge_weight = self._cached_edge_index # [2, n_edges], [n_edges] 154 | 155 | n = float(self.mask.sum() + self.lam*((~self.mask).sum())) 156 | 157 | x = scatter(x[edge_index[0]]*edge_weight[:, None], edge_index[1], dim=0) 158 | 159 | x = x.data.t() 160 | 161 | if sub_mod.weight.ndim == 3: 162 | x = x.repeat(sub_mod.weight.shape[0], 1) 163 | 164 | 165 | 166 | if sub_mod.bias is not None: 167 | ones = torch.ones_like(x[:1]) 168 | x = torch.cat([x, ones], dim=0) 169 | 170 | if self._iteration_counter == 0: 171 | state['xxt'] = torch.mm(x, x.t()) / n 172 | else: 173 | state['xxt'].addmm_(mat1=x, mat2=x.t(), 174 | beta=(1. - self.alpha), 175 | alpha=self.alpha / n) 176 | 177 | gy = gy.data.t() # [d_out x n] 178 | 179 | state['num_locations'] = 1 180 | if self._iteration_counter == 0: 181 | state['ggt'] = torch.mm(gy, gy.t()) / n 182 | else: 183 | state['ggt'].addmm_(mat1=gy, mat2=gy.t(), 184 | beta=(1. - self.alpha), 185 | alpha=self.alpha / n) 186 | 187 | def _inv_covs(self, xxt, ggt, num_locations): 188 | """Inverses the covariances.""" 189 | # Computes pi 190 | pi = 1.0 191 | if self.pi: 192 | tx = torch.trace(xxt) * ggt.shape[0] 193 | tg = torch.trace(ggt) * xxt.shape[0] 194 | pi = (tx / tg) 195 | # Regularizes and inverse 196 | eps = self.eps / num_locations 197 | diag_xxt = xxt.new(xxt.shape[0]).fill_((eps * pi) ** 0.5) 198 | diag_ggt = ggt.new(ggt.shape[0]).fill_((eps / pi) ** 0.5) 199 | ixxt = (xxt + torch.diag(diag_xxt)).inverse() 200 | iggt = (ggt + torch.diag(diag_ggt)).inverse() 201 | 202 | return ixxt, iggt 203 | 204 | def __del__(self): 205 | for handle in self._fwd_handles + self._bwd_handles: 206 | handle.remove() -------------------------------------------------------------------------------- /experiments/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "GCN" 4 | 5 | echo "Cora" 6 | echo "====" 7 | 8 | python gcn.py --dataset=Cora --split=public --optimizer=Adam --logger=GCN-Cora1-Adam 9 | python gcn.py --dataset=Cora --split=full --optimizer=Adam --logger=GCN-Cora2-Adam 10 | python gcn.py --dataset=Cora --split=complete --optimizer=Adam --logger=GCN-Cora3-Adam 11 | 12 | python gcn.py --dataset=Cora --split=public --optimizer=Adam --hyperparam=gamma --logger=GCN-Cora1-Adam 13 | python gcn.py --dataset=Cora --split=full --optimizer=Adam --hyperparam=gamma --logger=GCN-Cora2-Adam 14 | python gcn.py --dataset=Cora --split=complete --optimizer=Adam --hyperparam=gamma --logger=GCN-Cora3-Adam 15 | 16 | python gcn.py --dataset=Cora --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora1-Adam-KFAC 17 | python gcn.py --dataset=Cora --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora2-Adam-KFAC 18 | python gcn.py --dataset=Cora --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora3-Adam-KFAC 19 | 20 | python gcn.py --dataset=Cora --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora1-Adam-KFAC 21 | python gcn.py --dataset=Cora --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora2-Adam-KFAC 22 | python gcn.py --dataset=Cora --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora3-Adam-KFAC 23 | 24 | python gcn.py --dataset=Cora --split=public --optimizer=SGD --logger=GCN-Cora1-SGD 25 | python gcn.py --dataset=Cora --split=full --optimizer=SGD --logger=GCN-Cora2-SGD 26 | python gcn.py --dataset=Cora --split=complete --optimizer=SGD --logger=GCN-Cora3-SGD 27 | 28 | python gcn.py --dataset=Cora --split=public --optimizer=SGD --hyperparam=gamma --logger=GCN-Cora1-SGD 29 | python gcn.py --dataset=Cora --split=full --optimizer=SGD --hyperparam=gamma --logger=GCN-Cora2-SGD 30 | python gcn.py --dataset=Cora --split=complete --optimizer=SGD --hyperparam=gamma --logger=GCN-Cora3-SGD 31 | 32 | python gcn.py --dataset=Cora --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora1-SGD-KFAC 33 | python gcn.py --dataset=Cora --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora2-SGD-KFAC 34 | python gcn.py --dataset=Cora --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-Cora3-SGD-KFAC 35 | 36 | python gcn.py --dataset=Cora --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora1-SGD-KFAC 37 | python gcn.py --dataset=Cora --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora2-SGD-KFAC 38 | python gcn.py --dataset=Cora --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-Cora3-SGD-KFAC 39 | 40 | echo "CiteSeer" 41 | echo "========" 42 | 43 | python gcn.py --dataset=CiteSeer --split=public --optimizer=Adam --logger=GCN-CiteSeer1-Adam 44 | python gcn.py --dataset=CiteSeer --split=full --optimizer=Adam --logger=GCN-CiteSeer2-Adam 45 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=Adam --logger=GCN-CiteSeer3-Adam 46 | 47 | python gcn.py --dataset=CiteSeer --split=public --optimizer=Adam --hyperparam=gamma --logger=GCN-CiteSeer1-Adam 48 | python gcn.py --dataset=CiteSeer --split=full --optimizer=Adam --hyperparam=gamma --logger=GCN-CiteSeer2-Adam 49 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=Adam --hyperparam=gamma --logger=GCN-CiteSeer3-Adam 50 | 51 | python gcn.py --dataset=CiteSeer --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer1-Adam-KFAC 52 | python gcn.py --dataset=CiteSeer --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer2-Adam-KFAC 53 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer3-Adam-KFAC 54 | 55 | python gcn.py --dataset=CiteSeer --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer1-Adam-KFAC 56 | python gcn.py --dataset=CiteSeer --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer2-Adam-KFAC 57 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer3-Adam-KFAC 58 | 59 | python gcn.py --dataset=CiteSeer --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer1-Adam-KFAC 60 | python gcn.py --dataset=CiteSeer --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer2-Adam-KFAC 61 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer3-Adam-KFAC 62 | 63 | python gcn.py --dataset=CiteSeer --split=public --optimizer=SGD --logger=GCN-CiteSeer1-SGD 64 | python gcn.py --dataset=CiteSeer --split=full --optimizer=SGD --logger=GCN-CiteSeer2-SGD 65 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=SGD --logger=GCN-CiteSeer3-SGD 66 | 67 | python gcn.py --dataset=CiteSeer --split=public --optimizer=SGD --hyperparam=gamma --logger=GCN-CiteSeer1-SGD 68 | python gcn.py --dataset=CiteSeer --split=full --optimizer=SGD --hyperparam=gamma --logger=GCN-CiteSeer2-SGD 69 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=SGD --hyperparam=gamma --logger=GCN-CiteSeer3-SGD 70 | 71 | python gcn.py --dataset=CiteSeer --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer1-SGD-KFAC 72 | python gcn.py --dataset=CiteSeer --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer2-SGD-KFAC 73 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-CiteSeer3-SGD-KFAC 74 | 75 | python gcn.py --dataset=CiteSeer --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer1-SGD-KFAC 76 | python gcn.py --dataset=CiteSeer --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer2-SGD-KFAC 77 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-CiteSeer3-SGD-KFAC 78 | 79 | python gcn.py --dataset=CiteSeer --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer1-SGD-KFAC 80 | python gcn.py --dataset=CiteSeer --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer2-SGD-KFAC 81 | python gcn.py --dataset=CiteSeer --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=update_freq --logger=GCN-CiteSeer3-SGD-KFAC 82 | 83 | echo "PubMed" 84 | echo "======" 85 | 86 | python gcn.py --dataset=PubMed --split=public --optimizer=Adam --logger=GCN-PubMed1-Adam 87 | python gcn.py --dataset=PubMed --split=full --optimizer=Adam --logger=GCN-PubMed2-Adam 88 | python gcn.py --dataset=PubMed --split=complete --optimizer=Adam --logger=GCN-PubMed3-Adam 89 | 90 | python gcn.py --dataset=PubMed --split=public --optimizer=Adam --hyperparam=gamma --logger=GCN-PubMed1-Adam 91 | python gcn.py --dataset=PubMed --split=full --optimizer=Adam --hyperparam=gamma --logger=GCN-PubMed2-Adam 92 | python gcn.py --dataset=PubMed --split=complete --optimizer=Adam --hyperparam=gamma --logger=GCN-PubMed3-Adam 93 | 94 | python gcn.py --dataset=PubMed --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed1-Adam-KFAC 95 | python gcn.py --dataset=PubMed --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed2-Adam-KFAC 96 | python gcn.py --dataset=PubMed --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed3-Adam-KFAC 97 | 98 | python gcn.py --dataset=PubMed --split=public --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed1-Adam-KFAC 99 | python gcn.py --dataset=PubMed --split=full --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed2-Adam-KFAC 100 | python gcn.py --dataset=PubMed --split=complete --optimizer=Adam --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed3-Adam-KFAC 101 | 102 | python gcn.py --dataset=PubMed --split=public --optimizer=SGD --logger=GCN-PubMed1-SGD 103 | python gcn.py --dataset=PubMed --split=full --optimizer=SGD --logger=GCN-PubMed2-SGD 104 | python gcn.py --dataset=PubMed --split=complete --optimizer=SGD --logger=GCN-PubMed3-SGD 105 | 106 | python gcn.py --dataset=PubMed --split=public --optimizer=SGD --hyperparam=gamma --logger=GCN-PubMed1-SGD 107 | python gcn.py --dataset=PubMed --split=full --optimizer=SGD --hyperparam=gamma --logger=GCN-PubMed2-SGD 108 | python gcn.py --dataset=PubMed --split=complete --optimizer=SGD --hyperparam=gamma --logger=GCN-PubMed3-SGD 109 | 110 | python gcn.py --dataset=PubMed --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed1-SGD-KFAC 111 | python gcn.py --dataset=PubMed --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed2-SGD-KFAC 112 | python gcn.py --dataset=PubMed --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=eps --logger=GCN-PubMed3-SGD-KFAC 113 | 114 | python gcn.py --dataset=PubMed --split=public --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed1-SGD-KFAC 115 | python gcn.py --dataset=PubMed --split=full --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed2-SGD-KFAC 116 | python gcn.py --dataset=PubMed --split=complete --optimizer=SGD --preconditioner=KFAC --hyperparam=gamma --logger=GCN-PubMed3-SGD-KFAC -------------------------------------------------------------------------------- /experiments/train_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import time 4 | import os 5 | import shutil 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import tensor 9 | from torch.utils.tensorboard import SummaryWriter 10 | import utils as ut 11 | import psgd 12 | 13 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 14 | 15 | path_runs = "runs" 16 | 17 | def run( 18 | dataset, 19 | model, 20 | str_optimizer, 21 | str_preconditioner, 22 | runs, 23 | epochs, 24 | lr, 25 | weight_decay, 26 | early_stopping, 27 | logger, 28 | momentum, 29 | eps, 30 | update_freq, 31 | gamma, 32 | alpha, 33 | hyperparam 34 | ): 35 | if logger is not None: 36 | if hyperparam: 37 | logger += f"-{hyperparam}{eval(hyperparam)}" 38 | path_logger = os.path.join(path_runs, logger) 39 | print(f"path logger: {path_logger}") 40 | 41 | ut.empty_dir(path_logger) 42 | logger = SummaryWriter(log_dir=os.path.join(path_runs, logger)) if logger is not None else None 43 | 44 | val_losses, accs, durations = [], [], [] 45 | torch.manual_seed(42) 46 | for i_run in range(runs): 47 | data = dataset[0] 48 | data = data.to(device) 49 | 50 | model.to(device).reset_parameters() 51 | if str_preconditioner == 'KFAC': 52 | 53 | preconditioner = psgd.KFAC( 54 | model, 55 | eps, 56 | sua=False, 57 | pi=False, 58 | update_freq=update_freq, 59 | alpha=alpha if alpha is not None else 1., 60 | constraint_norm=False 61 | ) 62 | else: 63 | preconditioner = None 64 | 65 | if str_optimizer == 'Adam': 66 | optimizer = torch.optim.Adam( 67 | model.parameters(), 68 | lr=lr, 69 | weight_decay=weight_decay 70 | ) 71 | elif str_optimizer == 'SGD': 72 | optimizer = torch.optim.SGD( 73 | model.parameters(), 74 | lr=lr, 75 | momentum=momentum, 76 | ) 77 | 78 | if torch.cuda.is_available(): 79 | torch.cuda.synchronize() 80 | 81 | t_start = time.perf_counter() 82 | 83 | best_val_loss = float('inf') 84 | test_acc = 0 85 | val_loss_history = [] 86 | 87 | for epoch in range(1, epochs + 1): 88 | lam = (float(epoch)/float(epochs))**gamma if gamma is not None else 0. 89 | train(model, optimizer, data, preconditioner, lam) 90 | eval_info = evaluate(model, data) 91 | eval_info['epoch'] = int(epoch) 92 | eval_info['run'] = int(i_run+1) 93 | eval_info['time'] = time.perf_counter() - t_start 94 | eval_info['eps'] = eps 95 | eval_info['update-freq'] = update_freq 96 | 97 | if gamma is not None: 98 | eval_info['gamma'] = gamma 99 | 100 | if alpha is not None: 101 | eval_info['alpha'] = alpha 102 | 103 | if logger is not None: 104 | for k, v in eval_info.items(): 105 | logger.add_scalar(k, v, global_step=epoch) 106 | 107 | 108 | if eval_info['val loss'] < best_val_loss: 109 | best_val_loss = eval_info['val loss'] 110 | test_acc = eval_info['test acc'] 111 | 112 | val_loss_history.append(eval_info['val loss']) 113 | if early_stopping > 0 and epoch > epochs // 2: 114 | tmp = tensor(val_loss_history[-(early_stopping + 1):-1]) 115 | if eval_info['val loss'] > tmp.mean().item(): 116 | break 117 | if torch.cuda.is_available(): 118 | torch.cuda.synchronize() 119 | 120 | t_end = time.perf_counter() 121 | 122 | val_losses.append(best_val_loss) 123 | accs.append(test_acc) 124 | durations.append(t_end - t_start) 125 | 126 | if logger is not None: 127 | logger.close() 128 | loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations) 129 | print('Val Loss: {:.4f}, Test Accuracy: {:.2f} ± {:.2f}, Duration: {:.3f} \n'. 130 | format(loss.mean().item(), 131 | 100*acc.mean().item(), 132 | 100*acc.std().item(), 133 | duration.mean().item())) 134 | 135 | def train(model, optimizer, data, preconditioner=None, lam=0.): 136 | model.train() 137 | optimizer.zero_grad() 138 | out = model(data) 139 | label = out.max(1)[1] 140 | label[data.train_mask] = data.y[data.train_mask] 141 | label.requires_grad = False 142 | 143 | loss = F.nll_loss(out[data.train_mask], label[data.train_mask]) 144 | loss += lam * F.nll_loss(out[~data.train_mask], label[~data.train_mask]) 145 | 146 | loss.backward(retain_graph=True) 147 | if preconditioner: 148 | preconditioner.step(lam=lam) 149 | optimizer.step() 150 | 151 | def evaluate(model, data): 152 | model.eval() 153 | 154 | with torch.no_grad(): 155 | logits = model(data) 156 | 157 | outs = {} 158 | for key in ['train', 'val', 'test']: 159 | mask = data['{}_mask'.format(key)] 160 | loss = F.nll_loss(logits[mask], data.y[mask]).item() 161 | pred = logits[mask].max(1)[1] 162 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 163 | 164 | outs['{} loss'.format(key)] = loss 165 | outs['{} acc'.format(key)] = acc 166 | 167 | return outs 168 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | from collections import defaultdict 5 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 6 | 7 | def tabulate_events(dpath): 8 | summary_iterators = [EventAccumulator(os.path.join(dpath, dname)).Reload() for dname in os.listdir(dpath) if dname.startswith('events')] 9 | assert len(summary_iterators) == 1 10 | tags = set(*[si.Tags()['scalars'] for si in summary_iterators]) 11 | 12 | out = defaultdict(list) 13 | steps = [] 14 | 15 | for tag in tags: 16 | steps = [e.step for e in summary_iterators[0].Scalars(tag)] 17 | for events in zip(*[acc.Scalars(tag) for acc in summary_iterators]): 18 | assert len(set(e.step for e in events)) == 1 19 | out[tag].append([e.value for e in events]) 20 | return out, steps 21 | 22 | def to_csv(dpath): 23 | dirs = os.listdir(dpath) 24 | 25 | d, steps = tabulate_events(dpath) 26 | tags, values = zip(*d.items()) 27 | np_values = np.array(values) 28 | df = pd.DataFrame(dict((f"{tags[i]}", np_values[i][:, 0]) for i in range(np_values.shape[0])), index=steps, columns=tags) 29 | df.to_csv(os.path.join(dpath, "logger.csv")) 30 | 31 | def read_event(path): 32 | to_csv(path) 33 | return pd.read_csv(os.path.join(path, "logger.csv"), index_col=0) 34 | 35 | def empty_dir(folder): 36 | if os.path.exists(folder): 37 | for filename in os.listdir(folder): 38 | file_path = os.path.join(folder, filename) 39 | try: 40 | if os.path.isfile(file_path) or os.path.islink(file_path): 41 | os.unlink(file_path) 42 | elif os.path.isdir(file_path): 43 | shutil.rmtree(file_path) 44 | except Exception as e: 45 | print('Failed to delete %s. Reason: %s' % (file_path, e)) --------------------------------------------------------------------------------