├── .DS_Store ├── LICENSE ├── README.md ├── base_classes.py ├── block_constant.py ├── block_mixed.py ├── block_transformer_attention.py ├── block_transformer_hard_attention.py ├── block_transformer_rewiring.py ├── config.py ├── data.py ├── distances_kNN.py ├── function_GAT_attention.py ├── function_laplacian_diffusion.py ├── function_transformer_attention.py ├── graph_rewiring.py ├── grb ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-38.pyc ├── attack │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base.cpython-36.pyc │ │ └── base.cpython-38.pyc │ ├── base.py │ ├── injection │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── fgsm.cpython-36.pyc │ │ │ ├── fgsm.cpython-38.pyc │ │ │ ├── pgd.cpython-36.pyc │ │ │ ├── pgd.cpython-38.pyc │ │ │ ├── rand.cpython-36.pyc │ │ │ ├── rand.cpython-38.pyc │ │ │ ├── speit.cpython-36.pyc │ │ │ ├── speit.cpython-38.pyc │ │ │ ├── tdgia.cpython-36.pyc │ │ │ └── tdgia.cpython-38.pyc │ │ ├── fgsm.py │ │ ├── pgd.py │ │ ├── rand.py │ │ ├── speit.py │ │ └── tdgia.py │ └── modification │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── dice.cpython-38.pyc │ │ ├── fga.cpython-38.pyc │ │ ├── flip.cpython-38.pyc │ │ ├── nea.cpython-38.pyc │ │ ├── pgd.cpython-38.pyc │ │ ├── prbcd.cpython-38.pyc │ │ ├── rand.cpython-38.pyc │ │ └── stack.cpython-38.pyc │ │ ├── dice.py │ │ ├── fga.py │ │ ├── flip.py │ │ ├── nea.py │ │ ├── pgd.py │ │ ├── prbcd.py │ │ ├── rand.py │ │ └── stack.py ├── dataset │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── dataset.cpython-36.pyc │ │ └── dataset.cpython-38.pyc │ └── dataset.py ├── defense │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── adv_trainer.cpython-36.pyc │ │ ├── adv_trainer.cpython-38.pyc │ │ ├── gcnsvd.cpython-36.pyc │ │ ├── gcnsvd.cpython-38.pyc │ │ ├── gnnguard.cpython-36.pyc │ │ ├── gnnguard.cpython-38.pyc │ │ ├── robustgcn.cpython-36.pyc │ │ └── robustgcn.cpython-38.pyc │ ├── adv_trainer.py │ ├── base.py │ ├── gcnsvd.py │ ├── gnnguard.py │ └── robustgcn.py ├── evaluator │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── evaluator.cpython-36.pyc │ │ ├── evaluator.cpython-38.pyc │ │ ├── metric.cpython-36.pyc │ │ └── metric.cpython-38.pyc │ ├── evaluator.py │ └── metric.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc │ ├── cogdl │ │ ├── __init__.py │ │ └── gcn.py │ ├── dgl │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── gat.cpython-36.pyc │ │ │ ├── gat.cpython-38.pyc │ │ │ ├── gatode.cpython-36.pyc │ │ │ ├── gatode.cpython-38.pyc │ │ │ ├── gcn.cpython-36.pyc │ │ │ ├── gcn.cpython-38.pyc │ │ │ ├── grand.cpython-36.pyc │ │ │ └── grand.cpython-38.pyc │ │ ├── gat.py │ │ ├── gatode.py │ │ ├── gcn.py │ │ ├── gin.py │ │ └── grand.py │ └── torch │ │ ├── MeanCurv.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── MeanCurv.cpython-36.pyc │ │ ├── MeanCurv.cpython-38.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── appnp.cpython-36.pyc │ │ ├── appnp.cpython-38.pyc │ │ ├── beltrami.cpython-36.pyc │ │ ├── beltrami.cpython-38.pyc │ │ ├── beltrami2.cpython-36.pyc │ │ ├── beltrami2.cpython-38.pyc │ │ ├── beltramii.cpython-36.pyc │ │ ├── gcn.cpython-36.pyc │ │ ├── gcn.cpython-38.pyc │ │ ├── gcnode.cpython-36.pyc │ │ ├── gcnode.cpython-38.pyc │ │ ├── gcnode2.cpython-36.pyc │ │ ├── gcnode2.cpython-38.pyc │ │ ├── gin.cpython-36.pyc │ │ ├── gin.cpython-38.pyc │ │ ├── graphsage.cpython-36.pyc │ │ ├── graphsage.cpython-38.pyc │ │ ├── heat.cpython-36.pyc │ │ ├── heat.cpython-38.pyc │ │ ├── mlp.cpython-36.pyc │ │ ├── mlp.cpython-38.pyc │ │ ├── pLaplace.cpython-36.pyc │ │ ├── pLaplace.cpython-38.pyc │ │ ├── sgcn.cpython-36.pyc │ │ ├── sgcn.cpython-38.pyc │ │ ├── tagcn.cpython-36.pyc │ │ └── tagcn.cpython-38.pyc │ │ ├── appnp.py │ │ ├── belguard.py │ │ ├── beltrami.py │ │ ├── beltrami2.py │ │ ├── beltramii.py │ │ ├── gcn.py │ │ ├── gcnode.py │ │ ├── gcnode2.py │ │ ├── gin.py │ │ ├── graphsage.py │ │ ├── heat.py │ │ ├── mlp.py │ │ ├── pLaplace.py │ │ ├── sgcn.py │ │ └── tagcn.py ├── trainer │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── trainer.cpython-36.pyc │ │ ├── trainer.cpython-38.pyc │ │ ├── trainer2.cpython-36.pyc │ │ ├── trainer2.cpython-38.pyc │ │ ├── trainer_helper.cpython-36.pyc │ │ └── trainer_helper.cpython-38.pyc │ ├── trainer.py │ ├── trainer2.py │ └── trainer_helper.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-38.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-38.pyc │ ├── normalize.cpython-36.pyc │ ├── normalize.cpython-38.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-38.pyc │ └── utils2.cpython-38.pyc │ ├── logger.py │ ├── normalize.py │ ├── utils.py │ └── visualize.py ├── heterophilic.py ├── hyperbolic_distances.py ├── logger.py ├── model_configurations.py ├── regularized_ODE_function.py ├── run_GNN2.py ├── saved_models2 ├── grb-AmazonCoBuyComputerDataset │ ├── beltrami_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── gcnsvd_1ln_noAdvT │ │ └── model_at_0.pt │ ├── gcnsvd_2ln_noAdvT │ │ └── model_at_0.pt │ ├── gcnsvd_3ln_noAdvT │ │ └── model_at_0.pt │ ├── grand_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ └── meancurv_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt ├── grb-coauthor │ ├── beltrami_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── heat_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ └── meancurv_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt ├── grb-flickr │ ├── beltrami2_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami2_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── beltrami2_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── grand_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_1noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ ├── meancurv_2noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt │ └── meancurv_3noAdvT_drop05_attsamp095 │ │ └── model_at_0.pt └── grb-pubmed │ └── grand_noAdvT_drop05_attsamp095 │ └── model_at_0.pt ├── train_target.py ├── tsne_plot.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Stanislas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | This repository contains the implementation to reproduce the numerical experiments 3 | of the **NEURIPS 2022** paper [On the Robustness of Graph Neural Diffusion to Topology Perturbations](https://arxiv.org/abs/2209.07754) 4 | 5 | ## Running the experiments 6 | The codes are written based on 7 | - https://github.com/twitter-research/graph-neural-pde 8 | - https://github.com/THUDM/grb 9 | 10 | ### Experiments 11 | For example to run 12 | ``` 13 | cd src 14 | python run_GNN2.py 15 | ``` 16 | Saved models are available in ./saved_models2. 17 | 18 | ### Requirements 19 | 20 | * scipy==1.5.2 21 | * numpy==1.19.1 22 | * torch==1.8.0 23 | * networkx==2.5 24 | * pandas~=1.2.3 25 | * cogdl~=0.3.0.post1 26 | * torch-cluster==1.5.9 27 | * torch-geometric==1.7.0 28 | * torch-scatter==2.0.6 29 | * torch-sparse==0.6.9 30 | * torch-spline-conv==1.2.1 31 | * torchdiffeq==0.2.1 32 | 33 | ## Citation 34 | If you found our work useful in your research, please cite our paper at: 35 | ```bibtex 36 | @INPROCEEDINGS{SonKanWan:C22, 37 | author = {Yang Song and Qiyu Kang and Sijie Wang and Kai Zhao and Wee Peng Tay}, 38 | title = {On the Robustness of Graph Neural Diffusion to Topology Perturbations}, 39 | booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, 40 | month = {Nov.}, 41 | year = {2022}, 42 | address = {New Orleans, USA}, 43 | } 44 | ``` 45 | (Also consider starring the project on GitHub.) 46 | -------------------------------------------------------------------------------- /base_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.nn.conv import MessagePassing 4 | from utils import Meter 5 | from regularized_ODE_function import RegularizedODEfunc 6 | import regularized_ODE_function as reg_lib 7 | import six 8 | 9 | 10 | REGULARIZATION_FNS = { 11 | "kinetic_energy": reg_lib.quadratic_cost, 12 | "jacobian_norm2": reg_lib.jacobian_frobenius_regularization_fn, 13 | "total_deriv": reg_lib.total_derivative, 14 | "directional_penalty": reg_lib.directional_derivative 15 | } 16 | 17 | 18 | def create_regularization_fns(args): 19 | regularization_fns = [] 20 | regularization_coeffs = [] 21 | 22 | for arg_key, reg_fn in six.iteritems(REGULARIZATION_FNS): 23 | if args[arg_key] is not None: 24 | regularization_fns.append(reg_fn) 25 | regularization_coeffs.append(args[arg_key]) 26 | 27 | regularization_fns = regularization_fns 28 | regularization_coeffs = regularization_coeffs 29 | return regularization_fns, regularization_coeffs 30 | 31 | 32 | class ODEblock(nn.Module): 33 | def __init__(self, odefunc, regularization_fns, opt, data, device, t): 34 | super(ODEblock, self).__init__() 35 | self.opt = opt 36 | self.t = t 37 | 38 | self.aug_dim = 2 if opt['augment'] else 1 39 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 40 | 41 | self.nreg = len(regularization_fns) 42 | self.reg_odefunc = RegularizedODEfunc(self.odefunc, regularization_fns) 43 | 44 | if opt['adjoint']: 45 | from torchdiffeq import odeint_adjoint as odeint 46 | else: 47 | from torchdiffeq import odeint 48 | self.train_integrator = odeint 49 | self.test_integrator = None 50 | self.set_tol() 51 | 52 | def set_x0(self, x0): 53 | self.odefunc.x0 = x0.clone().detach() 54 | self.reg_odefunc.odefunc.x0 = x0.clone().detach() 55 | 56 | def set_tol(self): 57 | self.atol = self.opt['tol_scale'] * 1e-7 58 | self.rtol = self.opt['tol_scale'] * 1e-9 59 | if self.opt['adjoint']: 60 | self.atol_adjoint = self.opt['tol_scale_adjoint'] * 1e-7 61 | self.rtol_adjoint = self.opt['tol_scale_adjoint'] * 1e-9 62 | 63 | def reset_tol(self): 64 | self.atol = 1e-7 65 | self.rtol = 1e-9 66 | self.atol_adjoint = 1e-7 67 | self.rtol_adjoint = 1e-9 68 | 69 | def set_time(self, time): 70 | self.t = torch.tensor([0, time]).to(self.device) 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 74 | + ")" 75 | 76 | 77 | class ODEFunc(MessagePassing): 78 | 79 | # currently requires in_features = out_features 80 | def __init__(self, opt, data, device): 81 | super(ODEFunc, self).__init__() 82 | self.opt = opt 83 | self.device = device 84 | self.edge_index = None 85 | self.edge_weight = None 86 | self.attention_weights = None 87 | self.alpha_train = nn.Parameter(torch.tensor(0.0)) 88 | self.beta_train = nn.Parameter(torch.tensor(0.0)) 89 | self.x0 = None 90 | self.nfe = 0 91 | self.alpha_sc = nn.Parameter(torch.ones(1)) 92 | self.beta_sc = nn.Parameter(torch.ones(1)) 93 | 94 | def __repr__(self): 95 | return self.__class__.__name__ 96 | 97 | 98 | class BaseGNN(MessagePassing): 99 | def __init__(self, opt, dataset, device=torch.device('cpu')): 100 | super(BaseGNN, self).__init__() 101 | self.opt = opt 102 | self.T = opt['time'] 103 | self.num_classes = dataset.num_classes 104 | self.num_features = dataset.data.num_features 105 | self.num_nodes = dataset.data.num_nodes 106 | self.device = device 107 | self.fm = Meter() 108 | self.bm = Meter() 109 | 110 | if opt['beltrami']: 111 | self.mx = nn.Linear(self.num_features, opt['feat_hidden_dim']) 112 | self.mp = nn.Linear(opt['pos_enc_dim'], opt['pos_enc_hidden_dim']) 113 | opt['hidden_dim'] = opt['feat_hidden_dim'] + opt['pos_enc_hidden_dim'] 114 | else: 115 | self.m1 = nn.Linear(self.num_features, opt['hidden_dim']) 116 | 117 | if self.opt['use_mlp']: 118 | self.m11 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 119 | self.m12 = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 120 | if opt['use_labels']: 121 | # todo - fastest way to propagate this everywhere, but error prone - refactor later 122 | opt['hidden_dim'] = opt['hidden_dim'] + dataset.num_classes 123 | else: 124 | self.hidden_dim = opt['hidden_dim'] 125 | if opt['fc_out']: 126 | self.fc = nn.Linear(opt['hidden_dim'], opt['hidden_dim']) 127 | self.m2 = nn.Linear(opt['hidden_dim'], dataset.num_classes) 128 | if self.opt['batch_norm']: 129 | self.bn_in = torch.nn.BatchNorm1d(opt['hidden_dim']) 130 | self.bn_out = torch.nn.BatchNorm1d(opt['hidden_dim']) 131 | 132 | self.regularization_fns, self.regularization_coeffs = create_regularization_fns(self.opt) 133 | 134 | def getNFE(self): 135 | return self.odeblock.odefunc.nfe + self.odeblock.reg_odefunc.odefunc.nfe 136 | 137 | def resetNFE(self): 138 | self.odeblock.odefunc.nfe = 0 139 | self.odeblock.reg_odefunc.odefunc.nfe = 0 140 | 141 | def reset(self): 142 | self.m1.reset_parameters() 143 | self.m2.reset_parameters() 144 | 145 | def __repr__(self): 146 | return self.__class__.__name__ 147 | -------------------------------------------------------------------------------- /block_constant.py: -------------------------------------------------------------------------------- 1 | from base_classes import ODEblock 2 | import torch 3 | from utils import get_rw_adj, gcn_norm_fill_val 4 | 5 | 6 | class ConstantODEblock(ODEblock): 7 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1])): 8 | super(ConstantODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 9 | 10 | self.aug_dim = 2 if opt['augment'] else 1 11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 12 | if opt['data_norm'] == 'rw': 13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 14 | fill_value=opt['self_loop_weight'], 15 | num_nodes=data.num_nodes, 16 | dtype=data.x.dtype) 17 | else: 18 | edge_index, edge_weight = gcn_norm_fill_val(data.edge_index, edge_weight=data.edge_attr, 19 | fill_value=opt['self_loop_weight'], 20 | num_nodes=data.num_nodes, 21 | dtype=data.x.dtype) 22 | self.odefunc.edge_index = edge_index.to(device) 23 | self.odefunc.edge_weight = edge_weight.to(device) 24 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 25 | 26 | if opt['adjoint']: 27 | from torchdiffeq import odeint_adjoint as odeint 28 | else: 29 | from torchdiffeq import odeint 30 | 31 | self.train_integrator = odeint 32 | self.test_integrator = odeint 33 | self.set_tol() 34 | 35 | def forward(self, x): 36 | t = self.t.type_as(x) 37 | 38 | integrator = self.train_integrator if self.training else self.test_integrator 39 | 40 | reg_states = tuple( torch.zeros(x.size(0)).to(x) for i in range(self.nreg) ) 41 | 42 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 43 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 44 | 45 | if self.opt["adjoint"] and self.training: 46 | state_dt = integrator( 47 | func, state, t, 48 | method=self.opt['method'], 49 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 50 | adjoint_method=self.opt['adjoint_method'], 51 | adjoint_options=dict(step_size = self.opt['adjoint_step_size'], max_iters=self.opt['max_iters']), 52 | atol=self.atol, 53 | rtol=self.rtol, 54 | adjoint_atol=self.atol_adjoint, 55 | adjoint_rtol=self.rtol_adjoint) 56 | else: 57 | state_dt = integrator( 58 | func, state, t, 59 | method=self.opt['method'], 60 | options=dict(step_size=self.opt['step_size'], max_iters=self.opt['max_iters']), 61 | atol=self.atol, 62 | rtol=self.rtol) 63 | 64 | if self.training and self.nreg > 0: 65 | z = state_dt[0][1] 66 | reg_states = tuple( st[1] for st in state_dt[1:] ) 67 | return z, reg_states 68 | else: 69 | z = state_dt[1] 70 | return z 71 | 72 | def __repr__(self): 73 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 74 | + ")" 75 | -------------------------------------------------------------------------------- /block_mixed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from function_transformer_attention import SpGraphTransAttentionLayer 4 | from base_classes import ODEblock 5 | from utils import get_rw_adj 6 | 7 | 8 | class MixedODEblock(ODEblock): 9 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.): 10 | super(MixedODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 11 | 12 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 13 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 14 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 15 | fill_value=opt['self_loop_weight'], 16 | num_nodes=data.num_nodes, 17 | dtype=data.x.dtype) 18 | self.odefunc.edge_index = edge_index.to(device) 19 | self.odefunc.edge_weight = edge_weight.to(device) 20 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 21 | 22 | if opt['adjoint']: 23 | from torchdiffeq import odeint_adjoint as odeint 24 | else: 25 | from torchdiffeq import odeint 26 | self.train_integrator = odeint 27 | self.test_integrator = odeint 28 | self.set_tol() 29 | # parameter trading off between attention and the Laplacian 30 | self.gamma = nn.Parameter(gamma * torch.ones(1)) 31 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 32 | device).to(device) 33 | 34 | def get_attention_weights(self, x): 35 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index) 36 | return attention 37 | 38 | def get_mixed_attention(self, x): 39 | gamma = torch.sigmoid(self.gamma) 40 | attention = self.get_attention_weights(x) 41 | mixed_attention = attention.mean(dim=1) * (1 - gamma) + self.odefunc.edge_weight * gamma 42 | return mixed_attention 43 | 44 | def forward(self, x): 45 | t = self.t.type_as(x) 46 | self.odefunc.attention_weights = self.get_mixed_attention(x) 47 | integrator = self.train_integrator if self.training else self.test_integrator 48 | if self.opt["adjoint"] and self.training: 49 | z = integrator( 50 | self.odefunc, x, t, 51 | method=self.opt['method'], 52 | options={'step_size': self.opt['step_size']}, 53 | adjoint_method=self.opt['adjoint_method'], 54 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 55 | atol=self.atol, 56 | rtol=self.rtol, 57 | adjoint_atol=self.atol_adjoint, 58 | adjoint_rtol=self.rtol_adjoint)[1] 59 | else: 60 | z = integrator( 61 | self.odefunc, x, t, 62 | method=self.opt['method'], 63 | options={'step_size': self.opt['step_size']}, 64 | atol=self.atol, 65 | rtol=self.rtol)[1] 66 | 67 | return z 68 | 69 | def __repr__(self): 70 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 71 | + ")" 72 | -------------------------------------------------------------------------------- /block_transformer_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | 6 | 7 | class AttODEblock(ODEblock): 8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 9 | super(AttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 10 | 11 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 12 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 13 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 14 | fill_value=opt['self_loop_weight'], 15 | num_nodes=data.num_nodes, 16 | dtype=data.x.dtype) 17 | self.odefunc.edge_index = edge_index.to(device) 18 | self.odefunc.edge_weight = edge_weight.to(device) 19 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 20 | 21 | if opt['adjoint']: 22 | from torchdiffeq import odeint_adjoint as odeint 23 | else: 24 | from torchdiffeq import odeint 25 | self.train_integrator = odeint 26 | self.test_integrator = odeint 27 | self.set_tol() 28 | # parameter trading off between attention and the Laplacian 29 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 30 | device, edge_weights=self.odefunc.edge_weight).to(device) 31 | 32 | def get_attention_weights(self, x): 33 | attention, values = self.multihead_att_layer(x, self.odefunc.edge_index) 34 | return attention 35 | 36 | def forward(self, x): 37 | t = self.t.type_as(x) 38 | self.odefunc.attention_weights = self.get_attention_weights(x) 39 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 40 | integrator = self.train_integrator if self.training else self.test_integrator 41 | 42 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 43 | 44 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 45 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 46 | 47 | if self.opt["adjoint"] and self.training: 48 | state_dt = integrator( 49 | func, state, t, 50 | method=self.opt['method'], 51 | options={'step_size': self.opt['step_size']}, 52 | adjoint_method=self.opt['adjoint_method'], 53 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 54 | atol=self.atol, 55 | rtol=self.rtol, 56 | adjoint_atol=self.atol_adjoint, 57 | adjoint_rtol=self.rtol_adjoint) 58 | else: 59 | state_dt = integrator( 60 | func, state, t, 61 | method=self.opt['method'], 62 | options={'step_size': self.opt['step_size']}, 63 | atol=self.atol, 64 | rtol=self.rtol) 65 | 66 | if self.training and self.nreg > 0: 67 | z = state_dt[0][1] 68 | reg_states = tuple(st[1] for st in state_dt[1:]) 69 | return z, reg_states 70 | else: 71 | z = state_dt[1] 72 | return z 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 76 | + ")" 77 | -------------------------------------------------------------------------------- /block_transformer_hard_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from function_transformer_attention import SpGraphTransAttentionLayer 3 | from base_classes import ODEblock 4 | from utils import get_rw_adj 5 | from torch_scatter import scatter 6 | 7 | class HardAttODEblock(ODEblock): 8 | def __init__(self, odefunc, regularization_fns, opt, data, device, t=torch.tensor([0, 1]), gamma=0.5): 9 | super(HardAttODEblock, self).__init__(odefunc, regularization_fns, opt, data, device, t) 10 | assert opt['att_samp_pct'] > 0 and opt['att_samp_pct'] <= 1, "attention sampling threshold must be in (0,1]" 11 | self.opt = opt 12 | self.odefunc = odefunc(self.aug_dim * opt['hidden_dim'], self.aug_dim * opt['hidden_dim'], opt, data, device) 13 | # self.odefunc.edge_index, self.odefunc.edge_weight = data.edge_index, edge_weight=data.edge_attr 14 | self.num_nodes = data.num_nodes 15 | edge_index, edge_weight = get_rw_adj(data.edge_index, edge_weight=data.edge_attr, norm_dim=1, 16 | fill_value=opt['self_loop_weight'], 17 | num_nodes=data.num_nodes, 18 | dtype=data.x.dtype) 19 | self.data_edge_index = edge_index.to(device) 20 | self.odefunc.edge_index = edge_index.to(device) # this will be changed by attention scores 21 | self.odefunc.edge_weight = edge_weight.to(device) 22 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 23 | 24 | if opt['adjoint']: 25 | from torchdiffeq import odeint_adjoint as odeint 26 | else: 27 | from torchdiffeq import odeint 28 | self.train_integrator = odeint 29 | self.test_integrator = odeint 30 | self.set_tol() 31 | # parameter trading off between attention and the Laplacian 32 | if opt['function'] not in {'GAT', 'transformer'}: 33 | self.multihead_att_layer = SpGraphTransAttentionLayer(opt['hidden_dim'], opt['hidden_dim'], opt, 34 | device, edge_weights=self.odefunc.edge_weight).to(device) 35 | 36 | def get_attention_weights(self, x): 37 | if self.opt['function'] not in {'GAT', 'transformer'}: 38 | attention, values = self.multihead_att_layer(x, self.data_edge_index) 39 | else: 40 | attention, values = self.odefunc.multihead_att_layer(x, self.data_edge_index) 41 | return attention 42 | 43 | def renormalise_attention(self, attention): 44 | index = self.odefunc.edge_index[self.opt['attention_norm_idx']] 45 | att_sums = scatter(attention, index, dim=0, dim_size=self.num_nodes, reduce='sum')[index] 46 | return attention / (att_sums + 1e-16) 47 | 48 | def forward(self, x): 49 | t = self.t.type_as(x) 50 | attention_weights = self.get_attention_weights(x) 51 | # create attention mask 52 | if self.training: 53 | with torch.no_grad(): 54 | mean_att = attention_weights.mean(dim=1, keepdim=False) 55 | if self.opt['use_flux']: 56 | src_features = x[self.data_edge_index[0, :], :] 57 | dst_features = x[self.data_edge_index[1, :], :] 58 | delta = torch.linalg.norm(src_features-dst_features, dim=1) 59 | mean_att = mean_att * delta 60 | threshold = torch.quantile(mean_att, 1-self.opt['att_samp_pct']) 61 | mask = mean_att > threshold 62 | self.odefunc.edge_index = self.data_edge_index[:, mask.T] 63 | sampled_attention_weights = self.renormalise_attention(mean_att[mask]) 64 | print('retaining {} of {} edges'.format(self.odefunc.edge_index.shape[1], self.data_edge_index.shape[1])) 65 | self.odefunc.attention_weights = sampled_attention_weights 66 | else: 67 | self.odefunc.edge_index = self.data_edge_index 68 | self.odefunc.attention_weights = attention_weights.mean(dim=1, keepdim=False) 69 | self.reg_odefunc.odefunc.edge_index, self.reg_odefunc.odefunc.edge_weight = self.odefunc.edge_index, self.odefunc.edge_weight 70 | self.reg_odefunc.odefunc.attention_weights = self.odefunc.attention_weights 71 | integrator = self.train_integrator if self.training else self.test_integrator 72 | 73 | reg_states = tuple(torch.zeros(x.size(0)).to(x) for i in range(self.nreg)) 74 | 75 | func = self.reg_odefunc if self.training and self.nreg > 0 else self.odefunc 76 | state = (x,) + reg_states if self.training and self.nreg > 0 else x 77 | 78 | if self.opt["adjoint"] and self.training: 79 | state_dt = integrator( 80 | func, state, t, 81 | method=self.opt['method'], 82 | options={'step_size': self.opt['step_size']}, 83 | adjoint_method=self.opt['adjoint_method'], 84 | adjoint_options={'step_size': self.opt['adjoint_step_size']}, 85 | atol=self.atol, 86 | rtol=self.rtol, 87 | adjoint_atol=self.atol_adjoint, 88 | adjoint_rtol=self.rtol_adjoint) 89 | else: 90 | state_dt = integrator( 91 | func, state, t, 92 | method=self.opt['method'], 93 | options={'step_size': self.opt['step_size']}, 94 | atol=self.atol, 95 | rtol=self.rtol) 96 | 97 | if self.training and self.nreg > 0: 98 | z = state_dt[0][1] 99 | reg_states = tuple(st[1] for st in state_dt[1:]) 100 | return z, reg_states 101 | else: 102 | z = state_dt[1] 103 | return z 104 | 105 | def __repr__(self): 106 | return self.__class__.__name__ + '( Time Interval ' + str(self.t[0].item()) + ' -> ' + str(self.t[1].item()) \ 107 | + ")" 108 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code partially copied from 'Diffusion Improves Graph Learning' repo https://github.com/klicperajo/gdc/blob/master/data.py 3 | """ 4 | 5 | import os 6 | 7 | import numpy as np 8 | 9 | import torch 10 | from torch_geometric.data import Data, InMemoryDataset 11 | from torch_geometric.datasets import Planetoid, Amazon, Coauthor 12 | from graph_rewiring import get_two_hop, apply_gdc 13 | from ogb.nodeproppred import PygNodePropPredDataset 14 | import torch_geometric.transforms as T 15 | from torch_geometric.utils import to_undirected 16 | from graph_rewiring import make_symmetric, apply_pos_dist_rewire 17 | from heterophilic import WebKB, WikipediaNetwork, Actor 18 | from utils import ROOT_DIR 19 | 20 | DATA_PATH = f'{ROOT_DIR}/data' 21 | 22 | 23 | def rewire(data, opt, data_dir): 24 | rw = opt['rewiring'] 25 | if rw == 'two_hop': 26 | data = get_two_hop(data) 27 | elif rw == 'gdc': 28 | data = apply_gdc(data, opt) 29 | elif rw == 'pos_enc_knn': 30 | data = apply_pos_dist_rewire(data, opt, data_dir) 31 | return data 32 | 33 | 34 | def get_dataset(opt: dict, data_dir, use_lcc: bool = False) -> InMemoryDataset: 35 | ds = opt['dataset'] 36 | path = os.path.join(data_dir, ds) 37 | if ds in ['Cora', 'Citeseer', 'Pubmed']: 38 | dataset = Planetoid(path, ds) 39 | elif ds in ['Computers', 'Photo']: 40 | dataset = Amazon(path, ds) 41 | elif ds == 'CoauthorCS': 42 | dataset = Coauthor(path, 'CS') 43 | elif ds in ['cornell', 'texas', 'wisconsin']: 44 | dataset = WebKB(root=path, name=ds, transform=T.NormalizeFeatures()) 45 | elif ds in ['chameleon', 'squirrel']: 46 | dataset = WikipediaNetwork(root=path, name=ds, transform=T.NormalizeFeatures()) 47 | elif ds == 'film': 48 | dataset = Actor(root=path, transform=T.NormalizeFeatures()) 49 | elif ds == 'ogbn-arxiv': 50 | dataset = PygNodePropPredDataset(name=ds, root=path, 51 | transform=T.ToSparseTensor()) 52 | use_lcc = False # never need to calculate the lcc with ogb datasets 53 | else: 54 | raise Exception('Unknown dataset.') 55 | 56 | if use_lcc: 57 | lcc = get_largest_connected_component(dataset) 58 | 59 | x_new = dataset.data.x[lcc] 60 | y_new = dataset.data.y[lcc] 61 | 62 | row, col = dataset.data.edge_index.numpy() 63 | edges = [[i, j] for i, j in zip(row, col) if i in lcc and j in lcc] 64 | edges = remap_edges(edges, get_node_mapper(lcc)) 65 | 66 | data = Data( 67 | x=x_new, 68 | edge_index=torch.LongTensor(edges), 69 | y=y_new, 70 | train_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 71 | test_mask=torch.zeros(y_new.size()[0], dtype=torch.bool), 72 | val_mask=torch.zeros(y_new.size()[0], dtype=torch.bool) 73 | ) 74 | dataset.data = data 75 | if opt['rewiring'] is not None: 76 | dataset.data = rewire(dataset.data, opt, data_dir) 77 | train_mask_exists = True 78 | try: 79 | dataset.data.train_mask 80 | except AttributeError: 81 | train_mask_exists = False 82 | 83 | if ds == 'ogbn-arxiv': 84 | split_idx = dataset.get_idx_split() 85 | ei = to_undirected(dataset.data.edge_index) 86 | data = Data( 87 | x=dataset.data.x, 88 | edge_index=ei, 89 | y=dataset.data.y, 90 | train_mask=split_idx['train'], 91 | test_mask=split_idx['test'], 92 | val_mask=split_idx['valid']) 93 | dataset.data = data 94 | train_mask_exists = True 95 | 96 | #todo this currently breaks with heterophilic datasets if you don't pass --geom_gcn_splits 97 | if (use_lcc or not train_mask_exists) and not opt['geom_gcn_splits']: 98 | dataset.data = set_train_val_test_split( 99 | 12345, 100 | dataset.data, 101 | num_development=5000 if ds == "CoauthorCS" else 1500) 102 | 103 | return dataset 104 | 105 | 106 | def get_component(dataset: InMemoryDataset, start: int = 0) -> set: 107 | visited_nodes = set() 108 | queued_nodes = set([start]) 109 | row, col = dataset.data.edge_index.numpy() 110 | while queued_nodes: 111 | current_node = queued_nodes.pop() 112 | visited_nodes.update([current_node]) 113 | neighbors = col[np.where(row == current_node)[0]] 114 | neighbors = [n for n in neighbors if n not in visited_nodes and n not in queued_nodes] 115 | queued_nodes.update(neighbors) 116 | return visited_nodes 117 | 118 | 119 | def get_largest_connected_component(dataset: InMemoryDataset) -> np.ndarray: 120 | remaining_nodes = set(range(dataset.data.x.shape[0])) 121 | comps = [] 122 | while remaining_nodes: 123 | start = min(remaining_nodes) 124 | comp = get_component(dataset, start) 125 | comps.append(comp) 126 | remaining_nodes = remaining_nodes.difference(comp) 127 | return np.array(list(comps[np.argmax(list(map(len, comps)))])) 128 | 129 | 130 | def get_node_mapper(lcc: np.ndarray) -> dict: 131 | mapper = {} 132 | counter = 0 133 | for node in lcc: 134 | mapper[node] = counter 135 | counter += 1 136 | return mapper 137 | 138 | 139 | def remap_edges(edges: list, mapper: dict) -> list: 140 | row = [e[0] for e in edges] 141 | col = [e[1] for e in edges] 142 | row = list(map(lambda x: mapper[x], row)) 143 | col = list(map(lambda x: mapper[x], col)) 144 | return [row, col] 145 | 146 | 147 | def set_train_val_test_split( 148 | seed: int, 149 | data: Data, 150 | num_development: int = 1500, 151 | num_per_class: int = 20) -> Data: 152 | rnd_state = np.random.RandomState(seed) 153 | num_nodes = data.y.shape[0] 154 | development_idx = rnd_state.choice(num_nodes, num_development, replace=False) 155 | test_idx = [i for i in np.arange(num_nodes) if i not in development_idx] 156 | 157 | train_idx = [] 158 | rnd_state = np.random.RandomState(seed) 159 | for c in range(data.y.max() + 1): 160 | class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]] 161 | train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False)) 162 | 163 | val_idx = [i for i in development_idx if i not in train_idx] 164 | 165 | def get_mask(idx): 166 | mask = torch.zeros(num_nodes, dtype=torch.bool) 167 | mask[idx] = 1 168 | return mask 169 | 170 | data.train_mask = get_mask(train_idx) 171 | data.val_mask = get_mask(val_idx) 172 | data.test_mask = get_mask(test_idx) 173 | 174 | return data 175 | -------------------------------------------------------------------------------- /distances_kNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.neighbors import NearestNeighbors, KDTree, BallTree, DistanceMetric 3 | 4 | 5 | def apply_feat_KNN(x, k): 6 | nbrs = NearestNeighbors(n_neighbors=k).fit(x) 7 | distances, indices = nbrs.kneighbors(x) 8 | src = np.linspace(0, len(x) * k, len(x) * k + 1)[:-1] // k 9 | dst = indices.reshape(-1) 10 | ei = np.vstack((src, dst)) 11 | return ei 12 | 13 | def apply_dist_KNN(x, k): 14 | nbrs = NearestNeighbors(n_neighbors=k, metric='precomputed').fit(x) 15 | distances, indices = nbrs.kneighbors(x) 16 | src = np.linspace(0, len(x) * k, len(x) * k + 1)[:-1] // k 17 | dst = indices.reshape(-1) 18 | ei = np.vstack((src, dst)) 19 | return ei 20 | 21 | def threshold_mat(dist, quant=1/1000): 22 | thresh = np.quantile(dist, quant, axis=None) 23 | A = dist <= thresh 24 | return A 25 | 26 | def make_ei(A): 27 | src, dst = np.where(A) 28 | ei = np.vstack((src, dst)) 29 | return ei 30 | 31 | def apply_dist_threshold(dist, quant=1/1000): 32 | return make_ei(threshold_mat(dist, quant)) 33 | 34 | 35 | def get_distances(x): 36 | dist = DistanceMetric.get_metric('euclidean') 37 | return dist.pairwise(x) 38 | 39 | if __name__ == "__main__": 40 | # triangele 41 | # dist = np.array([[0, 1, 1], [1, 0, 1], [1, 1, 0]]) 42 | # square 43 | dist = np.array([[0, 1, 1, np.sqrt(2)], [1, 0, np.sqrt(2), 1], [1, np.sqrt(2), 0, 1], [np.sqrt(2), 1, 1, 0]]) 44 | print(f"distances \n {dist}") 45 | 46 | for k in range(4): # 3 47 | print(f"{k + 1} edges \n {apply_dist_KNN(dist, k + 1)}") 48 | 49 | quant= 0.75 50 | thresh = np.quantile(dist, quant, axis=None) 51 | 52 | A = threshold_mat(dist, quant) 53 | print(f"Threshold mat \n {A}") 54 | print(f"Edge index1 \n {make_ei(A)}") 55 | print(f"Edge index2 \n {apply_dist_threshold(dist, quant)}") 56 | 57 | square = np.array([[0,1],[1,1],[0,0],[1,0]]) 58 | sq_dist = get_distances(square) 59 | print(f"sq_dist \n {sq_dist}") -------------------------------------------------------------------------------- /function_GAT_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_geometric.utils import softmax 4 | import torch_sparse 5 | from torch_geometric.utils.loop import add_remaining_self_loops 6 | from data import get_dataset 7 | from utils import MaxNFEException 8 | from base_classes import ODEFunc 9 | 10 | 11 | class ODEFuncAtt(ODEFunc): 12 | 13 | def __init__(self, in_features, out_features, opt, data, device): 14 | super(ODEFuncAtt, self).__init__(opt, data, device) 15 | 16 | if opt['self_loop_weight'] > 0: 17 | self.edge_index, self.edge_weight = add_remaining_self_loops(data.edge_index, data.edge_attr, 18 | fill_value=opt['self_loop_weight']) 19 | else: 20 | self.edge_index, self.edge_weight = data.edge_index, data.edge_attr 21 | 22 | self.multihead_att_layer = SpGraphAttentionLayer(in_features, out_features, opt, 23 | device).to(device) 24 | try: 25 | self.attention_dim = opt['attention_dim'] 26 | except KeyError: 27 | self.attention_dim = out_features 28 | 29 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 30 | self.d_k = self.attention_dim // opt['heads'] 31 | 32 | def multiply_attention(self, x, attention, wx): 33 | if self.opt['mix_features']: 34 | wx = torch.mean(torch.stack( 35 | [torch_sparse.spmm(self.edge_index, attention[:, idx], wx.shape[0], wx.shape[0], wx) for idx in 36 | range(self.opt['heads'])], dim=0), 37 | dim=0) 38 | ax = torch.mm(wx, self.multihead_att_layer.Wout) 39 | else: 40 | ax = torch.mean(torch.stack( 41 | [torch_sparse.spmm(self.edge_index, attention[:, idx], x.shape[0], x.shape[0], x) for idx in 42 | range(self.opt['heads'])], dim=0), 43 | dim=0) 44 | return ax 45 | 46 | def forward(self, t, x): # t is needed when called by the integrator 47 | 48 | if self.nfe > self.opt["max_nfe"]: 49 | raise MaxNFEException 50 | 51 | self.nfe += 1 52 | 53 | attention, wx = self.multihead_att_layer(x, self.edge_index) 54 | ax = self.multiply_attention(x, attention, wx) 55 | # todo would be nice if this was more efficient 56 | 57 | if not self.opt['no_alpha_sigmoid']: 58 | alpha = torch.sigmoid(self.alpha_train) 59 | else: 60 | alpha = self.alpha_train 61 | 62 | f = alpha * (ax - x) 63 | if self.opt['add_source']: 64 | f = f + self.beta_train * self.x0 65 | return f 66 | 67 | def __repr__(self): 68 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 69 | 70 | 71 | class SpGraphAttentionLayer(nn.Module): 72 | """ 73 | Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903 74 | """ 75 | 76 | def __init__(self, in_features, out_features, opt, device, concat=True): 77 | super(SpGraphAttentionLayer, self).__init__() 78 | self.in_features = in_features 79 | self.out_features = out_features 80 | self.alpha = opt['leaky_relu_slope'] 81 | self.concat = concat 82 | self.device = device 83 | self.opt = opt 84 | self.h = opt['heads'] 85 | 86 | try: 87 | self.attention_dim = opt['attention_dim'] 88 | except KeyError: 89 | self.attention_dim = out_features 90 | 91 | assert self.attention_dim % opt['heads'] == 0, "Number of heads must be a factor of the dimension size" 92 | self.d_k = self.attention_dim // opt['heads'] 93 | 94 | self.W = nn.Parameter(torch.zeros(size=(in_features, self.attention_dim))).to(device) 95 | nn.init.xavier_normal_(self.W.data, gain=1.414) 96 | 97 | self.Wout = nn.Parameter(torch.zeros(size=(self.attention_dim, self.in_features))).to(device) 98 | nn.init.xavier_normal_(self.Wout.data, gain=1.414) 99 | 100 | self.a = nn.Parameter(torch.zeros(size=(2 * self.d_k, 1, 1))).to(device) 101 | nn.init.xavier_normal_(self.a.data, gain=1.414) 102 | 103 | self.leakyrelu = nn.LeakyReLU(self.alpha) 104 | 105 | def forward(self, x, edge): 106 | wx = torch.mm(x, self.W) # h: N x out 107 | h = wx.view(-1, self.h, self.d_k) 108 | h = h.transpose(1, 2) 109 | 110 | # Self-attention on the nodes - Shared attention mechanism 111 | edge_h = torch.cat((h[edge[0, :], :, :], h[edge[1, :], :, :]), dim=1).transpose(0, 1).to( 112 | self.device) # edge: 2*D x E 113 | edge_e = self.leakyrelu(torch.sum(self.a * edge_h, dim=0)).to(self.device) 114 | attention = softmax(edge_e, edge[self.opt['attention_norm_idx']]) 115 | return attention, wx 116 | 117 | def __repr__(self): 118 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 119 | 120 | 121 | if __name__ == '__main__': 122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 123 | opt = {'dataset': 'Cora', 'self_loop_weight': 1, 'leaky_relu_slope': 0.2, 'beta_dim': 'vc', 'heads': 2, 'K': 10, 'attention_norm_idx': 0, 124 | 'add_source':False, 'alpha_dim': 'sc', 'beta_dim': 'vc', 'max_nfe':1000, 'mix_features': False} 125 | dataset = get_dataset(opt, '../data', False) 126 | t = 1 127 | func = ODEFuncAtt(dataset.data.num_features, 6, opt, dataset.data, device) 128 | out = func(t, dataset.data.x) 129 | -------------------------------------------------------------------------------- /function_laplacian_diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch_sparse 4 | 5 | from base_classes import ODEFunc 6 | from utils import MaxNFEException 7 | 8 | 9 | # Define the ODE function. 10 | # Input: 11 | # --- t: A tensor with shape [], meaning the current time. 12 | # --- x: A tensor with shape [#batches, dims], meaning the value of x at t. 13 | # Output: 14 | # --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t. 15 | class LaplacianODEFunc(ODEFunc): 16 | 17 | # currently requires in_features = out_features 18 | def __init__(self, in_features, out_features, opt, data, device): 19 | super(LaplacianODEFunc, self).__init__(opt, data, device) 20 | 21 | self.in_features = in_features 22 | self.out_features = out_features 23 | self.w = nn.Parameter(torch.eye(opt['hidden_dim'])) 24 | self.d = nn.Parameter(torch.zeros(opt['hidden_dim']) + 1) 25 | self.alpha_sc = nn.Parameter(torch.ones(1)) 26 | self.beta_sc = nn.Parameter(torch.ones(1)) 27 | 28 | def sparse_multiply(self, x): 29 | if self.opt['block'] in ['attention']: # adj is a multihead attention 30 | mean_attention = self.attention_weights.mean(dim=1) 31 | ax = torch_sparse.spmm(self.edge_index, mean_attention, x.shape[0], x.shape[0], x) 32 | elif self.opt['block'] in ['mixed', 'hard_attention']: # adj is a torch sparse matrix 33 | ax = torch_sparse.spmm(self.edge_index, self.attention_weights, x.shape[0], x.shape[0], x) 34 | else: # adj is a torch sparse matrix 35 | ax = torch_sparse.spmm(self.edge_index, self.edge_weight, x.shape[0], x.shape[0], x) 36 | return ax 37 | 38 | def forward(self, t, x): # the t param is needed by the ODE solver. 39 | if self.nfe > self.opt["max_nfe"]: 40 | raise MaxNFEException 41 | self.nfe += 1 42 | ax = self.sparse_multiply(x) 43 | if not self.opt['no_alpha_sigmoid']: 44 | alpha = torch.sigmoid(self.alpha_train) 45 | else: 46 | alpha = self.alpha_train 47 | 48 | f = alpha * (ax - x) 49 | if self.opt['add_source']: 50 | f = f + self.beta_train * self.x0 51 | return f 52 | -------------------------------------------------------------------------------- /grb/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" -------------------------------------------------------------------------------- /grb/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/__init__.py: -------------------------------------------------------------------------------- 1 | """Attack Module for implementation of graph adversarial attacks""" 2 | from .base import Attack, InjectionAttack, ModificationAttack 3 | -------------------------------------------------------------------------------- /grb/attack/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/__pycache__/base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/__pycache__/base.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Attack(metaclass=ABCMeta): 5 | r""" 6 | 7 | Description 8 | ----------- 9 | Abstract class for graph adversarial attack. 10 | 11 | """ 12 | @abstractmethod 13 | def attack(self, model, adj, features, **kwargs): 14 | r""" 15 | 16 | Parameters 17 | ---------- 18 | model : torch.nn.module 19 | Model implemented based on ``torch.nn.module``. 20 | adj : scipy.sparse.csr.csr_matrix 21 | Adjacency matrix in form of ``N * N`` sparse matrix. 22 | features : torch.FloatTensor 23 | Features in form of ``N * D`` torch float tensor. 24 | kwargs : 25 | Keyword-only arguments. 26 | 27 | """ 28 | 29 | 30 | class ModificationAttack(Attack): 31 | r""" 32 | 33 | Description 34 | ----------- 35 | Abstract class for graph modification attack. 36 | 37 | """ 38 | @abstractmethod 39 | def attack(self, **kwargs): 40 | """ 41 | 42 | Parameters 43 | ---------- 44 | kwargs : 45 | Keyword-only arguments. 46 | 47 | """ 48 | 49 | @abstractmethod 50 | def modification(self, **kwargs): 51 | """ 52 | 53 | Parameters 54 | ---------- 55 | kwargs : 56 | Keyword-only arguments. 57 | 58 | """ 59 | 60 | 61 | class InjectionAttack(Attack): 62 | r""" 63 | 64 | Description 65 | ----------- 66 | Abstract class for graph injection attack. 67 | 68 | """ 69 | @abstractmethod 70 | def attack(self, **kwargs): 71 | """ 72 | 73 | Parameters 74 | ---------- 75 | kwargs : 76 | Keyword-only arguments. 77 | 78 | """ 79 | 80 | @abstractmethod 81 | def injection(self, **kwargs): 82 | """ 83 | 84 | Parameters 85 | ---------- 86 | kwargs : 87 | Keyword-only arguments. 88 | 89 | """ 90 | 91 | @abstractmethod 92 | def update_features(self, **kwargs): 93 | """ 94 | 95 | Parameters 96 | ---------- 97 | kwargs : 98 | Keyword-only arguments. 99 | 100 | """ 101 | 102 | 103 | class EarlyStop(object): 104 | r""" 105 | 106 | Description 107 | ----------- 108 | Strategy to early stop attack process. 109 | 110 | """ 111 | def __init__(self, patience=1000, epsilon=1e-4): 112 | r""" 113 | 114 | Parameters 115 | ---------- 116 | patience : int, optional 117 | Number of epoch to wait if no further improvement. Default: ``1000``. 118 | epsilon : float, optional 119 | Tolerance range of improvement. Default: ``1e-4``. 120 | 121 | """ 122 | self.patience = patience 123 | self.epsilon = epsilon 124 | self.min_score = None 125 | self.stop = False 126 | self.count = 0 127 | 128 | def __call__(self, score): 129 | r""" 130 | 131 | Parameters 132 | ---------- 133 | score : float 134 | Value of attack acore. 135 | 136 | """ 137 | if self.min_score is None: 138 | self.min_score = score 139 | elif self.min_score - score > 0: 140 | self.count = 0 141 | self.min_score = score 142 | elif self.min_score - score < self.epsilon: 143 | self.count += 1 144 | if self.count > self.patience: 145 | self.stop = True 146 | -------------------------------------------------------------------------------- /grb/attack/injection/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph injection attacks""" 2 | from .fgsm import FGSM 3 | from .pgd import PGD 4 | from .rand import RAND 5 | from .speit import SPEIT 6 | from .tdgia import TDGIA -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/fgsm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/fgsm.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/fgsm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/fgsm.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/pgd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/pgd.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/pgd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/pgd.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/rand.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/rand.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/rand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/rand.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/speit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/speit.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/speit.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/speit.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/tdgia.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/tdgia.cpython-36.pyc -------------------------------------------------------------------------------- /grb/attack/injection/__pycache__/tdgia.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/injection/__pycache__/tdgia.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__init__.py: -------------------------------------------------------------------------------- 1 | from .dice import DICE 2 | from .fga import FGA 3 | from .flip import FLIP 4 | from .nea import NEA 5 | from .rand import RAND 6 | from .stack import STACK 7 | from .pgd import PGD 8 | from .prbcd import PRBCD -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/dice.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/dice.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/fga.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/fga.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/flip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/flip.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/nea.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/nea.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/pgd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/pgd.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/prbcd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/prbcd.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/rand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/rand.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/__pycache__/stack.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/attack/modification/__pycache__/stack.cpython-38.pyc -------------------------------------------------------------------------------- /grb/attack/modification/dice.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm.auto import tqdm 3 | 4 | from ..base import ModificationAttack 5 | 6 | 7 | class DICE(ModificationAttack): 8 | """ 9 | DICE (delete internally, connect externally) 10 | """ 11 | 12 | def __init__(self, 13 | n_edge_mod, 14 | ratio_delete=0.6, 15 | allow_isolate=True, 16 | device="cpu", 17 | verbose=True): 18 | self.n_edge_mod = n_edge_mod 19 | self.ratio_delete = ratio_delete 20 | self.allow_isolate = allow_isolate 21 | self.device = device 22 | self.verbose = verbose 23 | 24 | def attack(self, adj, index_target, labels): 25 | adj_attack = self.modification(adj, index_target, labels) 26 | 27 | return adj_attack 28 | 29 | def modification(self, adj, index_target, labels): 30 | adj_attack = adj.tolil() 31 | degrees = adj_attack.getnnz(axis=1) 32 | 33 | # delete internally 34 | print("Delete internally......") 35 | n_delete = int(np.floor(self.n_edge_mod * self.ratio_delete)) 36 | index_i, index_j = index_target[adj_attack[index_target].nonzero()[0]], adj_attack[index_target].nonzero()[1] 37 | target_index_pair = [] 38 | for index in tqdm(zip(index_i, index_j), total=len(index_i)): 39 | if index[0] != index[1] and labels[index[0]] == labels[index[1]]: 40 | if self.allow_isolate: 41 | # if index[::-1] not in target_index_pair: 42 | target_index_pair.append(index) 43 | else: 44 | if degrees[index[0]] > 1 and degrees[index[1]] > 1: 45 | # if index[::-1] not in target_index_pair: 46 | target_index_pair.append(index) 47 | degrees[index[0]] -= 1 48 | degrees[index[1]] -= 1 49 | 50 | index_delete = np.random.permutation(target_index_pair)[:n_delete] 51 | if index_delete != []: 52 | adj_attack[index_delete[:, 0], index_delete[:, 1]] = 0 53 | adj_attack[index_delete[:, 1], index_delete[:, 0]] = 0 54 | 55 | # connect externally 56 | print("Connect externally......") 57 | n_connect = self.n_edge_mod - n_delete 58 | index_i, index_j = index_target, np.arange(adj_attack.shape[1]) 59 | target_index_pair = [] 60 | for index in tqdm(zip(index_i, index_j), total=len(index_i)): 61 | if index[0] != index[1] and labels[index[0]] != labels[index[1]] and adj_attack[index[0], index[1]] == 0: 62 | # if index[::-1] not in target_index_pair: 63 | target_index_pair.append(index) 64 | 65 | index_connect = np.random.permutation(target_index_pair)[:n_connect] 66 | adj_attack[index_connect[:, 0], index_connect[:, 1]] = 1 67 | adj_attack[index_connect[:, 1], index_connect[:, 0]] = 1 68 | adj_attack = adj_attack.tocsr() 69 | adj_attack.eliminate_zeros() 70 | 71 | if self.verbose: 72 | print( 73 | "DICE attack finished. {:d} edges were removed, {:d} edges were connected.".format(n_delete, n_connect)) 74 | 75 | return adj_attack 76 | -------------------------------------------------------------------------------- /grb/attack/modification/fga.py: -------------------------------------------------------------------------------- 1 | import scipy.sparse as sp 2 | import torch 3 | import torch.nn.functional as F 4 | from tqdm.auto import tqdm 5 | 6 | from ..base import ModificationAttack 7 | from ...utils import utils 8 | 9 | 10 | class FGA(ModificationAttack): 11 | """ 12 | FGA: Fast Gradient Attack on Network Embedding (https://arxiv.org/pdf/1809.02797.pdf) 13 | """ 14 | 15 | def __init__(self, 16 | n_edge_mod, 17 | loss=F.cross_entropy, 18 | allow_isolate=True, 19 | device="cpu", 20 | verbose=True): 21 | self.n_edge_mod = n_edge_mod 22 | self.allow_isolate = allow_isolate 23 | self.loss = loss 24 | self.device = device 25 | self.verbose = verbose 26 | 27 | def attack(self, 28 | model, 29 | adj, 30 | features, 31 | index_target, 32 | feat_norm=None, 33 | adj_norm_func=None): 34 | 35 | features = utils.feat_preprocess(features=features, 36 | feat_norm=model.feat_norm if feat_norm is None else feat_norm, 37 | device=self.device) 38 | adj_tensor = utils.adj_preprocess(adj=adj, 39 | adj_norm_func=model.adj_norm_func if adj_norm_func is None else adj_norm_func, 40 | model_type=model.model_type, 41 | device=self.device) 42 | model.to(self.device) 43 | pred_origin = model(features, adj_tensor) 44 | labels_origin = torch.argmax(pred_origin, dim=1) 45 | 46 | adj_attack = self.modification(model=model, 47 | adj_origin=adj, 48 | features_origin=features, 49 | labels_origin=labels_origin, 50 | index_target=index_target, 51 | feat_norm=feat_norm, 52 | adj_norm_func=adj_norm_func) 53 | 54 | return adj_attack 55 | 56 | def modification(self, 57 | model, 58 | adj_origin, 59 | features_origin, 60 | labels_origin, 61 | index_target, 62 | feat_norm=None, 63 | adj_norm_func=None): 64 | model.eval() 65 | adj_attack = adj_origin.todense() 66 | adj_attack = torch.FloatTensor(adj_attack) 67 | features_origin = utils.feat_preprocess(features=features_origin, 68 | feat_norm=model.feat_norm if feat_norm is None else feat_norm, 69 | device=self.device) 70 | adj_attack.requires_grad = True 71 | n_edge_flip = 0 72 | for _ in tqdm(range(adj_attack.shape[1])): 73 | if n_edge_flip >= self.n_edge_mod: 74 | break 75 | adj_attack_tensor = utils.adj_preprocess(adj=adj_attack, 76 | adj_norm_func=model.adj_norm_func if adj_norm_func is None else adj_norm_func, 77 | model_type=model.model_type, 78 | device=self.device) 79 | degs = adj_attack_tensor.sum(dim=1) 80 | pred = model(features_origin, adj_attack_tensor) 81 | loss = self.loss(pred[index_target], labels_origin[index_target]) 82 | grad = torch.autograd.grad(loss, adj_attack)[0] 83 | grad = (grad + grad.T) / torch.Tensor([2.0]) 84 | grad_max = torch.max(grad[index_target], dim=1) 85 | index_max_i = torch.argmax(grad_max.values) 86 | index_max_j = grad_max.indices[index_max_i] 87 | index_max_i = index_target[index_max_i] 88 | if adj_attack[index_max_i][index_max_j] == 0: 89 | adj_attack.data[index_max_i][index_max_j] = 1 90 | adj_attack.data[index_max_j][index_max_i] = 1 91 | n_edge_flip += 1 92 | else: 93 | if self.allow_isolate: 94 | adj_attack.data[index_max_i][index_max_j] = 0 95 | adj_attack.data[index_max_j][index_max_i] = 0 96 | n_edge_flip += 1 97 | else: 98 | if degs[index_max_i] > 1 and degs[index_max_j] > 1: 99 | adj_attack.data[index_max_i][index_max_j] = 0 100 | adj_attack.data[index_max_j][index_max_i] = 0 101 | degs[index_max_i] -= 1 102 | degs[index_max_j] -= 1 103 | n_edge_flip += 1 104 | 105 | adj_attack = adj_attack.detach().cpu().numpy() 106 | adj_attack = sp.csr_matrix(adj_attack) 107 | if self.verbose: 108 | print("FGA attack finished. {:d} edges were flipped.".format(n_edge_flip)) 109 | 110 | return adj_attack 111 | -------------------------------------------------------------------------------- /grb/attack/modification/flip.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | from tqdm.auto import tqdm 4 | 5 | from ..base import ModificationAttack 6 | 7 | 8 | class FLIP(ModificationAttack): 9 | """ 10 | FLIP, degree, betweenness, eigen. 11 | """ 12 | 13 | def __init__(self, 14 | n_edge_mod, 15 | flip_type="deg", 16 | mode="descend", 17 | allow_isolate=True, 18 | device="cpu", 19 | verbose=True): 20 | self.n_edge_mod = n_edge_mod 21 | self.flip_type = flip_type 22 | self.mode = mode 23 | self.allow_isolate = allow_isolate 24 | self.device = device 25 | self.verbose = verbose 26 | 27 | def attack(self, adj, index_target, **kwargs): 28 | adj_attack = self.modification(adj=adj, 29 | index_target=index_target, 30 | flip_type=self.flip_type, 31 | mode=self.mode, 32 | **kwargs) 33 | 34 | return adj_attack 35 | 36 | def modification(self, adj, index_target, flip_type="deg", saved=None, mode="descend"): 37 | adj_attack = adj.copy() 38 | degs = adj_attack.getnnz(axis=1) 39 | if flip_type == "deg": 40 | flip_edges = get_degree_flips_edges(adj, index_target, mode=mode) 41 | elif flip_type == "bet": 42 | flip_edges = betweenness_flips(adj, index_target, saved_bets=saved, mode=mode) 43 | elif flip_type == "eigen": 44 | flip_edges = eigen_flips(adj, index_target, saved_eigens=saved, mode=mode) 45 | else: 46 | raise NotImplementedError 47 | n_edge_flip = 0 48 | for index in tqdm(flip_edges): 49 | if n_edge_flip >= self.n_edge_mod: 50 | break 51 | if adj_attack[index[0], index[1]] == 0: 52 | adj_attack[index[0], index[1]] = 1 53 | adj_attack[index[1], index[0]] = 1 54 | degs[index[0]] += 1 55 | degs[index[1]] += 1 56 | n_edge_flip += 1 57 | else: 58 | if self.allow_isolate: 59 | adj_attack[index[0], index[1]] = 0 60 | adj_attack[index[1], index[0]] = 0 61 | n_edge_flip += 1 62 | else: 63 | if degs[index[0]] > 1 and degs[index[1]] > 1: 64 | adj_attack[index[0], index[1]] = 0 65 | adj_attack[index[1], index[0]] = 0 66 | degs[index[0]] -= 1 67 | degs[index[1]] -= 1 68 | n_edge_flip += 1 69 | if self.verbose: 70 | print("FLIP attack finished. {:d} edges were flipped.".format(n_edge_flip)) 71 | 72 | return adj_attack 73 | 74 | 75 | def get_degree_flips_edges(adj, index_target, mode="descend"): 76 | degs = adj.getnnz(axis=1) 77 | index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1] 78 | deg_score = degs[index_i] + degs[index_j] 79 | if mode == "ascend": 80 | deg_score = deg_score 81 | elif mode == "descend": 82 | deg_score = -deg_score 83 | else: 84 | raise NotImplementedError 85 | edges_target = np.column_stack([index_i, index_j]) 86 | flip_edges_idx = np.argsort(deg_score, axis=0) 87 | flip_edges = edges_target[flip_edges_idx].squeeze() 88 | 89 | return flip_edges 90 | 91 | 92 | def betweenness_flips(adj, index_target, saved_bets=None, mode="descend"): 93 | if saved_bets is None: 94 | g = nx.from_scipy_sparse_matrix(adj) 95 | bets = nx.betweenness_centrality(g) 96 | bets = np.array(list(bets.values())) 97 | else: 98 | bets = saved_bets 99 | index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1] 100 | bet_score = bets[index_i] + bets[index_j] 101 | if mode == "ascend": 102 | bet_score = bet_score 103 | elif mode == "descend": 104 | bet_score = -bet_score 105 | else: 106 | raise NotImplementedError 107 | edges_target = np.column_stack([index_i, index_j]) 108 | flip_edges_idx = np.argsort(bet_score, axis=0) 109 | flip_edges = edges_target[flip_edges_idx].squeeze() 110 | 111 | return flip_edges 112 | 113 | 114 | def eigen_flips(adj, index_target, saved_eigens=None, mode="descend"): 115 | if saved_eigens is None: 116 | g = nx.from_scipy_sparse_matrix(adj) 117 | eigens = nx.eigenvector_centrality(g) 118 | eigens = np.array(list(eigens.values())) 119 | else: 120 | eigens = saved_eigens 121 | index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1] 122 | eigen_score = eigens[index_i] + eigens[index_j] 123 | if mode == "ascend": 124 | eigen_score = eigen_score 125 | elif mode == "descend": 126 | eigen_score = -eigen_score 127 | else: 128 | raise NotImplementedError 129 | 130 | edges_target = np.column_stack([index_i, index_j]) 131 | flip_edges_idx = np.argsort(eigen_score, axis=0) 132 | flip_edges = edges_target[flip_edges_idx].squeeze() 133 | 134 | return flip_edges 135 | -------------------------------------------------------------------------------- /grb/attack/modification/nea.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import scipy.linalg as spl 4 | from tqdm.auto import tqdm 5 | 6 | from ..base import ModificationAttack 7 | 8 | 9 | class NEA(ModificationAttack): 10 | """ 11 | NEA: Adversarial Attacks on Node Embeddings via Graph Poisoning. 12 | """ 13 | 14 | def __init__(self, 15 | n_edge_mod, 16 | allow_isolate=True, 17 | device="cpu", 18 | verbose=True): 19 | self.n_edge_mod = n_edge_mod 20 | self.allow_isolate = allow_isolate 21 | self.device = device 22 | self.verbose = verbose 23 | 24 | def attack(self, adj, index_target, **kwargs): 25 | adj_attack = self.modification(adj=adj, index_target=index_target) 26 | 27 | return adj_attack 28 | 29 | def modification(self, adj, index_target): 30 | adj_attack = adj.copy() 31 | degs = adj_attack.getnnz(axis=1) 32 | adj_ = adj + sp.eye(adj.shape[0]) 33 | eigen_vals, eigen_vecs = spl.eigh(adj_.toarray(), np.diag(adj_.getnnz(axis=1))) 34 | index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1] 35 | edges_target = np.column_stack([index_i, index_j]) 36 | flip_indicator = 1 - 2 * np.array(adj[tuple(edges_target.T)])[0] 37 | eigen_scores = np.zeros(len(edges_target)) 38 | for k in range(len(edges_target)): 39 | i, j = edges_target[k] 40 | vals_est = eigen_vals + flip_indicator[k] * ( 41 | 2 * eigen_vecs[i] * eigen_vecs[j] - eigen_vals * (eigen_vecs[i] ** 2 + eigen_vecs[j] ** 2)) 42 | vals_sum_powers = sum_of_powers(vals_est, 5) 43 | loss_ij = np.sqrt(np.sum(np.sort(vals_sum_powers ** 2)[:adj.shape[0] - 32])) 44 | eigen_scores[k] = loss_ij 45 | struct_scores = - np.expand_dims(eigen_scores, 1) 46 | 47 | flip_edges_idx = np.argsort(struct_scores, axis=0) 48 | flip_edges = edges_target[flip_edges_idx].squeeze() 49 | 50 | n_edge_flip = 0 51 | for index in tqdm(flip_edges): 52 | if n_edge_flip >= self.n_edge_mod: 53 | break 54 | if adj_attack[index[0], index[1]] == 0: 55 | adj_attack[index[0], index[1]] = 1 56 | adj_attack[index[1], index[0]] = 1 57 | degs[index[0]] += 1 58 | degs[index[1]] += 1 59 | n_edge_flip += 1 60 | else: 61 | if self.allow_isolate: 62 | adj_attack[index[0], index[1]] = 0 63 | adj_attack[index[1], index[0]] = 0 64 | n_edge_flip += 1 65 | else: 66 | if degs[index[0]] > 1 and degs[index[1]] > 1: 67 | adj_attack[index[0], index[1]] = 0 68 | adj_attack[index[1], index[0]] = 0 69 | degs[index[0]] -= 1 70 | degs[index[1]] -= 1 71 | n_edge_flip += 1 72 | if self.verbose: 73 | print("NEA attack finished. {:d} edges were flipped.".format(n_edge_flip)) 74 | 75 | return adj_attack 76 | 77 | 78 | def sum_of_powers(x, power): 79 | n = x.shape[0] 80 | sum_powers = np.zeros((power, n)) 81 | for i, i_power in enumerate(range(1, power + 1)): 82 | sum_powers[i] = np.power(x, i_power) 83 | 84 | return sum_powers.sum(0) 85 | -------------------------------------------------------------------------------- /grb/attack/modification/rand.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm.auto import tqdm 3 | 4 | from ..base import ModificationAttack 5 | 6 | 7 | class RAND(ModificationAttack): 8 | """ 9 | FLIP, degree, betweenness, eigen. 10 | """ 11 | 12 | def __init__(self, 13 | n_edge_mod, 14 | allow_isolate=True, 15 | device="cpu", 16 | verbose=True): 17 | self.n_edge_mod = n_edge_mod 18 | self.allow_isolate = allow_isolate 19 | self.device = device 20 | self.verbose = verbose 21 | 22 | def attack(self, adj, index_target): 23 | adj_attack = self.modification(adj, index_target) 24 | 25 | return adj_attack 26 | 27 | def modification(self, adj, index_target): 28 | adj_attack = adj.copy() 29 | degs = adj_attack.getnnz(axis=1) 30 | 31 | # Randomly flip edges 32 | index_i, index_j = index_target[adj_attack[index_target].nonzero()[0]], adj_attack[index_target].nonzero()[1] 33 | flip_edges = np.random.permutation(np.column_stack([index_i, index_j])) 34 | n_edge_flip = 0 35 | for index in tqdm(flip_edges): 36 | if n_edge_flip >= self.n_edge_mod: 37 | break 38 | if adj_attack[index[0], index[1]] == 0: 39 | adj_attack[index[0], index[1]] = 1 40 | adj_attack[index[1], index[0]] = 1 41 | degs[index[0]] += 1 42 | degs[index[1]] += 1 43 | n_edge_flip += 1 44 | else: 45 | if self.allow_isolate: 46 | adj_attack[index[0], index[1]] = 0 47 | adj_attack[index[1], index[0]] = 0 48 | n_edge_flip += 1 49 | else: 50 | if degs[index[0]] > 1 and degs[index[1]] > 1: 51 | adj_attack[index[0], index[1]] = 0 52 | adj_attack[index[1], index[0]] = 0 53 | degs[index[0]] -= 1 54 | degs[index[1]] -= 1 55 | n_edge_flip += 1 56 | 57 | if self.verbose: 58 | print("RAND attack finished. {:d} edges were randomly flipped.".format(n_edge_flip)) 59 | 60 | return adj_attack 61 | -------------------------------------------------------------------------------- /grb/attack/modification/stack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import scipy.linalg as spl 4 | from tqdm.auto import tqdm 5 | 6 | from ..base import ModificationAttack 7 | 8 | 9 | class STACK(ModificationAttack): 10 | """ 11 | STACK. 12 | """ 13 | 14 | def __init__(self, 15 | n_edge_mod, 16 | allow_isolate=True, 17 | device="cpu", 18 | verbose=True): 19 | self.n_edge_mod = n_edge_mod 20 | self.allow_isolate = allow_isolate 21 | self.device = device 22 | self.verbose = verbose 23 | 24 | def attack(self, adj, index_target, **kwargs): 25 | adj_attack = self.modification(adj=adj, index_target=index_target) 26 | 27 | return adj_attack 28 | 29 | def modification(self, adj, index_target): 30 | adj_attack = adj.copy() 31 | degs = adj_attack.getnnz(axis=1) 32 | adj_ = adj + sp.eye(adj.shape[0]) 33 | eigen_vals, eigen_vecs = spl.eigh(adj_.toarray(), np.diag(adj_.getnnz(axis=1))) 34 | index_i, index_j = index_target[adj[index_target].nonzero()[0]], adj[index_target].nonzero()[1] 35 | edges_target = np.column_stack([index_i, index_j]) 36 | 37 | flip_indicator = 1 - 2 * np.array(adj[tuple(edges_target.T)])[0] 38 | eigen_scores = np.zeros(len(edges_target)) 39 | sub_org = np.sqrt(np.sum(eigen_vals ** 2)) 40 | for x in range(len(edges_target)): 41 | i, j = edges_target[x] 42 | vals_est = eigen_vals + flip_indicator[x] * ( 43 | 2 * eigen_vecs[i] * eigen_vecs[j] - eigen_vals * (eigen_vecs[i] ** 2 + eigen_vecs[j] ** 2)) 44 | loss_ij = np.abs(sub_org - np.sqrt(np.sum(vals_est ** 2))) 45 | eigen_scores[x] = loss_ij 46 | struct_scores = np.expand_dims(eigen_scores, 1) 47 | flip_edges_idx = np.argsort(struct_scores, axis=0)[::-1] 48 | flip_edges = edges_target[flip_edges_idx].squeeze() 49 | 50 | n_edge_flip = 0 51 | for index in tqdm(flip_edges): 52 | if n_edge_flip >= self.n_edge_mod: 53 | break 54 | if adj_attack[index[0], index[1]] == 0: 55 | adj_attack[index[0], index[1]] = 1 56 | adj_attack[index[1], index[0]] = 1 57 | degs[index[0]] += 1 58 | degs[index[1]] += 1 59 | n_edge_flip += 1 60 | else: 61 | if self.allow_isolate: 62 | adj_attack[index[0], index[1]] = 0 63 | adj_attack[index[1], index[0]] = 0 64 | n_edge_flip += 1 65 | else: 66 | if degs[index[0]] > 1 and degs[index[1]] > 1: 67 | adj_attack[index[0], index[1]] = 0 68 | adj_attack[index[1], index[0]] = 0 69 | degs[index[0]] -= 1 70 | degs[index[1]] -= 1 71 | n_edge_flip += 1 72 | if self.verbose: 73 | print("STACK attack finished. {:d} edges were flipped.".format(n_edge_flip)) 74 | 75 | return adj_attack 76 | 77 | 78 | def sum_of_powers(x, power): 79 | n = x.shape[0] 80 | sum_powers = np.zeros((power, n)) 81 | for i, i_power in enumerate(range(1, power + 1)): 82 | sum_powers[i] = np.power(x, i_power) 83 | 84 | return sum_powers.sum(0) 85 | -------------------------------------------------------------------------------- /grb/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Dataset Module for loading or customizing datasets.""" 2 | 3 | GRB_SUPPORTED_DATASETS = {"grb-cora", 4 | "grb-citeseer", 5 | "grb-aminer", 6 | "grb-reddit", 7 | "grb-flickr"} 8 | URLs = { 9 | "grb-cora" : {"adj.npz" : "https://cloud.tsinghua.edu.cn/f/2e522f282e884907a39f/?dl=1", 10 | "features.npz": "https://cloud.tsinghua.edu.cn/f/46fd09a8c1d04f11afbb/?dl=1", 11 | "labels.npz" : "https://cloud.tsinghua.edu.cn/f/88fccac46ee94161b48f/?dl=1", 12 | "index.npz" : "https://cloud.tsinghua.edu.cn/f/d8488cbf78a34a8c9c5b/?dl=1"}, 13 | "grb-citeseer": {"adj.npz" : "https://cloud.tsinghua.edu.cn/f/d3063e4e010e431b95a6/?dl=1", 14 | "features.npz": "https://cloud.tsinghua.edu.cn/f/172b66d454d348458bca/?dl=1", 15 | "labels.npz" : "https://cloud.tsinghua.edu.cn/f/f594655156c744da9ef6/?dl=1", 16 | "index.npz" : "https://cloud.tsinghua.edu.cn/f/cb25124f9a454dcf989f/?dl=1"}, 17 | "grb-reddit" : {"adj.npz" : "https://cloud.tsinghua.edu.cn/f/22e91d7f34494784a670/?dl=1", 18 | "features.npz": "https://cloud.tsinghua.edu.cn/f/000dc5cd8dd643dcbfc6/?dl=1", 19 | "labels.npz" : "https://cloud.tsinghua.edu.cn/f/3e228140ede64b7886b2/?dl=1", 20 | "index.npz" : "https://cloud.tsinghua.edu.cn/f/24310393f5394e3a8b73/?dl=1"}, 21 | "grb-aminer" : {"adj.npz" : "https://cloud.tsinghua.edu.cn/f/dca1075cd8cc408bb4c0/?dl=1", 22 | "features.npz": "https://cloud.tsinghua.edu.cn/f/e93ba93dbdd94673bce3/?dl=1", 23 | "labels.npz" : "https://cloud.tsinghua.edu.cn/f/0ddbca54864245f3b4e1/?dl=1", 24 | "index.npz" : "https://cloud.tsinghua.edu.cn/f/3444a2e87ef745e89828/?dl=1"}, 25 | "grb-flickr" : {"adj.npz" : "https://cloud.tsinghua.edu.cn/f/90a513e35f0a4f3896eb/?dl=1", 26 | "features.npz": "https://cloud.tsinghua.edu.cn/f/54b2f1d7ee7c4d5bbcd4/?dl=1", 27 | "labels.npz" : "https://cloud.tsinghua.edu.cn/f/43e9ec09458e4d30b528/?dl=1", 28 | "index.npz" : "https://cloud.tsinghua.edu.cn/f/8239dc6a729e489da44f/?dl=1"}, 29 | } 30 | 31 | from .dataset import Dataset, CustomDataset, CogDLDataset, OGBDataset 32 | -------------------------------------------------------------------------------- /grb/dataset/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/dataset/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/dataset/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/dataset/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/dataset/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/dataset/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /grb/dataset/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/dataset/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/__init__.py: -------------------------------------------------------------------------------- 1 | """Attack Module for implementation of graph adversarial defenses""" 2 | from .adv_trainer import AdvTrainer 3 | from .gcnsvd import GCNSVD 4 | from .gnnguard import GCNGuard, GATGuard 5 | from .robustgcn import RobustGCN 6 | -------------------------------------------------------------------------------- /grb/defense/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/adv_trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/adv_trainer.cpython-36.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/adv_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/adv_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/gcnsvd.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/gcnsvd.cpython-36.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/gcnsvd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/gcnsvd.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/gnnguard.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/gnnguard.cpython-36.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/gnnguard.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/gnnguard.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/robustgcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/robustgcn.cpython-36.pyc -------------------------------------------------------------------------------- /grb/defense/__pycache__/robustgcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/defense/__pycache__/robustgcn.cpython-38.pyc -------------------------------------------------------------------------------- /grb/defense/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class Defense(metaclass=ABCMeta): 5 | """ 6 | Abstract class for defense. 7 | """ 8 | @abstractmethod 9 | def defense(self, model, adj, features, **kwargs): 10 | r""" 11 | 12 | Parameters 13 | ---------- 14 | model : torch.nn.module 15 | Model implemented based on ``torch.nn.module``. 16 | adj : scipy.sparse.csr.csr_matrix 17 | Adjacency matrix in form of ``N * N`` sparse matrix. 18 | features : torch.FloatTensor 19 | Features in form of ``N * D`` torch float tensor. 20 | kwargs : 21 | Keyword-only arguments. 22 | 23 | """ 24 | -------------------------------------------------------------------------------- /grb/defense/gcnsvd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import grb.utils as utils 7 | from grb.model.torch.gcn import GCNConv 8 | from grb.utils.normalize import GCNAdjNorm 9 | 10 | 11 | class GCNSVD(nn.Module): 12 | def __init__(self, 13 | in_features, 14 | out_features, 15 | hidden_features, 16 | n_layers, 17 | activation=F.relu, 18 | layer_norm=False, 19 | feat_norm=None, 20 | adj_norm_func=None, 21 | residual=False, 22 | dropout=0.0, 23 | k=50): 24 | super(GCNSVD, self).__init__() 25 | self.in_features = in_features 26 | self.out_features = out_features 27 | self.feat_norm = feat_norm 28 | self.adj_norm_func = adj_norm_func 29 | if type(hidden_features) is int: 30 | hidden_features = [hidden_features] * (n_layers - 1) 31 | elif type(hidden_features) is list or type(hidden_features) is tuple: 32 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 33 | n_features = [in_features] + hidden_features + [out_features] 34 | 35 | self.layers = nn.ModuleList() 36 | for i in range(n_layers): 37 | if layer_norm: 38 | self.layers.append(nn.LayerNorm(n_features[i])) 39 | self.layers.append(GCNConv(in_features=n_features[i], 40 | out_features=n_features[i + 1], 41 | activation=activation if i != n_layers - 1 else None, 42 | residual=residual if i != n_layers - 1 else False, 43 | dropout=dropout if i != n_layers - 1 else 0.0)) 44 | self.k = k 45 | 46 | @property 47 | def model_type(self): 48 | return "torch" 49 | 50 | def forward(self, x, adj): 51 | adj = self.truncatedSVD(adj, self.k) 52 | adj = utils.adj_preprocess(adj=adj, adj_norm_func=self.adj_norm_func, device=x.device) 53 | for layer in self.layers: 54 | if isinstance(layer, nn.LayerNorm): 55 | x = layer(x) 56 | else: 57 | x = layer(x, adj) 58 | 59 | return x 60 | 61 | def truncatedSVD(self, adj, k=50): 62 | edge_index = adj._indices() 63 | row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:] 64 | 65 | adj = sp.csr_matrix((np.ones(len(row)), (row, col))) 66 | if sp.issparse(adj): 67 | adj = adj.asfptype() 68 | U, S, V = sp.linalg.svds(adj, k=k) 69 | diag_S = np.diag(S) 70 | else: 71 | U, S, V = np.linalg.svd(adj) 72 | U = U[:, :k] 73 | S = S[:k] 74 | V = V[:k, :] 75 | diag_S = np.diag(S) 76 | 77 | new_adj = U @ diag_S @ V 78 | new_adj = sp.csr_matrix(new_adj) 79 | 80 | return new_adj 81 | -------------------------------------------------------------------------------- /grb/defense/robustgcn.py: -------------------------------------------------------------------------------- 1 | """Torch module for RobustGCN.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from grb.utils.normalize import RobustGCNAdjNorm 7 | 8 | class RobustGCN(nn.Module): 9 | r""" 10 | 11 | Description 12 | ----------- 13 | Robust Graph Convolutional Networks (`RobustGCN `__) 14 | 15 | Parameters 16 | ---------- 17 | in_features : int 18 | Dimension of input features. 19 | out_features : int 20 | Dimension of output features. 21 | hidden_features : int or list of int 22 | Dimension of hidden features. List if multi-layer. 23 | feat_norm : str, optional 24 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 25 | adj_norm_func : func of utils.normalize, optional 26 | Function that normalizes adjacency matrix. Default: ``RobustAdjNorm``. 27 | dropout : float, optional 28 | Rate of dropout. Default: ``0.0``. 29 | 30 | """ 31 | 32 | def __init__(self, 33 | in_features, 34 | out_features, 35 | hidden_features, 36 | n_layers, 37 | feat_norm=None, 38 | adj_norm_func=RobustGCNAdjNorm, 39 | dropout=0.0): 40 | super(RobustGCN, self).__init__() 41 | self.in_features = in_features 42 | self.out_features = out_features 43 | self.feat_norm = feat_norm 44 | self.adj_norm_func = adj_norm_func 45 | if type(hidden_features) is int: 46 | hidden_features = [hidden_features] * (n_layers - 1) 47 | elif type(hidden_features) is list or type(hidden_features) is tuple: 48 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 49 | n_features = [in_features] + hidden_features + [out_features] 50 | 51 | self.act0 = F.elu 52 | self.act1 = F.relu 53 | 54 | self.layers = nn.ModuleList() 55 | for i in range(n_layers): 56 | self.layers.append(RobustGCNConv(n_features[i], n_features[i + 1], act0=self.act0, act1=self.act1, 57 | initial=True if i == 0 else False, 58 | dropout=dropout if i != n_layers - 1 else 0.0)) 59 | 60 | @property 61 | def model_type(self): 62 | """Indicate type of implementation.""" 63 | return "torch" 64 | 65 | def forward(self, x, adj): 66 | r""" 67 | 68 | Parameters 69 | ---------- 70 | x : torch.Tensor 71 | Tensor of input features. 72 | adj : list of torch.SparseTensor 73 | List of sparse tensor of adjacency matrix. 74 | 75 | Returns 76 | ------- 77 | x : torch.Tensor 78 | Output of model (logits without activation). 79 | 80 | """ 81 | 82 | adj0, adj1 = adj 83 | mean = x 84 | var = x 85 | for layer in self.layers: 86 | mean, var = layer(mean, var=var, adj0=adj0, adj1=adj1) 87 | sample = torch.randn(var.shape).to(x.device) 88 | output = mean + sample * torch.pow(var, 0.5) 89 | 90 | return output 91 | 92 | 93 | class RobustGCNConv(nn.Module): 94 | r""" 95 | 96 | Description 97 | ----------- 98 | RobustGCN convolutional layer. 99 | 100 | Parameters 101 | ---------- 102 | in_features : int 103 | Dimension of input features. 104 | out_features : int 105 | Dimension of output features. 106 | act0 : func of torch.nn.functional, optional 107 | Activation function. Default: ``F.elu``. 108 | act1 : func of torch.nn.functional, optional 109 | Activation function. Default: ``F.relu``. 110 | initial : bool, optional 111 | Whether to initialize variance. 112 | dropout : float, optional 113 | Rate of dropout. Default: ``0.0``. 114 | 115 | """ 116 | 117 | def __init__(self, in_features, out_features, act0=F.elu, act1=F.relu, initial=False, dropout=0.0): 118 | super(RobustGCNConv, self).__init__() 119 | self.mean_conv = nn.Linear(in_features, out_features) 120 | self.var_conv = nn.Linear(in_features, out_features) 121 | self.act0 = act0 122 | self.act1 = act1 123 | self.initial = initial 124 | if dropout > 0.0: 125 | self.dropout = nn.Dropout(dropout) 126 | else: 127 | self.dropout = None 128 | 129 | def forward(self, mean, var=None, adj0=None, adj1=None): 130 | r""" 131 | 132 | Parameters 133 | ---------- 134 | mean : torch.Tensor 135 | Tensor of mean of input features. 136 | var : torch.Tensor, optional 137 | Tensor of variance of input features. Default: ``None``. 138 | adj0 : torch.SparseTensor, optional 139 | Sparse tensor of adjacency matrix 0. Default: ``None``. 140 | adj1 : torch.SparseTensor, optional 141 | Sparse tensor of adjacency matrix 1. Default: ``None``. 142 | 143 | Returns 144 | ------- 145 | 146 | """ 147 | mean = self.mean_conv(mean) 148 | if self.initial: 149 | var = mean * 1 150 | else: 151 | var = self.var_conv(var) 152 | mean = self.act0(mean) 153 | var = self.act1(var) 154 | attention = torch.exp(-var) 155 | 156 | mean = mean * attention 157 | var = var * attention * attention 158 | mean = torch.spmm(adj0, mean) 159 | var = torch.spmm(adj1, var) 160 | if self.dropout: 161 | mean = self.act0(mean) 162 | var = self.act0(var) 163 | if self.dropout is not None: 164 | mean = self.dropout(mean) 165 | var = self.dropout(var) 166 | 167 | return mean, var 168 | -------------------------------------------------------------------------------- /grb/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | """Evaluator module for unified evaluation on adversarial robustness.""" 2 | from .evaluator import AttackEvaluator, DefenseEvaluator 3 | -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/evaluator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/evaluator.cpython-36.pyc -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/evaluator.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/evaluator.cpython-38.pyc -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/metric.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/metric.cpython-36.pyc -------------------------------------------------------------------------------- /grb/evaluator/__pycache__/metric.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/evaluator/__pycache__/metric.cpython-38.pyc -------------------------------------------------------------------------------- /grb/evaluator/metric.py: -------------------------------------------------------------------------------- 1 | """Evaluation metrics""" 2 | import torch 3 | import numpy as np 4 | from sklearn.metrics import roc_auc_score 5 | 6 | 7 | def eval_acc(pred, labels, mask=None): 8 | r""" 9 | 10 | Description 11 | ----------- 12 | Accuracy metric for node classification. 13 | 14 | Parameters 15 | ---------- 16 | pred : torch.Tensor 17 | Output logits of model in form of ``N * 1``. 18 | labels : torch.LongTensor 19 | Labels in form of ``N * 1``. 20 | mask : torch.Tensor, optional 21 | Mask of nodes to evaluate in form of ``N * 1`` torch bool tensor. Default: ``None``. 22 | 23 | Returns 24 | ------- 25 | acc : float 26 | Node classification accuracy. 27 | 28 | """ 29 | 30 | if mask is not None: 31 | pred, labels = pred[mask], labels[mask] 32 | if pred is None or labels is None: 33 | return 0.0 34 | 35 | acc = (torch.argmax(pred, dim=1) == labels).float().sum() / len(pred) 36 | 37 | return acc 38 | 39 | 40 | def eval_rocauc(pred, labels, mask=None): 41 | r""" 42 | 43 | Description 44 | ----------- 45 | ROC-AUC score for multi-label node classification. 46 | 47 | Parameters 48 | ---------- 49 | pred : torch.Tensor 50 | Output logits of model in form of ``N * 1``. 51 | labels : torch.LongTensor 52 | Labels in form of ``N * 1``. 53 | mask : torch.Tensor, optional 54 | Mask of nodes to evaluate in form of ``N * 1`` torch bool tensor. Default: ``None``. 55 | 56 | 57 | Returns 58 | ------- 59 | rocauc : float 60 | Average ROC-AUC score across different labels. 61 | 62 | """ 63 | 64 | rocauc_list = [] 65 | if mask is not None: 66 | pred, labels = pred[mask], labels[mask] 67 | if pred is None or labels is None: 68 | return 0.0 69 | pred = pred.detach().cpu().numpy() 70 | labels = labels.detach().cpu().numpy() 71 | for i in range(labels.shape[1]): 72 | # AUC is only defined when there is at least one positive data. 73 | if np.sum(labels[:, i] == 1) > 0 and np.sum(labels[:, i] == 0) > 0: 74 | rocauc_list.append(roc_auc_score(y_true=labels[:, i], 75 | y_score=pred[:, i])) 76 | 77 | if len(rocauc_list) == 0: 78 | raise RuntimeError('No positively labeled data available. Cannot compute ROC-AUC.') 79 | 80 | rocauc = sum(rocauc_list) / len(rocauc_list) 81 | 82 | return rocauc 83 | 84 | 85 | def eval_f1multilabel(pred, labels, mask=None): 86 | r""" 87 | 88 | Description 89 | ----------- 90 | F1 score for multi-label node classification. 91 | 92 | Parameters 93 | ---------- 94 | pred : torch.Tensor 95 | Output logits of model in form of ``N * 1``. 96 | labels : torch.LongTensor 97 | Labels in form of ``N * 1``. 98 | mask : torch.Tensor, optional 99 | Mask of nodes to evaluate in form of ``N * 1`` torch bool tensor. Default: ``None``. 100 | 101 | 102 | Returns 103 | ------- 104 | f1 : float 105 | Average F1 score across different labels. 106 | 107 | """ 108 | 109 | if mask is not None: 110 | pred, labels = pred[mask], labels[mask] 111 | if pred is None or labels is None: 112 | return 0.0 113 | pred[pred > 0.5] = 1 114 | pred[pred <= 0.5] = 0 115 | tp = (labels * pred).sum().float() 116 | fp = ((1 - labels) * pred).sum().float() 117 | fn = (labels * (1 - pred)).sum().float() 118 | 119 | epsilon = 1e-7 120 | precision = tp / (tp + fp + epsilon) 121 | recall = tp / (tp + fn + epsilon) 122 | f1 = (2 * precision * recall) / (precision + recall + epsilon) 123 | f1 = f1.item() 124 | 125 | return f1 126 | 127 | 128 | def get_weights_arithmetic(n, w_1, order='a'): 129 | r""" 130 | 131 | Description 132 | ----------- 133 | Arithmetic weights for calculating weighted robust score. 134 | 135 | Parameters 136 | ---------- 137 | n : int 138 | Number of scores. 139 | w_1 : float 140 | Initial weight of the first term. 141 | order : str, optional 142 | ``a`` for ascending order, ``d`` for descending order. Default: ``a``. 143 | 144 | Returns 145 | ------- 146 | weights : list 147 | List of weights. 148 | 149 | """ 150 | 151 | weights = [] 152 | epsilon = 2 / (n - 1) * (1 / n - w_1) 153 | for i in range(1, n + 1): 154 | weights.append(w_1 + (i - 1) * epsilon) 155 | 156 | if 'd' in order: 157 | weights.reverse() 158 | 159 | return weights 160 | 161 | 162 | def get_weights_polynomial(n, p=2, order='a'): 163 | r""" 164 | 165 | Description 166 | ----------- 167 | Arithmetic weights for calculating weighted robust score. 168 | 169 | Parameters 170 | ---------- 171 | n : int 172 | Number of scores. 173 | p : float 174 | Power of denominator. 175 | order : str, optional 176 | ``a`` for ascending order, ``d`` for descending order. Default: ``a``. 177 | 178 | Returns 179 | ------- 180 | weights_norms : list 181 | List of normalized polynomial weights. 182 | 183 | """ 184 | 185 | weights = [] 186 | for i in range(1, n + 1): 187 | weights.append(1 / i ** p) 188 | weights_norm = [weights[i] / sum(weights) for i in range(n)] 189 | if 'a' in order: 190 | weights_norm = weights_norm[::-1] 191 | 192 | return weights_norm 193 | -------------------------------------------------------------------------------- /grb/model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /grb/model/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/cogdl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/cogdl/__init__.py -------------------------------------------------------------------------------- /grb/model/cogdl/gcn.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from cogdl.layers import GCNLayer 5 | 6 | 7 | class GCN(nn.Module): 8 | def __init__(self, 9 | in_features, 10 | out_features, 11 | hidden_features, 12 | activation=F.relu, 13 | layer_norm=False): 14 | 15 | super(GCN, self).__init__() 16 | self.layers = nn.ModuleList() 17 | if layer_norm: 18 | self.layers.append(nn.LayerNorm(in_features)) 19 | self.layers.append(GCNLayer(in_features, hidden_features[0], activation=activation)) 20 | for i in range(len(hidden_features) - 1): 21 | if layer_norm: 22 | self.layers.append(nn.LayerNorm(hidden_features[i])) 23 | self.layers.append( 24 | GCNLayer(hidden_features[i], hidden_features[i + 1], activation=activation)) 25 | self.layers.append(GCNLayer(hidden_features[-1], out_features)) 26 | 27 | @property 28 | def model_type(self): 29 | return "cogdl" 30 | 31 | def forward(self, x, adj, dropout=0): 32 | graph = dgl.from_scipy(adj).to(x.device) 33 | graph.ndata['features'] = x 34 | 35 | for i, layer in enumerate(self.layers): 36 | if isinstance(layer, nn.LayerNorm): 37 | x = layer(x) 38 | else: 39 | x = layer(graph, x) 40 | if i != len(self.layers) - 1: 41 | x = F.dropout(x, dropout) 42 | 43 | return x 44 | -------------------------------------------------------------------------------- /grb/model/dgl/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcn import GCN 2 | from .gat import GAT 3 | from .gatode import GATODE 4 | from .grand import GRAND -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gat.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gat.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gatode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gatode.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gatode.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gatode.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gcn.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/gcn.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/grand.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/grand.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/dgl/__pycache__/grand.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/dgl/__pycache__/grand.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/dgl/gat.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch import GATConv 5 | 6 | from grb.utils.normalize import GCNAdjNorm 7 | 8 | 9 | class GAT(nn.Module): 10 | r""" 11 | 12 | Description 13 | ----------- 14 | Graph Attention Networks (`GAT `__) 15 | 16 | Parameters 17 | ---------- 18 | in_features : int 19 | Dimension of input features. 20 | out_features : int 21 | Dimension of output features. 22 | hidden_features : int or list of int 23 | Dimension of hidden features. List if multi-layer. 24 | n_layers : int 25 | Number of layers. 26 | layer_norm : bool, optional 27 | Whether to use layer normalization. Default: ``False``. 28 | activation : func of torch.nn.functional, optional 29 | Activation function. Default: ``torch.nn.functional.leaky_relu``. 30 | feat_norm : str, optional 31 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 32 | adj_norm_func : func of utils.normalize, optional 33 | Function that normalizes adjacency matrix. Default: ``None``. 34 | feat_dropout : float, optional 35 | Dropout rate for input features. Default: ``0.0``. 36 | attn_dropout : float, optional 37 | Dropout rate for attention. Default: ``0.0``. 38 | residual : bool, optional 39 | Whether to use residual connection. Default: ``False``. 40 | dropout : float, optional 41 | Dropout rate during training. Default: ``0.0``. 42 | 43 | """ 44 | def __init__(self, 45 | in_features, 46 | out_features, 47 | hidden_features, 48 | n_layers, 49 | n_heads, 50 | activation=F.leaky_relu, 51 | layer_norm=False, 52 | feat_norm=None, 53 | adj_norm_func=None, 54 | feat_dropout=0.0, 55 | attn_dropout=0.0, 56 | residual=False, 57 | dropout=0.0): 58 | super(GAT, self).__init__() 59 | self.in_features = in_features 60 | self.out_features = out_features 61 | self.feat_norm = feat_norm 62 | self.adj_norm_func = adj_norm_func 63 | if type(hidden_features) is int: 64 | hidden_features = [hidden_features] * (n_layers - 1) 65 | elif type(hidden_features) is list or type(hidden_features) is tuple: 66 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 67 | n_features = [in_features] + hidden_features + [out_features] 68 | 69 | self.layers = nn.ModuleList() 70 | for i in range(n_layers): 71 | if layer_norm: 72 | if i == 0: 73 | self.layers.append(nn.LayerNorm(n_features[i])) 74 | else: 75 | self.layers.append(nn.LayerNorm(n_features[i] * n_heads)) 76 | self.layers.append(GATConv(in_feats=n_features[i] * n_heads if i != 0 else n_features[i], 77 | out_feats=n_features[i + 1], 78 | num_heads=n_heads if i != n_layers - 1 else 1, 79 | feat_drop=feat_dropout if i != n_layers - 1 else 0.0, 80 | attn_drop=attn_dropout if i != n_layers - 1 else 0.0, 81 | residual=residual if i != n_layers - 1 else False, 82 | activation=activation if i != n_layers - 1 else None)) 83 | if dropout > 0.0: 84 | self.dropout = nn.Dropout(dropout) 85 | else: 86 | self.dropout = None 87 | 88 | @property 89 | def model_type(self): 90 | return "dgl" 91 | 92 | @property 93 | def model_name(self): 94 | return "gat" 95 | 96 | def forward(self, x, adj): 97 | r""" 98 | 99 | Parameters 100 | ---------- 101 | x : torch.Tensor 102 | Tensor of input features. 103 | adj : torch.SparseTensor 104 | Sparse tensor of adjacency matrix. 105 | 106 | Returns 107 | ------- 108 | x : torch.Tensor 109 | Output of layer. 110 | 111 | """ 112 | 113 | graph = dgl.from_scipy(adj).to(x.device) 114 | graph = dgl.remove_self_loop(graph) 115 | graph = dgl.add_self_loop(graph) 116 | graph.ndata['features'] = x 117 | 118 | for i, layer in enumerate(self.layers): 119 | if isinstance(layer, nn.LayerNorm): 120 | x = layer(x) 121 | else: 122 | x = layer(graph, x).flatten(1) 123 | if i != len(self.layers) - 1: 124 | if self.dropout is not None: 125 | x = self.dropout(x) 126 | 127 | return x 128 | -------------------------------------------------------------------------------- /grb/model/dgl/gatode.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch import GATConv 5 | 6 | from grb.utils.normalize import GCNAdjNorm 7 | 8 | from torchdiffeq import odeint 9 | from torch.nn.utils import spectral_norm 10 | import torch 11 | import geotorch 12 | 13 | class ODEfunc_feat(nn.Module): 14 | def __init__(self, dim): 15 | super(ODEfunc_feat, self).__init__() 16 | self.fc = nn.Linear(dim, dim) 17 | #geotorch.low_rank(self.fc, "weight", dim-1) 18 | def forward(self, t, x): 19 | x = self.fc(x) 20 | return x 21 | 22 | class ODEBlock(nn.Module): 23 | def __init__(self, odefunc): 24 | super(ODEBlock, self).__init__() 25 | self.odefunc = odefunc 26 | self.integration_time = torch.tensor([0, 1]).float() 27 | def forward(self, x): 28 | self.integration_time = self.integration_time.type_as(x) 29 | out = odeint(self.odefunc, x, self.integration_time, rtol=1e-3, atol=1e-3) 30 | return out[1] 31 | 32 | class GATODE(nn.Module): 33 | r""" 34 | 35 | Description 36 | ----------- 37 | Graph Attention Networks (`GAT `__) 38 | 39 | Parameters 40 | ---------- 41 | in_features : int 42 | Dimension of input features. 43 | out_features : int 44 | Dimension of output features. 45 | hidden_features : int or list of int 46 | Dimension of hidden features. List if multi-layer. 47 | n_layers : int 48 | Number of layers. 49 | layer_norm : bool, optional 50 | Whether to use layer normalization. Default: ``False``. 51 | activation : func of torch.nn.functional, optional 52 | Activation function. Default: ``torch.nn.functional.leaky_relu``. 53 | feat_norm : str, optional 54 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 55 | adj_norm_func : func of utils.normalize, optional 56 | Function that normalizes adjacency matrix. Default: ``None``. 57 | feat_dropout : float, optional 58 | Dropout rate for input features. Default: ``0.0``. 59 | attn_dropout : float, optional 60 | Dropout rate for attention. Default: ``0.0``. 61 | residual : bool, optional 62 | Whether to use residual connection. Default: ``False``. 63 | dropout : float, optional 64 | Dropout rate during training. Default: ``0.0``. 65 | 66 | """ 67 | def __init__(self, 68 | in_features, 69 | out_features, 70 | hidden_features, 71 | n_layers, 72 | n_heads, 73 | activation=F.leaky_relu, 74 | layer_norm=False, 75 | feat_norm=None, 76 | adj_norm_func=None, 77 | feat_dropout=0.0, 78 | attn_dropout=0.0, 79 | residual=False, 80 | dropout=0.0): 81 | super(GATODE, self).__init__() 82 | self.in_features = in_features 83 | self.out_features = out_features 84 | self.feat_norm = feat_norm 85 | self.adj_norm_func = adj_norm_func 86 | self.odeblk = ODEBlock(ODEfunc_feat(out_features)) 87 | if type(hidden_features) is int: 88 | hidden_features = [hidden_features] * (n_layers - 1) 89 | elif type(hidden_features) is list or type(hidden_features) is tuple: 90 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 91 | n_features = [in_features] + hidden_features + [out_features] 92 | 93 | self.layers = nn.ModuleList() 94 | for i in range(n_layers): 95 | if layer_norm: 96 | if i == 0: 97 | self.layers.append(nn.LayerNorm(n_features[i])) 98 | else: 99 | self.layers.append(nn.LayerNorm(n_features[i] * n_heads)) 100 | self.layers.append(GATConv(in_feats=n_features[i] * n_heads if i != 0 else n_features[i], 101 | out_feats=n_features[i + 1], 102 | num_heads=n_heads if i != n_layers - 1 else 1, 103 | feat_drop=feat_dropout if i != n_layers - 1 else 0.0, 104 | attn_drop=attn_dropout if i != n_layers - 1 else 0.0, 105 | residual=residual if i != n_layers - 1 else False, 106 | activation=activation if i != n_layers - 1 else None)) 107 | if dropout > 0.0: 108 | self.dropout = nn.Dropout(dropout) 109 | else: 110 | self.dropout = None 111 | 112 | @property 113 | def model_type(self): 114 | return "dgl" 115 | 116 | @property 117 | def model_name(self): 118 | return "gat" 119 | 120 | def forward(self, x, adj): 121 | r""" 122 | 123 | Parameters 124 | ---------- 125 | x : torch.Tensor 126 | Tensor of input features. 127 | adj : torch.SparseTensor 128 | Sparse tensor of adjacency matrix. 129 | 130 | Returns 131 | ------- 132 | x : torch.Tensor 133 | Output of layer. 134 | 135 | """ 136 | 137 | graph = dgl.from_scipy(adj).to(x.device) 138 | graph = dgl.remove_self_loop(graph) 139 | graph = dgl.add_self_loop(graph) 140 | graph.ndata['features'] = x 141 | 142 | for i, layer in enumerate(self.layers): 143 | if isinstance(layer, nn.LayerNorm): 144 | x = layer(x) 145 | else: 146 | x = layer(graph, x).flatten(1) 147 | if i != len(self.layers) - 1: 148 | if self.dropout is not None: 149 | x = self.dropout(x) 150 | x = self.odeblk(x) 151 | return x 152 | -------------------------------------------------------------------------------- /grb/model/dgl/gcn.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from dgl.nn.pytorch import GraphConv 5 | 6 | 7 | class GCN(nn.Module): 8 | def __init__(self, 9 | in_features, 10 | out_features, 11 | hidden_features, 12 | activation=F.relu, 13 | layer_norm=False): 14 | 15 | super(GCN, self).__init__() 16 | self.layers = nn.ModuleList() 17 | if layer_norm: 18 | self.layers.append(nn.LayerNorm(in_features)) 19 | self.layers.append(GraphConv(in_features, hidden_features[0], activation=activation)) 20 | for i in range(len(hidden_features) - 1): 21 | if layer_norm: 22 | self.layers.append(nn.LayerNorm(hidden_features[i])) 23 | self.layers.append( 24 | GraphConv(hidden_features[i], hidden_features[i + 1], activation=activation)) 25 | self.layers.append(GraphConv(hidden_features[-1], out_features)) 26 | 27 | @property 28 | def model_type(self): 29 | return "dgl" 30 | 31 | def forward(self, x, adj, dropout=0): 32 | graph = dgl.from_scipy(adj).to(x.device) 33 | graph.ndata['features'] = x 34 | 35 | for i, layer in enumerate(self.layers): 36 | if isinstance(layer, nn.LayerNorm): 37 | x = layer(x) 38 | else: 39 | x = layer(graph, x) 40 | if i != len(self.layers) - 1: 41 | x = F.dropout(x, dropout) 42 | 43 | return x 44 | -------------------------------------------------------------------------------- /grb/model/dgl/gin.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from dgl.nn.pytorch.conv import GINConv 6 | from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling 7 | 8 | 9 | class ApplyNodeFunc(nn.Module): 10 | """Update the node feature hv with MLP, BN and ReLU.""" 11 | 12 | def __init__(self, mlp): 13 | super(ApplyNodeFunc, self).__init__() 14 | self.mlp = mlp 15 | self.bn = nn.BatchNorm1d(self.mlp.output_dim) 16 | 17 | def forward(self, h): 18 | h = self.mlp(h) 19 | h = self.bn(h) 20 | h = F.relu(h) 21 | return h 22 | 23 | 24 | class MLP(nn.Module): 25 | """MLP with linear output""" 26 | 27 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 28 | """MLP layers construction 29 | Paramters 30 | --------- 31 | num_layers: int 32 | The number of linear layers 33 | input_dim: int 34 | The dimensionality of input features 35 | hidden_dim: int 36 | The dimensionality of hidden units at ALL layers 37 | output_dim: int 38 | The number of classes for prediction 39 | """ 40 | super(MLP, self).__init__() 41 | self.linear_or_not = True # default is linear model 42 | self.num_layers = num_layers 43 | self.output_dim = output_dim 44 | 45 | if num_layers < 1: 46 | raise ValueError("number of layers should be positive!") 47 | elif num_layers == 1: 48 | # Linear model 49 | self.linear = nn.Linear(input_dim, output_dim) 50 | else: 51 | # Multi-layer model 52 | self.linear_or_not = False 53 | self.linears = torch.nn.ModuleList() 54 | self.batch_norms = torch.nn.ModuleList() 55 | 56 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 57 | for layer in range(num_layers - 2): 58 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 59 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 60 | 61 | for layer in range(num_layers - 1): 62 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 63 | 64 | def forward(self, x): 65 | if self.linear_or_not: 66 | # If linear model 67 | return self.linear(x) 68 | else: 69 | # If MLP 70 | h = x 71 | for i in range(self.num_layers - 1): 72 | h = F.relu(self.batch_norms[i](self.linears[i](h))) 73 | return self.linears[-1](h) 74 | 75 | 76 | class GIN(nn.Module): 77 | """GIN model""" 78 | 79 | def __init__(self, 80 | in_features, 81 | hidden_features, 82 | out_features, 83 | learn_eps=True, 84 | neighbor_pooling_type='sum', 85 | num_mlp_layers=1): 86 | super(GIN, self).__init__() 87 | self.learn_eps = learn_eps 88 | 89 | # List of MLPs 90 | self.layers = torch.nn.ModuleList() 91 | self.batch_norms = torch.nn.ModuleList() 92 | 93 | for i in range(len(hidden_features)): 94 | if i == 0: 95 | mlp = MLP(num_mlp_layers, in_features, hidden_features[i], hidden_features[i]) 96 | else: 97 | mlp = MLP(num_mlp_layers, hidden_features[i], hidden_features[i], hidden_features[i]) 98 | 99 | self.layers.append( 100 | GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) 101 | self.batch_norms.append(nn.BatchNorm1d(hidden_features[i])) 102 | 103 | self.linear1 = nn.Linear(hidden_features[-2], hidden_features[-1]) 104 | self.linear2 = nn.Linear(hidden_features[-1], out_features) 105 | 106 | @property 107 | def model_type(self): 108 | return "dgl" 109 | 110 | def forward(self, x, adj, dropout=0): 111 | graph = dgl.from_scipy(adj).to(x.device) 112 | graph.ndata['features'] = x 113 | 114 | for i in range(len(self.layers) - 1): 115 | x = self.layers[i](graph, x) 116 | x = self.batch_norms[i](x) 117 | x = F.relu(x) 118 | 119 | x = F.relu(self.linear1(x)) 120 | x = F.dropout(x, dropout) 121 | x = self.linear2(x) 122 | 123 | return x 124 | -------------------------------------------------------------------------------- /grb/model/dgl/grand.py: -------------------------------------------------------------------------------- 1 | import dgl 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import dgl.function as fn 6 | 7 | from grb.model.torch import MLP 8 | 9 | 10 | class GRAND(nn.Module): 11 | r""" 12 | 13 | Description 14 | ----------- 15 | Graph Random Neural Networks (`GRAND `__) 16 | 17 | 18 | Parameters 19 | ----------- 20 | in_features : int 21 | Dimension of input features. 22 | out_features : int 23 | Dimension of output features. 24 | hidden_features : int or list of int 25 | Dimension of hidden features. List if multi-layer. 26 | n_layers : int 27 | Number of layers. 28 | s: int 29 | Number of Augmentation samples 30 | k: int 31 | Number of Propagation Steps 32 | node_dropout: float 33 | Dropout rate on node features. 34 | input_dropout: float 35 | Dropout rate of the input layer of a MLP 36 | hidden_dropout: float 37 | Dropout rate of the hidden layer of a MLP 38 | """ 39 | 40 | def __init__(self, 41 | in_features, 42 | out_features, 43 | hidden_features, 44 | n_layers=2, 45 | s=1, 46 | k=3, 47 | temp=1.0, 48 | lam=1.0, 49 | feat_norm=None, 50 | adj_norm_func=None, 51 | node_dropout=0.0, 52 | input_dropout=0.0, 53 | hidden_dropout=0.0): 54 | super(GRAND, self).__init__() 55 | self.in_features = in_features 56 | self.out_features = out_features 57 | self.feat_norm = feat_norm 58 | self.adj_norm_func = adj_norm_func 59 | self.s = s 60 | self.k = k 61 | self.temp = temp 62 | self.lam = lam 63 | self.mlp = MLP(in_features=in_features, 64 | out_features=out_features, 65 | hidden_features=hidden_features, 66 | n_layers=n_layers, 67 | dropout=hidden_dropout) 68 | self.node_dropout = node_dropout 69 | if input_dropout > 0.0: 70 | self.input_dropout = nn.Dropout(input_dropout) 71 | else: 72 | self.input_dropout = None 73 | 74 | def forward(self, x, adj): 75 | graph = dgl.from_scipy(adj).to(x.device) 76 | graph.ndata['features'] = x 77 | 78 | if self.input_dropout is not None: 79 | x = self.input_dropout(x) 80 | if self.training: 81 | output_list = [] 82 | for s in range(self.s): 83 | x_drop = drop_node(x, self.node_dropout, training=True) 84 | x_drop = GRANDConv(graph, x_drop, hop=self.k) 85 | output_list.append(torch.log_softmax(self.mlp(x_drop), dim=-1)) 86 | 87 | return output_list 88 | else: 89 | x = GRANDConv(graph, x, self.k) 90 | x = self.mlp(x) 91 | 92 | return x 93 | 94 | @property 95 | def model_type(self): 96 | return "dgl" 97 | 98 | @property 99 | def model_name(self): 100 | return "grand" 101 | 102 | 103 | def drop_node(features, drop_rate, training=True): 104 | n = features.shape[0] 105 | drop_rates = torch.FloatTensor(np.ones(n) * drop_rate) 106 | 107 | if training: 108 | masks = torch.bernoulli(1. - drop_rates).unsqueeze(1) 109 | features = masks.to(features.device) * features 110 | 111 | return features 112 | 113 | 114 | def GRANDConv(graph, features, hop): 115 | r""" 116 | 117 | Parameters 118 | ----------- 119 | graph : dgl.Graph 120 | The input graph 121 | features : torch.Tensor 122 | Tensor of node features 123 | hop : int 124 | Propagation Steps 125 | 126 | """ 127 | 128 | with graph.local_scope(): 129 | degs = graph.in_degrees().float().clamp(min=1) 130 | norm = torch.pow(degs, -0.5).to(features.device).unsqueeze(1) 131 | 132 | graph.ndata['norm'] = norm 133 | graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight')) 134 | 135 | x = features 136 | y = 0.0 + features 137 | 138 | for i in range(hop): 139 | graph.ndata['h'] = x 140 | graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h')) 141 | x = graph.ndata.pop('h') 142 | y.add_(x) 143 | 144 | return y / (hop + 1) 145 | -------------------------------------------------------------------------------- /grb/model/torch/__init__.py: -------------------------------------------------------------------------------- 1 | """Module for implementing GNN models based on pure Torch""" 2 | from .appnp import APPNP 3 | from .gcn import GCN, GCNGC 4 | from .gin import GIN 5 | from .graphsage import GraphSAGE 6 | from .sgcn import SGCN 7 | from .tagcn import TAGCN 8 | from .mlp import MLP 9 | from .gcnode import GCNODE 10 | from .gcnode2 import GCNODE2 11 | from .beltrami import BELTRAMI 12 | from .beltrami2 import BELTRAMI2 13 | from .beltramii import BELTRAMII 14 | from .MeanCurv import MEANCURV 15 | from .heat import HEAT 16 | from .pLaplace import pLAPLACE -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/MeanCurv.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/MeanCurv.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/MeanCurv.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/MeanCurv.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/appnp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/appnp.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/appnp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/appnp.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/beltrami.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/beltrami.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/beltrami.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/beltrami.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/beltrami2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/beltrami2.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/beltrami2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/beltrami2.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/beltramii.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/beltramii.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcn.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcn.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcnode.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcnode.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcnode.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcnode.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcnode2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcnode2.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gcnode2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gcnode2.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gin.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gin.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/gin.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/gin.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/graphsage.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/graphsage.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/graphsage.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/graphsage.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/heat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/heat.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/heat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/heat.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/mlp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/mlp.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/mlp.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/pLaplace.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/pLaplace.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/pLaplace.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/pLaplace.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/sgcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/sgcn.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/sgcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/sgcn.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/tagcn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/tagcn.cpython-36.pyc -------------------------------------------------------------------------------- /grb/model/torch/__pycache__/tagcn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/model/torch/__pycache__/tagcn.cpython-38.pyc -------------------------------------------------------------------------------- /grb/model/torch/appnp.py: -------------------------------------------------------------------------------- 1 | """Torch module for APPNP.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from grb.utils.normalize import GCNAdjNorm 7 | 8 | 9 | class APPNP(nn.Module): 10 | r""" 11 | 12 | Description 13 | ----------- 14 | Approximated Personalized Propagation of Neural Predictions (`APPNP `__) 15 | 16 | Parameters 17 | ---------- 18 | in_features : int 19 | Dimension of input features. 20 | out_features : int 21 | Dimension of output features. 22 | hidden_features : int or list of int 23 | Dimension of hidden features. List if multi-layer. 24 | n_layers : int 25 | Number of layers. 26 | layer_norm : bool, optional 27 | Whether to use layer normalization. Default: ``False``. 28 | activation : func of torch.nn.functional, optional 29 | Activation function. Default: ``torch.nn.functional.relu``. 30 | feat_norm : str, optional 31 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 32 | adj_norm_func : func of utils.normalize, optional 33 | Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``. 34 | edge_drop : float, optional 35 | Rate of edge drop. 36 | alpha : float, optional 37 | Hyper-parameter, refer to original paper. Default: ``0.01``. 38 | k : int, optional 39 | Hyper-parameter, refer to original paper. Default: ``10``. 40 | dropout : float, optional 41 | Dropout rate during training. Default: ``0.0``. 42 | 43 | """ 44 | 45 | def __init__(self, 46 | in_features, 47 | out_features, 48 | hidden_features, 49 | n_layers, 50 | layer_norm=False, 51 | activation=F.relu, 52 | edge_drop=0.0, 53 | alpha=0.01, 54 | k=10, 55 | feat_norm=None, 56 | adj_norm_func=GCNAdjNorm, 57 | dropout=0.0): 58 | super(APPNP, self).__init__() 59 | self.in_features = in_features 60 | self.out_features = out_features 61 | self.feat_norm = feat_norm 62 | self.adj_norm_func = adj_norm_func 63 | if type(hidden_features) is int: 64 | hidden_features = [hidden_features] * (n_layers - 1) 65 | elif type(hidden_features) is list or type(hidden_features) is tuple: 66 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 67 | n_features = [in_features] + hidden_features + [out_features] 68 | 69 | self.layers = nn.ModuleList() 70 | for i in range(n_layers): 71 | if layer_norm: 72 | self.layers.append(nn.LayerNorm(n_features[i])) 73 | self.layers.append(nn.Linear(n_features[i], n_features[i + 1])) 74 | self.alpha = alpha 75 | self.k = k 76 | self.activation = activation 77 | if edge_drop > 0.0: 78 | self.edge_dropout = SparseEdgeDrop(edge_drop) 79 | else: 80 | self.edge_dropout = None 81 | if dropout > 0.0: 82 | self.dropout = nn.Dropout(dropout) 83 | else: 84 | self.dropout = None 85 | 86 | @property 87 | def model_type(self): 88 | """Indicate type of implementation.""" 89 | return "torch" 90 | 91 | @property 92 | def model_name(self): 93 | return "appnp" 94 | 95 | def reset_parameters(self): 96 | """Reset parameters.""" 97 | for layer in self.layers: 98 | layer.reset_parameters() 99 | 100 | def forward(self, x, adj): 101 | r""" 102 | 103 | Parameters 104 | ---------- 105 | x : torch.Tensor 106 | Tensor of input features. 107 | adj : torch.SparseTensor 108 | Sparse tensor of adjacency matrix. 109 | 110 | Returns 111 | ------- 112 | x : torch.Tensor 113 | Output of model (logits without activation). 114 | 115 | """ 116 | 117 | for layer in self.layers: 118 | if isinstance(layer, nn.LayerNorm): 119 | x = layer(x) 120 | else: 121 | x = layer(x) 122 | x = self.activation(x) 123 | if self.dropout is not None: 124 | x = self.dropout(x) 125 | for i in range(self.k): 126 | if self.edge_dropout is not None and self.training: 127 | adj = self.edge_dropout(adj) 128 | x = (1 - self.alpha) * torch.spmm(adj, x) + self.alpha * x 129 | 130 | return x 131 | 132 | 133 | class SparseEdgeDrop(nn.Module): 134 | r""" 135 | 136 | Description 137 | ----------- 138 | Sparse implementation of edge drop. 139 | 140 | Parameters 141 | ---------- 142 | edge_drop : float 143 | Rate of edge drop. 144 | 145 | """ 146 | 147 | def __init__(self, edge_drop): 148 | super(SparseEdgeDrop, self).__init__() 149 | self.edge_drop = edge_drop 150 | 151 | def forward(self, adj): 152 | """Sparse edge drop""" 153 | mask = ((torch.rand(adj._values().size()) + self.edge_drop) > 1.0) 154 | rc = adj._indices() 155 | val = adj._values().clone() 156 | val[mask] = 0.0 157 | 158 | return torch.sparse.FloatTensor(rc, val) 159 | -------------------------------------------------------------------------------- /grb/model/torch/gin.py: -------------------------------------------------------------------------------- 1 | """Torch module for GIN.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class GIN(nn.Module): 8 | r""" 9 | 10 | Description 11 | ----------- 12 | Graph Isomorphism Network (`GIN `__) 13 | 14 | Parameters 15 | ---------- 16 | in_features : int 17 | Dimension of input features. 18 | out_features : int 19 | Dimension of output features. 20 | hidden_features : int or list of int 21 | Dimension of hidden features. List if multi-layer. 22 | n_layers : int 23 | Number of layers. 24 | n_mlp_layers : int 25 | Number of layers. 26 | layer_norm : bool, optional 27 | Whether to use layer normalization. Default: ``False``. 28 | batch_norm : bool, optional 29 | Whether to apply batch normalization. Default: ``True``. 30 | eps : float, optional 31 | Hyper-parameter, refer to original paper. Default: ``0.0``. 32 | activation : func of torch.nn.functional, optional 33 | Activation function. Default: ``torch.nn.functional.relu``. 34 | feat_norm : str, optional 35 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 36 | adj_norm_func : func of utils.normalize, optional 37 | Function that normalizes adjacency matrix. Default: ``None``. 38 | dropout : float, optional 39 | Rate of dropout. Default: ``0.0``. 40 | 41 | """ 42 | 43 | def __init__(self, 44 | in_features, 45 | out_features, 46 | hidden_features, 47 | n_layers, 48 | n_mlp_layers=2, 49 | activation=F.relu, 50 | layer_norm=False, 51 | batch_norm=True, 52 | eps=0.0, 53 | feat_norm=None, 54 | adj_norm_func=None, 55 | dropout=0.0): 56 | super(GIN, self).__init__() 57 | self.in_features = in_features 58 | self.out_features = out_features 59 | self.feat_norm = feat_norm 60 | self.adj_norm_func = adj_norm_func 61 | self.activation = activation 62 | if type(hidden_features) is int: 63 | hidden_features = [hidden_features] * (n_layers - 1) 64 | elif type(hidden_features) is list or type(hidden_features) is tuple: 65 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 66 | n_features = [in_features] + hidden_features + [out_features] 67 | 68 | self.layers = nn.ModuleList() 69 | for i in range(n_layers - 1): 70 | if layer_norm: 71 | self.layers.append(nn.LayerNorm(n_features[i])) 72 | self.layers.append(GINConv(in_features=n_features[i], 73 | out_features=n_features[i + 1], 74 | batch_norm=batch_norm, 75 | eps=eps, 76 | activation=activation, 77 | dropout=dropout)) 78 | self.mlp_layers = nn.ModuleList() 79 | for i in range(n_mlp_layers): 80 | if i == n_mlp_layers - 1: 81 | self.mlp_layers.append(nn.Linear(hidden_features[-1], out_features)) 82 | else: 83 | self.mlp_layers.append(nn.Linear(hidden_features[-1], hidden_features[-1])) 84 | if dropout > 0.0: 85 | self.dropout = nn.Dropout(dropout) 86 | else: 87 | self.dropout = None 88 | self.reset_parameters() 89 | 90 | @property 91 | def model_type(self): 92 | """Indicate type of implementation.""" 93 | return "torch" 94 | 95 | @property 96 | def model_name(self): 97 | return "gin" 98 | 99 | def reset_parameters(self): 100 | """Reset parameters.""" 101 | for layer in self.layers: 102 | layer.reset_parameters() 103 | for layer in self.mlp_layers: 104 | layer.reset_parameters() 105 | 106 | def forward(self, x, adj): 107 | r""" 108 | 109 | Parameters 110 | ---------- 111 | x : torch.Tensor 112 | Tensor of input features. 113 | adj : torch.SparseTensor 114 | Sparse tensor of adjacency matrix. 115 | 116 | Returns 117 | ------- 118 | x : torch.Tensor 119 | Output of model (logits without activation). 120 | 121 | """ 122 | 123 | for layer in self.layers: 124 | if isinstance(layer, nn.LayerNorm): 125 | x = layer(x) 126 | else: 127 | x = layer(x, adj) 128 | 129 | for i, layer in enumerate(self.mlp_layers): 130 | x = layer(x) 131 | if i != len(self.mlp_layers) - 1: 132 | x = self.activation(x) 133 | if self.dropout is not None: 134 | x = self.dropout(x) 135 | 136 | return x 137 | 138 | 139 | class GINConv(nn.Module): 140 | r""" 141 | 142 | Description 143 | ----------- 144 | GIN convolutional layer. 145 | 146 | Parameters 147 | ---------- 148 | in_features : int 149 | Dimension of input features. 150 | out_features : int 151 | Dimension of output features. 152 | activation : func of torch.nn.functional, optional 153 | Activation function. Default: ``None``. 154 | eps : float, optional 155 | Hyper-parameter, refer to original paper. Default: ``0.0``. 156 | batch_norm : bool, optional 157 | Whether to apply batch normalization. Default: ``True``. 158 | dropout : float, optional 159 | Rate of dropout. Default: ``0.0``. 160 | 161 | """ 162 | 163 | def __init__(self, 164 | in_features, 165 | out_features, 166 | activation=F.relu, 167 | eps=0.0, 168 | batch_norm=True, 169 | dropout=0.0): 170 | super(GINConv, self).__init__() 171 | self.linear1 = nn.Linear(in_features, out_features) 172 | self.linear2 = nn.Linear(out_features, out_features) 173 | self.activation = activation 174 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 175 | self.batch_norm = batch_norm 176 | if batch_norm: 177 | self.norm = nn.BatchNorm1d(out_features) 178 | if dropout > 0.0: 179 | self.dropout = nn.Dropout(dropout) 180 | else: 181 | self.dropout = None 182 | self.reset_parameters() 183 | 184 | def reset_parameters(self): 185 | """Reset parameters.""" 186 | if self.activation == F.leaky_relu: 187 | gain = nn.init.calculate_gain('leaky_relu') 188 | else: 189 | gain = nn.init.calculate_gain('relu') 190 | nn.init.xavier_normal_(self.linear1.weight, gain=gain) 191 | nn.init.xavier_normal_(self.linear2.weight, gain=gain) 192 | 193 | def forward(self, x, adj): 194 | r""" 195 | 196 | Parameters 197 | ---------- 198 | x : torch.Tensor 199 | Tensor of input features. 200 | adj : torch.SparseTensor 201 | Sparse tensor of adjacency matrix. 202 | 203 | Returns 204 | ------- 205 | x : torch.Tensor 206 | Output of layer. 207 | 208 | """ 209 | 210 | y = torch.spmm(adj, x) 211 | x = y + (1 + self.eps) * x 212 | x = self.linear1(x) 213 | if self.activation is not None: 214 | x = self.activation(x) 215 | x = self.linear2(x) 216 | if self.batch_norm: 217 | x = self.norm(x) 218 | if self.activation is not None: 219 | x = self.activation(x) 220 | if self.dropout is not None: 221 | x = self.dropout(x) 222 | 223 | return x 224 | -------------------------------------------------------------------------------- /grb/model/torch/graphsage.py: -------------------------------------------------------------------------------- 1 | """Torch module for GraphSAGE.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from grb.utils.normalize import SAGEAdjNorm 7 | 8 | 9 | class GraphSAGE(nn.Module): 10 | r""" 11 | 12 | Description 13 | ----------- 14 | Inductive Representation Learning on Large Graphs (`GraphSAGE `__) 15 | 16 | Parameters 17 | ---------- 18 | in_features : int 19 | Dimension of input features. 20 | out_features : int 21 | Dimension of output features. 22 | n_layers : int 23 | Number of layers. 24 | hidden_features : int or list of int 25 | Dimension of hidden features. List if multi-layer. 26 | layer_norm : bool, optional 27 | Whether to use layer normalization. Default: ``False``. 28 | activation : func of torch.nn.functional, optional 29 | Activation function. Default: ``torch.nn.functional.relu``. 30 | feat_norm : str, optional 31 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 32 | adj_norm_func : func of utils.normalize, optional 33 | Function that normalizes adjacency matrix. Default: ``SAGEAdjNorm``. 34 | mu : float, optional 35 | Hyper-parameter, refer to original paper. Default: ``2.0``. 36 | dropout : float, optional 37 | Rate of dropout. Default: ``0.0``. 38 | 39 | """ 40 | 41 | def __init__(self, 42 | in_features, 43 | out_features, 44 | hidden_features, 45 | n_layers, 46 | activation=F.relu, 47 | layer_norm=False, 48 | feat_norm=None, 49 | adj_norm_func=SAGEAdjNorm, 50 | mu=2.0, 51 | dropout=0.0): 52 | super(GraphSAGE, self).__init__() 53 | self.in_features = in_features 54 | self.out_features = out_features 55 | self.feat_norm = feat_norm 56 | self.adj_norm_func = adj_norm_func 57 | if type(hidden_features) is int: 58 | hidden_features = [hidden_features] * (n_layers - 1) 59 | elif type(hidden_features) is list or type(hidden_features) is tuple: 60 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 61 | n_features = [in_features] + hidden_features + [out_features] 62 | 63 | self.layers = nn.ModuleList() 64 | for i in range(n_layers): 65 | if layer_norm: 66 | self.layers.append(nn.LayerNorm(n_features[i])) 67 | self.layers.append(SAGEConv(in_features=n_features[i], 68 | pool_features=n_features[i], 69 | out_features=n_features[i + 1], 70 | activation=activation if i != n_layers - 1 else None, 71 | mu=mu, 72 | dropout=dropout if i != n_layers - 1 else 0.0)) 73 | self.reset_parameters() 74 | 75 | @property 76 | def model_type(self): 77 | """Indicate type of implementation.""" 78 | return "torch" 79 | 80 | @property 81 | def model_name(self): 82 | return "graphsage" 83 | 84 | def reset_parameters(self): 85 | """Reset parameters.""" 86 | for layer in self.layers: 87 | layer.reset_parameters() 88 | 89 | def forward(self, x, adj): 90 | r""" 91 | 92 | Parameters 93 | ---------- 94 | x : torch.Tensor 95 | Tensor of input features. 96 | adj : torch.SparseTensor 97 | Sparse tensor of adjacency matrix. 98 | 99 | Returns 100 | ------- 101 | x : torch.Tensor 102 | Output of model (logits without activation). 103 | 104 | """ 105 | 106 | for layer in self.layers: 107 | if isinstance(layer, nn.LayerNorm): 108 | x = layer(x) 109 | else: 110 | x = F.normalize(x, dim=1) 111 | x = layer(x, adj) 112 | 113 | return x 114 | 115 | 116 | class SAGEConv(nn.Module): 117 | r""" 118 | 119 | Description 120 | ----------- 121 | SAGE convolutional layer. 122 | 123 | Parameters 124 | ---------- 125 | in_features : int 126 | Dimension of input features. 127 | pool_features : int 128 | Dimension of pooling features. 129 | out_features : int 130 | Dimension of output features. 131 | activation : func of torch.nn.functional, optional 132 | Activation function. Default: ``None``. 133 | dropout : float, optional 134 | Rate of dropout. Default: ``0.0``. 135 | mu : float, optional 136 | Hyper-parameter, refer to original paper. Default: ``2.0``. 137 | 138 | """ 139 | 140 | def __init__(self, 141 | in_features, 142 | pool_features, 143 | out_features, 144 | activation=None, 145 | mu=2.0, 146 | dropout=0.0): 147 | super(SAGEConv, self).__init__() 148 | self.pool_layer = nn.Linear(in_features, pool_features) 149 | self.linear1 = nn.Linear(pool_features, out_features) 150 | self.linear2 = nn.Linear(pool_features, out_features) 151 | self.activation = activation 152 | self.mu = mu 153 | 154 | if dropout > 0.0: 155 | self.dropout = nn.Dropout(dropout) 156 | else: 157 | self.dropout = None 158 | 159 | self.reset_parameters() 160 | 161 | def reset_parameters(self): 162 | """Reset parameters.""" 163 | if self.activation == F.leaky_relu: 164 | gain = nn.init.calculate_gain('leaky_relu') 165 | else: 166 | gain = nn.init.calculate_gain('relu') 167 | nn.init.xavier_normal_(self.linear1.weight, gain=gain) 168 | nn.init.xavier_normal_(self.linear2.weight, gain=gain) 169 | nn.init.xavier_normal_(self.pool_layer.weight, gain=gain) 170 | 171 | def forward(self, x, adj): 172 | r""" 173 | 174 | Parameters 175 | ---------- 176 | x : torch.Tensor 177 | Tensor of input features. 178 | adj : torch.SparseTensor 179 | Sparse tensor of adjacency matrix. 180 | 181 | 182 | Returns 183 | ------- 184 | x : torch.Tensor 185 | Output of layer. 186 | 187 | """ 188 | 189 | x = F.relu(self.pool_layer(x)) 190 | x_ = x ** self.mu 191 | x_ = torch.spmm(adj, x_) ** (1 / self.mu) 192 | 193 | # In original model this is actually max-pool, but **10/**0.1 result in gradient explosion. 194 | # However we can still achieve similar performance using 2-norm. 195 | x = self.linear1(x) 196 | x_ = self.linear2(x_) 197 | x = x + x_ 198 | if self.activation is not None: 199 | x = self.activation(x) 200 | if self.dropout is not None: 201 | x = self.dropout(x) 202 | 203 | return x 204 | -------------------------------------------------------------------------------- /grb/model/torch/mlp.py: -------------------------------------------------------------------------------- 1 | """Torch module for MLP.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MLP(nn.Module): 8 | r""" 9 | 10 | Description 11 | ----------- 12 | Multi-Layer Perceptron (without considering the graph structure). 13 | 14 | Parameters 15 | ---------- 16 | in_features : int 17 | Dimension of input features. 18 | out_features : int 19 | Dimension of output features. 20 | hidden_features : int or list of int 21 | Dimension of hidden features. List if multi-layer. 22 | n_layers : int 23 | Number of layers. 24 | layer_norm : bool, optional 25 | Whether to use layer normalization. Default: ``False``. 26 | activation : func of torch.nn.functional, optional 27 | Activation function. Default: ``torch.nn.functional.relu``. 28 | dropout : float, optional 29 | Dropout rate during training. Default: ``0.0``. 30 | 31 | """ 32 | 33 | def __init__(self, 34 | in_features, 35 | out_features, 36 | hidden_features, 37 | n_layers, 38 | activation=F.relu, 39 | feat_norm=None, 40 | adj_norm_func=None, 41 | batch_norm=False, 42 | layer_norm=False, 43 | dropout=0.0): 44 | super(MLP, self).__init__() 45 | self.in_features = in_features 46 | self.out_features = out_features 47 | self.feat_norm = feat_norm 48 | self.adj_norm_func = adj_norm_func 49 | if type(hidden_features) is int: 50 | hidden_features = [hidden_features] * (n_layers - 1) 51 | elif type(hidden_features) is list or type(hidden_features) is tuple: 52 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 53 | n_features = [in_features] + hidden_features + [out_features] 54 | 55 | self.layers = nn.ModuleList() 56 | for i in range(n_layers): 57 | if layer_norm: 58 | self.layers.append(nn.LayerNorm(n_features[i])) 59 | self.layers.append(MLPLayer(in_features=n_features[i], 60 | out_features=n_features[i + 1], 61 | activation=activation if i != n_layers - 1 else None, 62 | batch_norm=batch_norm if i != n_layers - 1 else None, 63 | dropout=dropout if i != n_layers - 1 else 0.0)) 64 | self.reset_parameters() 65 | 66 | @property 67 | def model_type(self): 68 | """Indicate type of implementation.""" 69 | return "torch" 70 | 71 | @property 72 | def model_name(self): 73 | return "mlp" 74 | 75 | def reset_parameters(self): 76 | """Reset parameters.""" 77 | for layer in self.layers: 78 | layer.reset_parameters() 79 | 80 | def forward(self, x, adj=None): 81 | r""" 82 | 83 | Parameters 84 | ---------- 85 | x : torch.Tensor 86 | Tensor of input features. 87 | adj : torch.SparseTensor 88 | Sparse tensor of adjacency matrix. 89 | 90 | Returns 91 | ------- 92 | x : torch.Tensor 93 | Output of model (logits without activation). 94 | 95 | """ 96 | 97 | for layer in self.layers: 98 | if isinstance(layer, nn.LayerNorm): 99 | x = layer(x) 100 | else: 101 | x = layer(x, adj) 102 | 103 | return x 104 | 105 | 106 | class MLPLayer(nn.Module): 107 | r""" 108 | 109 | Description 110 | ----------- 111 | MLP layer. 112 | 113 | Parameters 114 | ---------- 115 | in_features : int 116 | Dimension of input features. 117 | out_features : int 118 | Dimension of output features. 119 | activation : func of torch.nn.functional, optional 120 | Activation function. Default: ``None``. 121 | dropout : float, optional 122 | Rate of dropout. Default: ``0.0``. 123 | 124 | """ 125 | 126 | def __init__(self, in_features, out_features, activation=None, batch_norm=False, dropout=0.0): 127 | super(MLPLayer, self).__init__() 128 | self.in_features = in_features 129 | self.out_features = out_features 130 | self.linear = nn.Linear(in_features, out_features) 131 | self.activation = activation 132 | if dropout > 0.0: 133 | self.dropout = nn.Dropout(dropout) 134 | else: 135 | self.dropout = None 136 | if batch_norm: 137 | self.batch_norm = nn.BatchNorm1d(out_features) 138 | else: 139 | self.batch_norm = None 140 | self.reset_parameters() 141 | 142 | def reset_parameters(self): 143 | """Reset parameters.""" 144 | if self.activation == F.leaky_relu: 145 | gain = nn.init.calculate_gain('leaky_relu') 146 | else: 147 | gain = nn.init.calculate_gain('relu') 148 | nn.init.xavier_normal_(self.linear.weight, gain=gain) 149 | 150 | def forward(self, x, adj=None): 151 | r""" 152 | 153 | Parameters 154 | ---------- 155 | x : torch.Tensor 156 | Tensor of input features. 157 | adj : torch.SparseTensor 158 | Sparse tensor of adjacency matrix. 159 | 160 | Returns 161 | ------- 162 | x : torch.Tensor 163 | Output of layer. 164 | 165 | """ 166 | 167 | x = self.linear(x) 168 | if hasattr(self, 'batch_norm'): 169 | if self.batch_norm is not None: 170 | x = self.batch_norm(x) 171 | if self.activation is not None: 172 | x = self.activation(x) 173 | if self.dropout is not None: 174 | x = self.dropout(x) 175 | 176 | return x 177 | -------------------------------------------------------------------------------- /grb/model/torch/sgcn.py: -------------------------------------------------------------------------------- 1 | """Torch module for SGCN.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from grb.utils.normalize import GCNAdjNorm 7 | 8 | 9 | class SGCN(nn.Module): 10 | r""" 11 | 12 | Description 13 | ----------- 14 | Simplifying Graph Convolutional Networks (`SGCN `__) 15 | 16 | Parameters 17 | ---------- 18 | in_features : int 19 | Dimension of input features. 20 | out_features : int 21 | Dimension of output features. 22 | hidden_features : int or list of int 23 | Dimension of hidden features. List if multi-layer. 24 | n_layers : int 25 | Number of layers. 26 | layer_norm : bool, optional 27 | Whether to use layer normalization. Default: ``False``. 28 | activation : func of torch.nn.functional, optional 29 | Activation function. Default: ``torch.tanh``. 30 | k : int, optional 31 | Hyper-parameter, refer to original paper. Default: ``4``. 32 | feat_norm : str, optional 33 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 34 | adj_norm_func : func of utils.normalize, optional 35 | Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``. 36 | dropout : float, optional 37 | Rate of dropout. Default: ``0.0``. 38 | 39 | """ 40 | 41 | def __init__(self, 42 | in_features, 43 | out_features, 44 | hidden_features, 45 | n_layers, 46 | activation=torch.tanh, 47 | feat_norm=None, 48 | adj_norm_func=GCNAdjNorm, 49 | layer_norm=False, 50 | batch_norm=False, 51 | k=4, 52 | dropout=0.0): 53 | super(SGCN, self).__init__() 54 | self.in_features = in_features 55 | self.out_features = out_features 56 | self.feat_norm = feat_norm 57 | self.adj_norm_func = adj_norm_func 58 | if type(hidden_features) is int: 59 | hidden_features = [hidden_features] * (n_layers - 1) 60 | elif type(hidden_features) is list or type(hidden_features) is tuple: 61 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 62 | 63 | if batch_norm: 64 | self.batch_norm = nn.BatchNorm1d(in_features) 65 | else: 66 | self.batch_norm = None 67 | self.in_conv = nn.Linear(in_features, hidden_features[0]) 68 | self.out_conv = nn.Linear(hidden_features[-1], out_features) 69 | self.activation = activation 70 | self.layers = nn.ModuleList() 71 | 72 | for i in range(n_layers - 2): 73 | if layer_norm: 74 | self.layers.append(nn.LayerNorm(hidden_features[i])) 75 | self.layers.append(SGConv(in_features=hidden_features[i], 76 | out_features=hidden_features[i + 1], 77 | k=k)) 78 | 79 | if dropout > 0.0: 80 | self.dropout = nn.Dropout(dropout) 81 | else: 82 | self.dropout = None 83 | 84 | @property 85 | def model_type(self): 86 | """Indicate type of implementation.""" 87 | return "torch" 88 | 89 | @property 90 | def model_name(self): 91 | return "sgcn" 92 | 93 | def forward(self, x, adj): 94 | r""" 95 | 96 | Parameters 97 | ---------- 98 | x : torch.Tensor 99 | Tensor of input features. 100 | adj : torch.SparseTensor 101 | Sparse tensor of adjacency matrix. 102 | 103 | Returns 104 | ------- 105 | x : torch.Tensor 106 | Output of model (logits without activation). 107 | 108 | """ 109 | if self.batch_norm is not None: 110 | x = self.batch_norm(x) 111 | x = self.in_conv(x) 112 | x = F.relu(x) 113 | if self.dropout is not None: 114 | x = self.dropout(x) 115 | 116 | for layer in self.layers: 117 | if isinstance(layer, nn.LayerNorm): 118 | x = layer(x) 119 | else: 120 | x = layer(x, adj) 121 | if self.activation is not None: 122 | x = self.activation(x) 123 | if self.dropout is not None: 124 | x = self.dropout(x) 125 | x = self.out_conv(x) 126 | 127 | return x 128 | 129 | 130 | class SGConv(nn.Module): 131 | r""" 132 | 133 | Description 134 | ----------- 135 | SGCN convolutional layer. 136 | 137 | Parameters 138 | ---------- 139 | in_features : int 140 | Dimension of input features. 141 | out_features : int 142 | Dimension of output features. 143 | k : int, optional 144 | Hyper-parameter, refer to original paper. Default: ``4``. 145 | 146 | Returns 147 | ------- 148 | x : torch.Tensor 149 | Output of layer. 150 | 151 | """ 152 | 153 | def __init__(self, in_features, out_features, k): 154 | super(SGConv, self).__init__() 155 | self.in_features = in_features 156 | self.out_features = out_features 157 | self.linear = nn.Linear(in_features, out_features) 158 | self.k = k 159 | 160 | def forward(self, x, adj): 161 | r""" 162 | 163 | Parameters 164 | ---------- 165 | x : torch.Tensor 166 | Tensor of input features. 167 | adj : torch.SparseTensor 168 | Sparse tensor of adjacency matrix. 169 | 170 | Returns 171 | ------- 172 | x : torch.Tensor 173 | Output of layer. 174 | 175 | """ 176 | 177 | for i in range(self.k): 178 | x = torch.spmm(adj, x) 179 | x = self.linear(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /grb/model/torch/tagcn.py: -------------------------------------------------------------------------------- 1 | """Torch module for TAGCN.""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from grb.utils.normalize import GCNAdjNorm 7 | 8 | 9 | class TAGCN(nn.Module): 10 | r""" 11 | 12 | Description 13 | ----------- 14 | Topological Adaptive Graph Convolutional Networks (`TAGCN `__) 15 | 16 | Parameters 17 | ---------- 18 | in_features : int 19 | Dimension of input features. 20 | out_features : int 21 | Dimension of output features. 22 | hidden_features : int or list of int 23 | Dimension of hidden features. List if multi-layer. 24 | n_layers : int 25 | Number of layers. 26 | k : int 27 | Hyper-parameter, k-hop adjacency matrix, refer to original paper. 28 | layer_norm : bool, optional 29 | Whether to use layer normalization. Default: ``False``. 30 | batch_norm : bool, optional 31 | Whether to apply batch normalization. Default: ``False``. 32 | activation : func of torch.nn.functional, optional 33 | Activation function. Default: ``torch.nn.functional.leaky_relu``. 34 | feat_norm : str, optional 35 | Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. 36 | adj_norm_func : func of utils.normalize, optional 37 | Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``. 38 | dropout : float, optional 39 | Rate of dropout. Default: ``0.0``. 40 | 41 | """ 42 | 43 | def __init__(self, 44 | in_features, 45 | out_features, 46 | hidden_features, 47 | n_layers, 48 | k, 49 | activation=F.leaky_relu, 50 | feat_norm=None, 51 | adj_norm_func=GCNAdjNorm, 52 | layer_norm=False, 53 | batch_norm=False, 54 | dropout=0.0): 55 | super(TAGCN, self).__init__() 56 | self.in_features = in_features 57 | self.out_features = out_features 58 | self.feat_norm = feat_norm 59 | self.adj_norm_func = adj_norm_func 60 | if type(hidden_features) is int: 61 | hidden_features = [hidden_features] * (n_layers - 1) 62 | elif type(hidden_features) is list or type(hidden_features) is tuple: 63 | assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." 64 | n_features = [in_features] + hidden_features + [out_features] 65 | 66 | self.layers = nn.ModuleList() 67 | for i in range(n_layers): 68 | if layer_norm: 69 | self.layers.append(nn.LayerNorm(n_features[i])) 70 | self.layers.append(TAGConv(in_features=n_features[i], 71 | out_features=n_features[i + 1], 72 | k=k, 73 | batch_norm=batch_norm if i != n_layers - 1 else False, 74 | activation=activation if i != n_layers - 1 else None, 75 | dropout=dropout if i != n_layers - 1 else 0.0)) 76 | self.reset_parameters() 77 | 78 | @property 79 | def model_type(self): 80 | """Indicate type of implementation.""" 81 | return "torch" 82 | 83 | @property 84 | def model_name(self): 85 | return "tagcn" 86 | 87 | def reset_parameters(self): 88 | """Reset paramters.""" 89 | for layer in self.layers: 90 | layer.reset_parameters() 91 | 92 | def forward(self, x, adj): 93 | r""" 94 | 95 | Parameters 96 | ---------- 97 | x : torch.Tensor 98 | Tensor of input features. 99 | adj : torch.SparseTensor 100 | Sparse tensor of adjacency matrix. 101 | 102 | Returns 103 | ------- 104 | x : torch.Tensor 105 | Output of model (logits without activation). 106 | 107 | """ 108 | 109 | for layer in self.layers: 110 | if isinstance(layer, nn.LayerNorm): 111 | x = layer(x) 112 | else: 113 | x = layer(x, adj) 114 | 115 | return x 116 | 117 | 118 | class TAGConv(nn.Module): 119 | r""" 120 | 121 | Description 122 | ----------- 123 | TAGCN convolutional layer. 124 | 125 | Parameters 126 | ---------- 127 | in_features : int 128 | Dimension of input features. 129 | out_features : int 130 | Dimension of output features. 131 | k : int, optional 132 | Hyper-parameter, refer to original paper. Default: ``2``. 133 | activation : func of torch.nn.functional, optional 134 | Activation function. Default: ``None``. 135 | 136 | batch_norm : bool, optional 137 | Whether to apply batch normalization. Default: ``False``. 138 | dropout : float, optional 139 | Rate of dropout. Default: ``0.0``. 140 | 141 | """ 142 | 143 | def __init__(self, 144 | in_features, 145 | out_features, 146 | k=2, 147 | activation=None, 148 | batch_norm=False, 149 | dropout=0.0): 150 | super(TAGConv, self).__init__() 151 | self.in_features = in_features 152 | self.out_features = out_features 153 | self.linear = nn.Linear(in_features * (k + 1), out_features) 154 | self.batch_norm = batch_norm 155 | if batch_norm: 156 | self.norm_func = nn.BatchNorm1d(out_features, affine=False) 157 | self.activation = activation 158 | if dropout > 0.0: 159 | self.dropout = nn.Dropout(dropout) 160 | else: 161 | self.dropout = None 162 | self.k = k 163 | self.reset_parameters() 164 | 165 | def reset_parameters(self): 166 | """Reset parameters.""" 167 | if self.activation == F.leaky_relu: 168 | gain = nn.init.calculate_gain('leaky_relu') 169 | else: 170 | gain = nn.init.calculate_gain('relu') 171 | nn.init.xavier_normal_(self.linear.weight, gain=gain) 172 | 173 | def forward(self, x, adj): 174 | r""" 175 | 176 | Parameters 177 | ---------- 178 | x : torch.Tensor 179 | Tensor of input features. 180 | adj : torch.SparseTensor 181 | Sparse tensor of adjacency matrix. 182 | 183 | Returns 184 | ------- 185 | x : torch.Tensor 186 | Output of layer. 187 | 188 | """ 189 | 190 | fstack = [x] 191 | for i in range(self.k): 192 | y = torch.spmm(adj, fstack[-1]) 193 | fstack.append(y) 194 | x = torch.cat(fstack, dim=-1) 195 | x = self.linear(x) 196 | if self.batch_norm: 197 | x = self.norm_func(x) 198 | if self.activation is not None: 199 | x = self.activation(x) 200 | if self.dropout is not None: 201 | x = self.dropout(x) 202 | 203 | return x 204 | -------------------------------------------------------------------------------- /grb/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer, GraphTrainer, AutoTrainer 2 | from .trainer2 import Trainer -------------------------------------------------------------------------------- /grb/trainer/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer.cpython-36.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer.cpython-38.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer2.cpython-36.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer2.cpython-38.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer_helper.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer_helper.cpython-36.pyc -------------------------------------------------------------------------------- /grb/trainer/__pycache__/trainer_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/trainer/__pycache__/trainer_helper.cpython-38.pyc -------------------------------------------------------------------------------- /grb/trainer/trainer_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def consistency_loss(logits, temp, lam): 5 | ps = [torch.exp(p) for p in logits] 6 | ps = torch.stack(ps, dim=2) 7 | 8 | avg_p = torch.mean(ps, dim=2) 9 | sharp_p = (torch.pow(avg_p, 1. / temp) / torch.sum(torch.pow(avg_p, 1. / temp), dim=1, keepdim=True)).detach() 10 | 11 | sharp_p = sharp_p.unsqueeze(2) 12 | loss = torch.mean(torch.sum(torch.pow(ps - sharp_p, 2), dim=1, keepdim=True)) 13 | 14 | loss = lam * loss 15 | 16 | return loss 17 | -------------------------------------------------------------------------------- /grb/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import Logger 2 | from . import normalize 3 | from .utils import * -------------------------------------------------------------------------------- /grb/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/normalize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/normalize.cpython-36.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/normalize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/normalize.cpython-38.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /grb/utils/__pycache__/utils2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/grb/utils/__pycache__/utils2.cpython-38.pyc -------------------------------------------------------------------------------- /grb/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | 5 | class Logger(object): 6 | def __init__(self, file_dir="./logs", file_name="default.out", stream=sys.stdout): 7 | self.terminal = stream 8 | if not os.path.exists(file_dir): 9 | os.makedirs(file_dir) 10 | self.log = open(os.path.join(file_dir, file_name), "a") 11 | 12 | def write(self, message): 13 | self.terminal.write(message) 14 | self.log.write(message) 15 | 16 | def flush(self): 17 | self.log.flush() 18 | -------------------------------------------------------------------------------- /grb/utils/normalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.sparse as sp 4 | 5 | 6 | def GCNAdjNorm(adj, order=-0.5): 7 | r""" 8 | 9 | Description 10 | ----------- 11 | Normalization of adjacency matrix proposed in `GCN `__. 12 | 13 | Parameters 14 | ---------- 15 | adj : scipy.sparse.csr.csr_matrix or torch.FloatTensor 16 | Adjacency matrix in form of ``N * N`` sparse matrix (or in form of ``N * N`` dense tensor). 17 | order : float, optional 18 | Order of degree matrix. Default: ``-0.5``. 19 | 20 | 21 | Returns 22 | ------- 23 | adj : scipy.sparse.csr.csr_matrix 24 | Normalized adjacency matrix in form of ``N * N`` sparse matrix. 25 | 26 | """ 27 | if sp.issparse(adj): 28 | adj = sp.eye(adj.shape[0]) + adj 29 | adj.data[np.where((adj.data > 0) * (adj.data == 1))[0]] = 1 30 | adj = sp.coo_matrix(adj) 31 | rowsum = np.array(adj.sum(1)) 32 | d_inv = np.power(rowsum, order).flatten() 33 | d_inv[np.isinf(d_inv)] = 0. 34 | d_mat_inv = sp.diags(d_inv) 35 | adj = d_mat_inv @ adj @ d_mat_inv 36 | else: 37 | rowsum = torch.sparse.mm(adj, torch.ones((adj.shape[0], 1), device=adj.device)) + 1 38 | d_inv = torch.pow(rowsum, order).flatten() 39 | d_inv[torch.isinf(d_inv)] = 0. 40 | 41 | self_loop_idx = torch.stack(( 42 | torch.arange(adj.shape[0], device=adj.device), 43 | torch.arange(adj.shape[0], device=adj.device) 44 | )) 45 | self_loop_val = torch.ones_like(self_loop_idx[0], dtype=adj.dtype) 46 | indices = torch.cat((self_loop_idx, adj.indices()), dim=1) 47 | values = torch.cat((self_loop_val, adj.values())) 48 | values = d_inv[indices[0]] * values * d_inv[indices[1]] 49 | adj = torch.sparse.FloatTensor(indices, values, adj.shape).coalesce() 50 | 51 | return adj 52 | 53 | 54 | def SAGEAdjNorm(adj, order=-1): 55 | r""" 56 | 57 | Description 58 | ----------- 59 | Normalization of adjacency matrix proposed in `GraphSAGE `__. 60 | 61 | Parameters 62 | ---------- 63 | adj : scipy.sparse.csr.csr_matrix 64 | Adjacency matrix in form of ``N * N`` sparse matrix. 65 | order : float, optional 66 | Order of degree matrix. Default: ``-0.5``. 67 | 68 | 69 | Returns 70 | ------- 71 | adj : scipy.sparse.csr.csr_matrix 72 | Normalized adjacency matrix in form of ``N * N`` sparse matrix. 73 | 74 | """ 75 | if sp.issparse(adj): 76 | adj = sp.eye(adj.shape[0]) + adj 77 | for i in range(len(adj.data)): 78 | if adj.data[i] > 0 and adj.data[i] != 1: 79 | adj.data[i] = 1 80 | if adj.data[i] < 0: 81 | adj.data[i] = 0 82 | adj.eliminate_zeros() 83 | adj = sp.coo_matrix(adj) 84 | if order == 0: 85 | return adj.tocoo() 86 | rowsum = np.array(adj.sum(1)) 87 | d_inv = np.power(rowsum, order).flatten() 88 | d_inv[np.isinf(d_inv)] = 0. 89 | d_mat_inv = sp.diags(d_inv) 90 | adj = d_mat_inv @ adj 91 | else: 92 | adj = torch.eye(adj.shape[0]).to(adj.device) + adj 93 | rowsum = adj.sum(1) 94 | d_inv = torch.pow(rowsum, order).flatten() 95 | d_inv[torch.isinf(d_inv)] = 0. 96 | d_mat_inv = torch.diag(d_inv) 97 | adj = d_mat_inv @ adj 98 | 99 | return adj 100 | 101 | 102 | def SPARSEAdjNorm(adj, order=-0.5): 103 | r""" 104 | 105 | Description 106 | ----------- 107 | Normalization of adjacency matrix proposed in `GCN `__. 108 | 109 | Parameters 110 | ---------- 111 | adj : scipy.sparse.csr.csr_matrix or torch.FloatTensor 112 | Adjacency matrix in form of ``N * N`` sparse matrix (or in form of ``N * N`` dense tensor). 113 | order : float, optional 114 | Order of degree matrix. Default: ``-0.5``. 115 | 116 | 117 | Returns 118 | ------- 119 | adj : scipy.sparse.csr.csr_matrix 120 | Normalized adjacency matrix in form of ``N * N`` sparse matrix. 121 | 122 | """ 123 | if sp.issparse(adj): 124 | adj = sp.eye(adj.shape[0]) + adj 125 | adj.data[np.where((adj.data > 0) * (adj.data == 1))[0]] = 1 126 | adj = sp.coo_matrix(adj) 127 | rowsum = np.array(adj.sum(1)) 128 | d_inv = np.power(rowsum, order).flatten() 129 | d_inv[np.isinf(d_inv)] = 0. 130 | d_mat_inv = sp.diags(d_inv) 131 | adj = d_mat_inv @ adj @ d_mat_inv 132 | else: 133 | rowsum = torch.sparse.mm(adj, torch.ones((adj.shape[0], 1), device=adj.device)) + 1 134 | d_inv = torch.pow(rowsum, order).flatten() 135 | d_inv[torch.isinf(d_inv)] = 0. 136 | 137 | self_loop_idx = torch.stack(( 138 | torch.arange(adj.shape[0], device=adj.device), 139 | torch.arange(adj.shape[0], device=adj.device) 140 | )) 141 | self_loop_val = torch.ones_like(self_loop_idx[0], dtype=adj.dtype) 142 | indices = torch.cat((self_loop_idx, adj.indices()), dim=1) 143 | values = torch.cat((self_loop_val, adj.values())) 144 | values = d_inv[indices[0]] * values * d_inv[indices[1]] 145 | adj = torch.sparse.FloatTensor(indices, values, adj.shape).coalesce() 146 | 147 | return adj 148 | 149 | 150 | def RobustGCNAdjNorm(adj): 151 | r""" 152 | 153 | Description 154 | ----------- 155 | Normalization of adjacency matrix proposed in `RobustGCN `__. 156 | 157 | Parameters 158 | ---------- 159 | adj : tuple of scipy.sparse.csr.csr_matrix 160 | Tuple of adjacency matrix in form of ``N * N`` sparse matrix. 161 | 162 | Returns 163 | ------- 164 | adj0 : scipy.sparse.csr.csr_matrix 165 | Adjacency matrix in form of ``N * N`` sparse matrix. 166 | adj1 : scipy.sparse.csr.csr_matrix 167 | Adjacency matrix in form of ``N * N`` sparse matrix. 168 | 169 | """ 170 | adj0 = GCNAdjNorm(adj, order=-0.5) 171 | adj1 = GCNAdjNorm(adj, order=-1) 172 | 173 | return adj0, adj1 174 | 175 | 176 | def feature_normalize(features): 177 | x_sum = torch.sum(features, dim=1) 178 | x_rev = x_sum.pow(-1).flatten() 179 | x_rev[torch.isnan(x_rev)] = 0.0 180 | x_rev[torch.isinf(x_rev)] = 0.0 181 | features = features * x_rev.unsqueeze(-1).expand_as(features) 182 | 183 | return features 184 | -------------------------------------------------------------------------------- /grb/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_graph(adj, pos, labels, nodelist=None, figsize=(12, 12), title=None): 6 | graph = nx.from_scipy_sparse_matrix(adj) 7 | plt.figure(figsize=figsize) 8 | plt.axis('off') 9 | if title is not None: 10 | plt.title(title) 11 | nx.draw_networkx(graph, pos=pos, nodelist=nodelist, node_size=50, cmap=plt.get_cmap('coolwarm'), 12 | node_color=labels, edge_color='k', 13 | arrows=False, width=1, style='dotted', with_labels=False) 14 | plt.savefig("./images/{}.png".format(title)) 15 | plt.show() 16 | -------------------------------------------------------------------------------- /hyperbolic_distances.py: -------------------------------------------------------------------------------- 1 | import time 2 | from scipy.spatial.distance import squareform, pdist 3 | import numpy as np 4 | import argparse 5 | import pickle 6 | 7 | def hyperbolize(x): 8 | n = pdist(x.detach().numpy(), "sqeuclidean") 9 | MACHINE_EPSILON = np.finfo(np.double).eps 10 | m = squareform(n) 11 | qsqr = np.sum(x ** 2, axis=1) 12 | divisor = np.maximum(1 - qsqr[:, np.newaxis], MACHINE_EPSILON) * np.maximum(1 - qsqr[np.newaxis, :], MACHINE_EPSILON) 13 | m = np.arccosh(1 + 2 * m / divisor ) #** 2 14 | return m 15 | 16 | def main(opt): 17 | dataset = opt['dataset'] 18 | for emb_dim in [16, 8, 4, 2]: 19 | with open(f"../data/pos_encodings/{dataset}_HYPS{emb_dim:02d}.pkl", "rb") as f: 20 | emb = pickle.load(f) 21 | t = time.time() 22 | sqdist = pdist(emb.detach().numpy(), "sqeuclidean") 23 | distances_ = hyperbolize(emb.detach().numpy(), sqdist) 24 | print("Distances calculated in %.2f sec" % (time.time()-t)) 25 | #with open(f"../data/pos_encodings/{dataset}_HYPS{emb_dim:02d}_dists.pkl", "wb") as f: 26 | # pickle.dump(distances, f) 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--dataset', type=str, default='ALL', 30 | help='Cora, Citeseer, Pubmed, Computers, Photo, CoauthorCS') 31 | args = parser.parse_args() 32 | opt = vars(args) 33 | main(opt) -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | class Logger(object): 2 | def __init__(self, runs, info=None): 3 | self.info = info 4 | self.results = [[] for _ in range(runs)] 5 | 6 | def add_result(self, run, result): 7 | assert len(result) == 3 8 | assert run >= 0 and run < len(self.results) 9 | self.results[run].append(result) 10 | 11 | def print_statistics(self, run=None): 12 | if run is not None: 13 | result = 100 * torch.tensor(self.results[run]) 14 | argmax = result[:, 1].argmax().item() 15 | print(f'Run {run + 1:02d}:') 16 | print(f'Highest Train: {result[:, 0].max():.2f}') 17 | print(f'Highest Valid: {result[:, 1].max():.2f}') 18 | print(f' Final Train: {result[argmax, 0]:.2f}') 19 | print(f' Final Test: {result[argmax, 2]:.2f}') 20 | else: 21 | result = 100 * torch.tensor(self.results) 22 | 23 | best_results = [] 24 | for r in result: 25 | train1 = r[:, 0].max().item() 26 | valid = r[:, 1].max().item() 27 | train2 = r[r[:, 1].argmax(), 0].item() 28 | test = r[r[:, 1].argmax(), 2].item() 29 | best_results.append((train1, valid, train2, test)) 30 | 31 | best_result = torch.tensor(best_results) 32 | 33 | print(f'All runs:') 34 | r = best_result[:, 0] 35 | print(f'Highest Train: {r.mean():.2f} ± {r.std():.2f}') 36 | r = best_result[:, 1] 37 | print(f'Highest Valid: {r.mean():.2f} ± {r.std():.2f}') 38 | r = best_result[:, 2] 39 | print(f' Final Train: {r.mean():.2f} ± {r.std():.2f}') 40 | r = best_result[:, 3] 41 | print(f' Final Test: {r.mean():.2f} ± {r.std():.2f}') -------------------------------------------------------------------------------- /model_configurations.py: -------------------------------------------------------------------------------- 1 | from function_transformer_attention import ODEFuncTransformerAtt 2 | from function_GAT_attention import ODEFuncAtt 3 | from function_laplacian_diffusion import LaplacianODEFunc 4 | from block_transformer_attention import AttODEblock 5 | from block_constant import ConstantODEblock 6 | from block_mixed import MixedODEblock 7 | from block_transformer_hard_attention import HardAttODEblock 8 | from block_transformer_rewiring import RewireAttODEblock 9 | 10 | class BlockNotDefined(Exception): 11 | pass 12 | 13 | class FunctionNotDefined(Exception): 14 | pass 15 | 16 | 17 | def set_block(opt): 18 | ode_str = opt['block'] 19 | if ode_str == 'mixed': 20 | block = MixedODEblock 21 | elif ode_str == 'attention': 22 | block = AttODEblock 23 | elif ode_str == 'hard_attention': 24 | block = HardAttODEblock 25 | elif ode_str == 'rewire_attention': 26 | block = RewireAttODEblock 27 | elif ode_str == 'constant': 28 | block = ConstantODEblock 29 | else: 30 | raise BlockNotDefined 31 | return block 32 | 33 | 34 | def set_function(opt): 35 | ode_str = opt['function'] 36 | if ode_str == 'laplacian': 37 | f = LaplacianODEFunc 38 | elif ode_str == 'GAT': 39 | f = ODEFuncAtt 40 | elif ode_str == 'transformer': 41 | f = ODEFuncTransformerAtt 42 | else: 43 | raise FunctionNotDefined 44 | return f 45 | -------------------------------------------------------------------------------- /regularized_ODE_function.py: -------------------------------------------------------------------------------- 1 | ## This code has been adapted from https://github.com/cfinlay/ffjord-rnode/ 2 | ## MIT License 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class RegularizedODEfunc(nn.Module): 9 | def __init__(self, odefunc, regularization_fns): 10 | super(RegularizedODEfunc, self).__init__() 11 | self.odefunc = odefunc 12 | self.regularization_fns = regularization_fns 13 | 14 | def before_odeint(self, *args, **kwargs): 15 | self.odefunc.before_odeint(*args, **kwargs) 16 | 17 | def forward(self, t, state): 18 | 19 | with torch.enable_grad(): 20 | x = state[0] 21 | x.requires_grad_(True) 22 | t.requires_grad_(True) 23 | dstate = self.odefunc(t, x) 24 | if len(state) > 1: 25 | dx = dstate 26 | reg_states = tuple(reg_fn(x, t, dx, self.odefunc) for reg_fn in self.regularization_fns) 27 | return (dstate,) + reg_states 28 | else: 29 | return dstate 30 | 31 | @property 32 | def _num_evals(self): 33 | return self.odefunc._num_evals 34 | 35 | 36 | def total_derivative(x, t, dx, unused_context): 37 | del unused_context 38 | 39 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 40 | 41 | try: 42 | u = torch.full_like(dx, 1 / x.numel(), requires_grad=True) 43 | tmp = torch.autograd.grad((u * dx).sum(), t, create_graph=True)[0] 44 | partial_dt = torch.autograd.grad(tmp.sum(), u, create_graph=True)[0] 45 | 46 | total_deriv = directional_dx + partial_dt 47 | except RuntimeError as e: 48 | if 'One of the differentiated Tensors' in e.__str__(): 49 | raise RuntimeError( 50 | 'No partial derivative with respect to time. Use mathematically equivalent "directional_derivative" regularizer instead') 51 | 52 | tdv2 = total_deriv.pow(2).view(x.size(0), -1) 53 | 54 | return 0.5 * tdv2.mean(dim=-1) 55 | 56 | 57 | def directional_derivative(x, t, dx, unused_context): 58 | del t, unused_context 59 | 60 | directional_dx = torch.autograd.grad(dx, x, dx, create_graph=True)[0] 61 | ddx2 = directional_dx.pow(2).view(x.size(0), -1) 62 | 63 | return 0.5 * ddx2.mean(dim=-1) 64 | 65 | 66 | def quadratic_cost(x, t, dx, unused_context): 67 | del x, t, unused_context 68 | dx = dx.view(dx.shape[0], -1) 69 | return 0.5 * dx.pow(2).mean(dim=-1) 70 | 71 | 72 | def divergence_bf(dx, x): 73 | sum_diag = 0. 74 | for i in range(x.shape[1]): 75 | sum_diag += torch.autograd.grad(dx[:, i].sum(), x, create_graph=True)[0].contiguous()[:, i].contiguous() 76 | return sum_diag.contiguous() 77 | 78 | 79 | def jacobian_frobenius_regularization_fn(x, t, dx, context): 80 | del t 81 | return divergence_bf(dx, x) 82 | -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/beltrami_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_1ln_noAdvT/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_1ln_noAdvT/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_2ln_noAdvT/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_2ln_noAdvT/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_3ln_noAdvT/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/gcnsvd_3ln_noAdvT/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/grand_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/grand_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/grand_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/grand_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/grand_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/grand_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/heat_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/heat_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/heat_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/heat_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/heat_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/heat_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-AmazonCoBuyComputerDataset/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/beltrami_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/beltrami_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/beltrami_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/beltrami_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/beltrami_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/beltrami_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/grand_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/grand_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/grand_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/grand_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/grand_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/grand_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/heat_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/heat_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/heat_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/heat_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/heat_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/heat_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-coauthor/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-coauthor/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/beltrami2_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/beltrami2_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/beltrami2_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/beltrami2_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/beltrami2_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/beltrami2_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/grand_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/grand_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/grand_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/grand_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/grand_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/grand_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/meancurv_1noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/meancurv_2noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-flickr/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-flickr/meancurv_3noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /saved_models2/grb-pubmed/grand_noAdvT_drop05_attsamp095/model_at_0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zknus/Robustness-of-Graph-Neural-Diffusion/f42c166aaa69327782a9e8522b69e77000c25c18/saved_models2/grb-pubmed/grand_noAdvT_drop05_attsamp095/model_at_0.pt -------------------------------------------------------------------------------- /tsne_plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import seaborn as sns 4 | import torch 5 | from dgl import DGLGraph 6 | from grb.dataset import Dataset 7 | from grb.utils.normalize import GCNAdjNorm 8 | from grb.utils import utils 9 | from grb.model.torch import GCN, GCNODE, GCNODE2, BELTRAMI, MEANCURV, pLAPLACE, HEAT 10 | import os 11 | import numpy as np 12 | import scipy 13 | from torch_geometric.utils import remove_self_loops 14 | from sklearn.manifold import TSNE 15 | from scipy.sparse import csr_matrix 16 | 17 | device = torch.device('cuda') 18 | 19 | def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None, 20 | edges_to_plot=None, nodes_pos=None, nodes_colors=None, 21 | edge_colormap=plt.cm.Greys): 22 | 23 | if nodes_to_plot is None: 24 | nodes_to_plot = g.nodes() 25 | if edges_to_plot is None: 26 | assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \ 27 | 'object, got {}.'.format(type(g)) 28 | edges_to_plot = sorted(g.edges()) 29 | nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot, 30 | edge_color=attention, edge_cmap=edge_colormap, 31 | width=1, alpha=0.5, ax=ax, edge_vmin=0, 32 | edge_vmax=1) 33 | 34 | if nodes_colors is None: 35 | nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1) 36 | 37 | nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=10, 38 | node_color=[nodes_colors[nodes_labels[v]] for v in nodes_to_plot], 39 | alpha=0.9) 40 | 41 | 42 | dataset_name = 'grb-cora' 43 | dataset = Dataset(name='grb-cora', mode='easy', feat_norm='arctan') 44 | 45 | adj = dataset.adj 46 | features = dataset.features 47 | labels = dataset.labels 48 | num_features = dataset.num_features 49 | num_classes = dataset.num_classes 50 | mask = dataset.train_mask 51 | 52 | 53 | save_name = "model_at.pt" 54 | save_name = save_name.split('.')[0] + "_{}.pt".format(0) 55 | model_name = "meancurv_noAdvT_drop05_attsamp095" 56 | 57 | save_dir = "./saved_models2/{}/{}".format(dataset_name, model_name) 58 | 59 | if model_name.split('_')[0] == "beltrami": 60 | model = BELTRAMI(in_features=dataset.num_features, 61 | out_features=dataset.num_classes, 62 | hidden_features=128, 63 | n_layers=4, 64 | adj_norm_func=GCNAdjNorm, 65 | layer_norm=True, 66 | residual=False, 67 | dropout=0.5) 68 | 69 | if model_name.split('_')[0] == "meancurv": 70 | model = MEANCURV(in_features=dataset.num_features, 71 | out_features=dataset.num_classes, 72 | hidden_features=128, 73 | n_layers=4, 74 | adj_norm_func=GCNAdjNorm, 75 | layer_norm=True, 76 | residual=False, 77 | dropout=0.5) 78 | 79 | if model_name.split('_')[0] == "heat": 80 | model = HEAT(in_features=dataset.num_features, 81 | out_features=dataset.num_classes, 82 | hidden_features=64, 83 | n_layers=3, 84 | adj_norm_func=GCNAdjNorm, 85 | layer_norm=True, 86 | residual=False, 87 | dropout=0.5) 88 | if model_name.split('_')[0] == "grand": 89 | model = GCNODE(in_features=dataset.num_features, 90 | out_features=dataset.num_classes, 91 | hidden_features=64, 92 | n_layers=3, 93 | adj_norm_func=GCNAdjNorm, 94 | layer_norm=True, 95 | residual=False, 96 | dropout=0.5) 97 | 98 | if model_name.split('_')[0] == "grand2": 99 | model = GCNODE2(in_features=dataset.num_features, 100 | out_features=dataset.num_classes, 101 | hidden_features=64, 102 | n_layers=3, 103 | adj_norm_func=GCNAdjNorm, 104 | layer_norm=True, 105 | residual=False, 106 | dropout=0.5) 107 | 108 | ckp = torch.load(os.path.join(save_dir, save_name), map_location=device) 109 | model.load_state_dict(ckp['model']) 110 | model.to(device) 111 | model.eval() 112 | 113 | adj_ = utils.adj_preprocess(adj, adj_norm_func=model.adj_norm_func, mask=mask, model_type=model.model_type) 114 | logits, att, edge_index = model(features[mask].to(device), adj_.to(device)) 115 | 116 | t_sne_embeddings = TSNE(n_components=2, perplexity=30, method='barnes_hut').fit_transform(logits.cpu().detach().numpy()) 117 | pos={} 118 | for i in range(logits.shape[0]): 119 | x = t_sne_embeddings[i,0] 120 | y = t_sne_embeddings[i,1] 121 | pos[i] = (x,y) 122 | 123 | # node_labels = labels[mask] 124 | # cora_label_to_color_map = {0: "red", 1: "blue", 2: "green", 3: "orange", 4: "yellow", 5: "pink", 6: "gray"} 125 | # fig = plt.figure(figsize=(12,8), dpi=80) # otherwise plots are really small in Jupyter Notebook 126 | # for class_id in range(num_classes): 127 | # plt.scatter(t_sne_embeddings[node_labels == class_id, 0], t_sne_embeddings[node_labels == class_id, 1], s=20, color=cora_label_to_color_map[class_id], edgecolors='black', linewidths=0.2) 128 | # plt.show() 129 | # plt.savefig('b.png') 130 | 131 | # mean_att = 0 132 | # for i in range(len(att)): 133 | # mean_att += att[i] 134 | # mean_att /= len(att) 135 | 136 | 137 | # adj2 = adj_.to_dense().cpu().detach().numpy() 138 | # adj2 -= np.diag(np.diag(adj2)) 139 | 140 | #src, dst = np.nonzero(adj2) 141 | 142 | 143 | ns = list(range(mask[mask==True].shape[0])) 144 | for layer in range(1): 145 | print(layer, len(att), edge_index[layer]) 146 | edge_index_noself, new_weights = remove_self_loops(edge_index[layer], att[layer]) 147 | 148 | src = edge_index_noself[0].cpu().detach().numpy() 149 | dst = edge_index_noself[1].cpu().detach().numpy() 150 | g_dgl = DGLGraph((src, dst)).cpu() 151 | dg = g_dgl.to_networkx().to_directed() 152 | dg.edges() 153 | 154 | 155 | a = csr_matrix((new_weights.cpu().detach().numpy(), (src, dst)), shape=(len(ns), len(ns))) 156 | b = a.toarray() 157 | 158 | U, S, Vh = np.linalg.svd(b, full_matrices=False) 159 | plt.plot(list(range(len(S))), S) 160 | plt.show() 161 | plt.savefig(model_name+'_layer'+str(layer)+'_svd.png') 162 | plt.close() 163 | 164 | ent = [] 165 | for i in range(len(ns)): 166 | p = att[layer][edge_index[layer][0]==i].cpu().detach().numpy() 167 | print(np.sum(p)) 168 | if p.shape[0] == 0: 169 | ent.append(0) 170 | else: 171 | a = np.ones((p.shape[0], 1))/p.shape[0] 172 | z = np.sum(-np.log(a)*a) 173 | ent.append(np.sum(-np.log(p)*p)/z) 174 | plt.plot(ns, ent) 175 | plt.show() 176 | plt.savefig(model_name+'_layer'+str(layer)+'.png') 177 | plt.close() 178 | 179 | #Y.reshape(Y.shape[0]) 180 | grEd = dg.edges() 181 | att_list = new_weights.cpu().detach().numpy().tolist() 182 | 183 | 184 | fig, ax = plt.subplots() 185 | plot(dg, att_list, ax=ax, nodes_pos=pos, nodes_labels=labels[mask]) 186 | # ax.set_axis_off() 187 | # sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1)) 188 | # sm.set_array([]) 189 | # plt.colorbar(sm, fraction=0.046, pad=0.01) 190 | plt.show() 191 | plt.savefig(model_name+'_layer'+str(layer)+'_tsne.png') 192 | plt.close() --------------------------------------------------------------------------------