├── README.md ├── dataset ├── CiteSeer │ ├── processed │ │ └── .gitignore │ └── raw │ │ └── .gitignore ├── Cora │ ├── .gitignore │ ├── processed │ │ └── .gitignore │ └── raw │ │ └── .gitignore ├── Reddit │ ├── processed │ │ └── .gitignore │ └── raw │ │ └── .gitignore ├── ogbn_proteins │ ├── .gitignore │ ├── mapping │ │ └── .gitignore │ ├── processed │ │ └── .gitignore │ ├── raw │ │ └── .gitignore │ └── split │ │ └── species │ │ └── .gitignore └── pubmed │ ├── .gitignore │ ├── processed │ └── .gitignore │ └── raw │ └── .gitignore ├── example └── GCN.py ├── model ├── __pycache__ │ ├── gnn_model_1_6.cpython-37.pyc │ ├── gnn_model_1_6.cpython-38.pyc │ ├── inf_model.cpython-38.pyc │ └── inits.cpython-38.pyc ├── gnn_model_1_6.py ├── inf_model.py └── inits.py ├── pyg1.svg ├── result.png └── test ├── test_gnn_layer.py └── test_gnn_total.py /README.md: -------------------------------------------------------------------------------- 1 | # GNN-Feature-Decomposition 2 | 3 | 4 | ### ***NOTE!*** 5 | After our cooperation with the [PyG team](https://github.com/pyg-team/pytorch_geometric), the feature decomposition method has been improved and added 6 | to the [PyG framework](https://github.com/pyg-team/pytorch_geometric) as an optional feature. For more details and how to use please refer to 7 | [Feature Decomposition in PyG](#Feature-Decomposition-in-PyG) 8 | 9 |

10 | 11 |

12 | 13 | -------------------------------------------------------------------------------- 14 | 15 | 16 | ### Overview 17 | 18 | This is a repository for our work: GNN Feature Decomposition, 19 | which is accepted by RTAS 2021(Brif Industry Track), named ***"Optimizing Memory Efficiency of Graph NeuralNetworks on Edge Computing Platforms"*** 20 | 21 | For more details, please see our full paper: https://arxiv.org/abs/2104.03058 22 | 23 | Graph neural networks (GNN) have achieved state-of-the-art performance on various industrial tasks. 24 | However, the poor efficiency of GNN inference and frequent Out-Of-Memory (OOM) problem limit the successful application of GNN on edge computing platforms. 25 | To tackle these problems, a feature decomposition approach is proposed for memory efficiency optimization of GNN inference. 26 | The proposed approach could achieve outstanding optimization on various GNN models, covering a wide range of datasets, which speeds up the inference by up to 3x. 27 | Furthermore, the proposed feature decomposition could significantly reduce the peak memory usage (up to 5x in memory efficiency improvement) and mitigate OOM problems during GNN inference. 28 | 29 | ### Requirements 30 | 31 | Recent versions of PyTorch, numpy, torch_geometric(1.6.3) are required. 32 | 33 | 34 | ### Contents 35 | There are two main top-level scripts in this repo: 36 | 37 | 1.test_gnn_layer.py: runs a gnn feature decomposition method on single GNN layer. 38 | 2.test_gnn_total.py: runs a gnn feature decomposition method on total gnn models. 39 | 40 | ### Running the code 41 | #### test single gnn layer by our feature decomposition method. 42 | cd test 43 | python test_gnn_layer.py --hidden=32 --agg="gas" --m="GCN" --layer=32 --data="rd" 44 | 45 | #### test total gnn model by our feature decomposition method. 46 | cd test 47 | python test_gnn_total.py --hidden=32 --agg="gas" --m="GCN" --layer="32,41" --data="rd" 48 | 49 | - hidden: the hidden layer size of gnn. 50 | - agg: the aggregate model, include spmm and gas. if using feature decomposition, there shoule be "gas". 51 | - m: the gnn model name,include "GCN,GAT,GIN,SAGE". 52 | - layer: the layers of feature decomposition along dimension of feature vector, the basic gnn inference using 1 layer. 53 | if test total gnn model,there should be two parameters. 54 | - data: dataset name. 55 | 56 | 57 | ### Feature Decomposition in PyG 58 | 59 | --- 60 | 61 | We integrated the feature decomposition into the [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) 62 | module of the [PyG framework](https://github.com/pyg-team/pytorch_geometric). Specifically, we added an optional argument `decomposed_layers: int = 1` to the initialization 63 | function of the [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) module, as shown below. 64 | 65 | 66 | def __init__(self, aggr: Optional[str] = "add", 67 | flow: str = "source_to_target", node_dim: int = -2, 68 | decomposed_layers: int = 1): 69 | 70 | When creating a layer, pass in the `decomposed_layers` (>1) to use the feature decomposition method. As fllows: 71 | 72 | conv = GCNConv(16, 32, decomoposed_layers = 2 ) 73 | 74 | For specific usage, please refer to the [example](example/GCN.py) 75 | 76 | --- 77 | The following table is the test result of the [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv) layer on the reddit data set. 78 | It can be seen that the best acceleration effect can be achieved when the decomposition granularity is the largest (Not mean that all layers are like this). 79 | The only rule that we can be sure of is: For a given graph, there is an optimal dimension of one decomposed layer, which does not change with the hidden layer dimension 80 | (eg: as the following table, it is optimal to set the dimension of the decomposed layer to 1, which means `decomoposed_layers = hidden layers` ). 81 | 82 | Test machine: Intel Core 8700k CPU, Ubuntu system. 83 | 84 |

85 | 86 |

87 | 88 | - The horizontal axis : represents the hidden layer dimension of [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv). 89 | - the vertical axis : equal to the `decomposed_layers`, which means the granularity of feature decomposition. 90 | - total : means the running time of whole layer. 91 | - agg : means the running time of Aggregation phase of GCNConv. 92 | 93 | -------------------------------------------------------------------------------- /dataset/CiteSeer/processed/.gitignore: -------------------------------------------------------------------------------- 1 | /data.pt 2 | /pre_filter.pt 3 | /pre_transform.pt 4 | -------------------------------------------------------------------------------- /dataset/CiteSeer/raw/.gitignore: -------------------------------------------------------------------------------- 1 | /ind.citeseer.allx 2 | /ind.citeseer.ally 3 | /ind.citeseer.graph 4 | /ind.citeseer.test.index 5 | /ind.citeseer.tx 6 | /ind.citeseer.ty 7 | /ind.citeseer.x 8 | /ind.citeseer.y 9 | -------------------------------------------------------------------------------- /dataset/Cora/.gitignore: -------------------------------------------------------------------------------- 1 | /processed_cora.pkl 2 | -------------------------------------------------------------------------------- /dataset/Cora/processed/.gitignore: -------------------------------------------------------------------------------- 1 | /data.pt 2 | /pre_filter.pt 3 | /pre_transform.pt 4 | -------------------------------------------------------------------------------- /dataset/Cora/raw/.gitignore: -------------------------------------------------------------------------------- 1 | /ind.cora.allx 2 | /ind.cora.ally 3 | /ind.cora.graph 4 | /ind.cora.test.index 5 | /ind.cora.tx 6 | /ind.cora.ty 7 | /ind.cora.x 8 | /ind.cora.y 9 | -------------------------------------------------------------------------------- /dataset/Reddit/processed/.gitignore: -------------------------------------------------------------------------------- 1 | /data.pt 2 | /pre_filter.pt 3 | /pre_transform.pt 4 | -------------------------------------------------------------------------------- /dataset/Reddit/raw/.gitignore: -------------------------------------------------------------------------------- 1 | /reddit_data.npz 2 | /reddit_graph.npz 3 | -------------------------------------------------------------------------------- /dataset/ogbn_proteins/.gitignore: -------------------------------------------------------------------------------- 1 | /RELEASE_v1.txt 2 | -------------------------------------------------------------------------------- /dataset/ogbn_proteins/mapping/.gitignore: -------------------------------------------------------------------------------- 1 | /README.md 2 | /labelidx2GO.csv.gz 3 | /nodeidx2proteinid.csv.gz 4 | -------------------------------------------------------------------------------- /dataset/ogbn_proteins/processed/.gitignore: -------------------------------------------------------------------------------- 1 | /geometric_data_processed.pt 2 | /pre_filter.pt 3 | /pre_transform.pt 4 | -------------------------------------------------------------------------------- /dataset/ogbn_proteins/raw/.gitignore: -------------------------------------------------------------------------------- 1 | /edge-feat.csv.gz 2 | /edge.csv.gz 3 | /node-label.csv.gz 4 | /node_species.csv.gz 5 | /num-edge-list.csv.gz 6 | /num-node-list.csv.gz 7 | -------------------------------------------------------------------------------- /dataset/ogbn_proteins/split/species/.gitignore: -------------------------------------------------------------------------------- 1 | /test.csv.gz 2 | /train.csv.gz 3 | /valid.csv.gz 4 | -------------------------------------------------------------------------------- /dataset/pubmed/.gitignore: -------------------------------------------------------------------------------- 1 | /processed_pubmed.pkl 2 | -------------------------------------------------------------------------------- /dataset/pubmed/processed/.gitignore: -------------------------------------------------------------------------------- 1 | /data.pt 2 | /pre_filter.pt 3 | /pre_transform.pt 4 | -------------------------------------------------------------------------------- /dataset/pubmed/raw/.gitignore: -------------------------------------------------------------------------------- 1 | /ind.pubmed.allx 2 | /ind.pubmed.ally 3 | /ind.pubmed.graph 4 | /ind.pubmed.test.index 5 | /ind.pubmed.tx 6 | /ind.pubmed.ty 7 | /ind.pubmed.x 8 | /ind.pubmed.y 9 | -------------------------------------------------------------------------------- /example/GCN.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.datasets import Planetoid 7 | import torch_geometric.transforms as T 8 | from torch_geometric.nn import GCNConv, ChebConv # noqa 9 | from torch_geometric.datasets import Reddit 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--use_gdc', action='store_true', 12 | help='Use GDC preprocessing.') 13 | args = parser.parse_args() 14 | 15 | dataset = 'Reddit' 16 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', dataset) 17 | dataset = Reddit(path,transform=T.NormalizeFeatures()) 18 | data = dataset[0] 19 | 20 | if args.use_gdc: 21 | gdc = T.GDC(self_loop_weight=1, normalization_in='sym', 22 | normalization_out='col', 23 | diffusion_kwargs=dict(method='ppr', alpha=0.05), 24 | sparsification_kwargs=dict(method='topk', k=128, 25 | dim=0), exact=True) 26 | data = gdc(data) 27 | 28 | 29 | class Net(torch.nn.Module): 30 | def __init__(self): 31 | super(Net, self).__init__() 32 | # using feature decomposition method, and set the decomposed layers = 2 33 | self.conv1 = GCNConv(dataset.num_features, 16, cached=True, 34 | normalize=not args.use_gdc, decomoposed_layers = 2) 35 | self.conv2 = GCNConv(16, dataset.num_classes, cached=True, 36 | normalize=not args.use_gdc, decomoposed_layers = 2) 37 | # self.conv1 = ChebConv(data.num_features, 16, K=2) 38 | # self.conv2 = ChebConv(16, data.num_features, K=2) 39 | 40 | def forward(self): 41 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr 42 | x = F.relu(self.conv1(x, edge_index, edge_weight)) 43 | x = F.dropout(x, training=self.training) 44 | x = self.conv2(x, edge_index, edge_weight) 45 | return F.log_softmax(x, dim=1) 46 | 47 | 48 | device = torch.device('cpu') 49 | model, data = Net().to(device), data.to(device) 50 | optimizer = torch.optim.Adam([ 51 | dict(params=model.conv1.parameters(), weight_decay=5e-4), 52 | dict(params=model.conv2.parameters(), weight_decay=0) 53 | ], lr=0.01) # Only perform weight-decay on first convolution. 54 | 55 | 56 | def train(): 57 | model.train() 58 | optimizer.zero_grad() 59 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward() 60 | optimizer.step() 61 | 62 | 63 | @torch.no_grad() 64 | def test(): 65 | model.eval() 66 | logits, accs = model(), [] 67 | for _, mask in data('train_mask', 'val_mask', 'test_mask'): 68 | pred = logits[mask].max(1)[1] 69 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item() 70 | accs.append(acc) 71 | return accs 72 | 73 | 74 | best_val_acc = test_acc = 0 75 | for epoch in range(1, 201): 76 | train() 77 | train_acc, val_acc, tmp_test_acc = test() 78 | if val_acc > best_val_acc: 79 | best_val_acc = val_acc 80 | test_acc = tmp_test_acc 81 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}' 82 | print(log.format(epoch, train_acc, best_val_acc, test_acc)) -------------------------------------------------------------------------------- /model/__pycache__/gnn_model_1_6.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/gnn_model_1_6.cpython-37.pyc -------------------------------------------------------------------------------- /model/__pycache__/gnn_model_1_6.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/gnn_model_1_6.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/inf_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/inf_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/inits.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/inits.cpython-38.pyc -------------------------------------------------------------------------------- /model/gnn_model_1_6.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch_sparse import SparseTensor, matmul 4 | 5 | from typing import Union, Tuple, Optional 6 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, 7 | OptTensor) 8 | import torch 9 | from torch import Tensor 10 | import torch.nn.functional as F 11 | from torch.nn import Parameter, Linear 12 | from torch_sparse import SparseTensor, set_diag 13 | from torch_geometric.nn.conv import MessagePassing 14 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax 15 | 16 | from .inits import glorot, zeros, reset 17 | # 基于1.6的pyg的各gnn模型修改 18 | from typing import Optional, Tuple 19 | from torch_geometric.typing import Adj, OptTensor, PairTensor 20 | 21 | import torch 22 | from torch import Tensor 23 | from torch.nn import Parameter 24 | from torch_scatter import scatter_add 25 | from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul 26 | from torch_geometric.nn.conv import MessagePassing 27 | from torch_geometric.utils import add_remaining_self_loops 28 | from torch_geometric.utils.num_nodes import maybe_num_nodes 29 | 30 | from typing import Callable, Union 31 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size 32 | 33 | import torch 34 | from torch import Tensor 35 | import torch.nn.functional as F 36 | from torch_sparse import SparseTensor, matmul 37 | from torch_geometric.nn.conv import MessagePassing 38 | 39 | 40 | @torch.jit._overload 41 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, 42 | add_self_loops=True, dtype=None): 43 | # type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa 44 | pass 45 | 46 | 47 | @torch.jit._overload 48 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, 49 | add_self_loops=True, dtype=None): 50 | # type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa 51 | pass 52 | 53 | 54 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, 55 | add_self_loops=True, dtype=None): 56 | fill_value = 2. if improved else 1. 57 | 58 | if isinstance(edge_index, SparseTensor): 59 | adj_t = edge_index 60 | if not adj_t.has_value(): 61 | adj_t = adj_t.fill_value(1., dtype=dtype) 62 | if add_self_loops: 63 | adj_t = fill_diag(adj_t, fill_value) 64 | deg = sum(adj_t, dim=1) 65 | 66 | deg_inv_sqrt = deg.pow_(-0.5) 67 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.) 68 | adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1)) 69 | adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1)) 70 | return adj_t 71 | 72 | else: 73 | num_nodes = maybe_num_nodes(edge_index, num_nodes) 74 | 75 | if edge_weight is None: 76 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype, 77 | device=edge_index.device) 78 | 79 | if add_self_loops: 80 | edge_index, tmp_edge_weight = add_remaining_self_loops( 81 | edge_index, edge_weight, fill_value, num_nodes) 82 | assert tmp_edge_weight is not None 83 | edge_weight = tmp_edge_weight 84 | 85 | row, col = edge_index[0], edge_index[1] 86 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes) 87 | deg_inv_sqrt = deg.pow_(-0.5) 88 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0) 89 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] 90 | 91 | 92 | class SAGEConv_3d(MessagePassing): 93 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 94 | Large Graphs" `_ paper 95 | 96 | .. math:: 97 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot 98 | \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j 99 | 100 | Args: 101 | in_channels (int or tuple): Size of each input sample. A tuple 102 | corresponds to the sizes of source and target dimensionalities. 103 | out_channels (int): Size of each output sample. 104 | normalize (bool, optional): If set to :obj:`True`, output features 105 | will be :math:`\ell_2`-normalized, *i.e.*, 106 | :math:`\frac{\mathbf{x}^{\prime}_i} 107 | {\| \mathbf{x}^{\prime}_i \|_2}`. 108 | (default: :obj:`False`) 109 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 110 | an additive bias. (default: :obj:`True`) 111 | **kwargs (optional): Additional arguments of 112 | :class:`torch_geometric.nn.conv.MessagePassing`. 113 | """ 114 | 115 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 116 | out_channels: int, normalize: bool = False, 117 | bias: bool = True, **kwargs): # yapf: disable 118 | super(SAGEConv_3d, self).__init__(aggr='mean', **kwargs) 119 | 120 | self.in_channels = in_channels 121 | self.out_channels = out_channels 122 | self.normalize = normalize 123 | 124 | if isinstance(in_channels, int): 125 | in_channels = (in_channels, in_channels) 126 | 127 | self.lin_l = Linear(in_channels[0], out_channels, bias=bias) 128 | self.lin_r = Linear(in_channels[1], out_channels, bias=False) 129 | 130 | self.reset_parameters() 131 | 132 | def reset_parameters(self): 133 | self.lin_l.reset_parameters() 134 | self.lin_r.reset_parameters() 135 | 136 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 137 | size: Size = None, layer=None) -> Tensor: 138 | """""" 139 | if layer == None or layer == 0: 140 | x_r = x 141 | if isinstance(x, Tensor): 142 | x: OptPairTensor = (x, x) 143 | out = self.propagate(edge_index, x=x, size=size, layer=None) 144 | out = self.lin_l(out) 145 | # propagate_type: (x: OptPairTensor) 146 | elif layer > 1: 147 | x_r = x 148 | L = math.floor(self.out_channels / layer) 149 | parts = [] 150 | for i in range(layer - 1): 151 | parts.append(L) 152 | parts.append(x.shape[-1] - L * (layer - 1)) 153 | parts = tuple(parts) 154 | s_x = x.split(parts, dim=-1) 155 | # s_x_p: OptPairTensor = (s_x[0], s_x[0]) 156 | out = self.propagate(edge_index, x=s_x[0], size=size) 157 | for i in range(1, layer): 158 | # s_x_p: OptPairTensor = (s_x[i], s_x[i]) 159 | out_l = self.propagate(edge_index, x=s_x[i], size=size) 160 | out = torch.cat((out, out_l), dim=-1) 161 | out = self.lin_l(out) 162 | else: 163 | x_r = x 164 | out = self.propagate(edge_index, x=x, size=size, layer=layer) 165 | out = self.lin_l(out) 166 | if x_r is not None: 167 | out += self.lin_r(x_r) 168 | 169 | if self.normalize: 170 | out = F.normalize(out, p=2., dim=-1) 171 | 172 | return out 173 | 174 | def message(self, x_j: Tensor) -> Tensor: 175 | return x_j 176 | 177 | def message_and_aggregate(self, adj_t: SparseTensor, 178 | x: OptPairTensor) -> Tensor: 179 | adj_t = adj_t.set_value(None, layout=None) 180 | return matmul(adj_t, x[0], reduce=self.aggr) 181 | 182 | def __repr__(self): 183 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 184 | self.out_channels) 185 | 186 | 187 | class SAGEConv_3d_com(MessagePassing): 188 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on 189 | Large Graphs" `_ paper 190 | 191 | .. math:: 192 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot 193 | \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j 194 | 195 | Args: 196 | in_channels (int or tuple): Size of each input sample. A tuple 197 | corresponds to the sizes of source and target dimensionalities. 198 | out_channels (int): Size of each output sample. 199 | normalize (bool, optional): If set to :obj:`True`, output features 200 | will be :math:`\ell_2`-normalized, *i.e.*, 201 | :math:`\frac{\mathbf{x}^{\prime}_i} 202 | {\| \mathbf{x}^{\prime}_i \|_2}`. 203 | (default: :obj:`False`) 204 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 205 | an additive bias. (default: :obj:`True`) 206 | **kwargs (optional): Additional arguments of 207 | :class:`torch_geometric.nn.conv.MessagePassing`. 208 | """ 209 | 210 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 211 | out_channels: int, normalize: bool = False, 212 | bias: bool = True, **kwargs): # yapf: disable 213 | super(SAGEConv_3d_com, self).__init__(aggr='mean', **kwargs) 214 | 215 | self.in_channels = in_channels 216 | self.out_channels = out_channels 217 | self.normalize = normalize 218 | 219 | if isinstance(in_channels, int): 220 | in_channels = (in_channels, in_channels) 221 | 222 | self.lin_l = Linear(in_channels[0], out_channels, bias=bias) 223 | self.lin_r = Linear(in_channels[1], out_channels, bias=False) 224 | 225 | self.reset_parameters() 226 | 227 | def reset_parameters(self): 228 | self.lin_l.reset_parameters() 229 | self.lin_r.reset_parameters() 230 | 231 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 232 | size: Size = None, layer=None) -> Tensor: 233 | """""" 234 | if layer == None or layer == 0: 235 | x_r = x 236 | x = self.lin_l(x) 237 | if isinstance(x, Tensor): 238 | x: OptPairTensor = (x, x) 239 | out = self.propagate(edge_index, x=x, size=size, layer=layer) 240 | 241 | 242 | # propagate_type: (x: OptPairTensor) 243 | elif layer > 1: 244 | x_r = x 245 | x = self.lin_l(x) 246 | L = math.floor(self.out_channels / layer) 247 | parts = [] 248 | for i in range(layer - 1): 249 | parts.append(L) 250 | parts.append(x.shape[-1] - L * (layer - 1)) 251 | parts = tuple(parts) 252 | s_x = x.split(parts, dim=-1) 253 | # s_x_p: OptPairTensor = (s_x[0], s_x[0]) 254 | out = self.propagate(edge_index, x=s_x[0], size=size) 255 | 256 | for i in range(1, layer): 257 | # s_x_p: OptPairTensor = (s_x[i], s_x[i]) 258 | out_l = self.propagate(edge_index, x=s_x[i], size=size) 259 | out = torch.cat((out, out_l), dim=-1) 260 | # out = self.lin_l(out) 261 | # x_r = x 262 | else: 263 | x_r = x 264 | x = self.lin_l(x) 265 | out = self.propagate(edge_index, x=x, size=size, layer=layer) 266 | 267 | if x_r is not None: 268 | out += self.lin_r(x_r) 269 | 270 | if self.normalize: 271 | out = F.normalize(out, p=2., dim=-1) 272 | 273 | return out 274 | 275 | def message(self, x_j: Tensor) -> Tensor: 276 | return x_j 277 | 278 | def message_and_aggregate(self, adj_t: SparseTensor, 279 | x: OptPairTensor) -> Tensor: 280 | adj_t = adj_t.set_value(None, layout=None) 281 | return matmul(adj_t, x[0], reduce=self.aggr) 282 | 283 | def __repr__(self): 284 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 285 | self.out_channels) 286 | 287 | 288 | class GCNConv_3d(MessagePassing): 289 | r"""The graph convolutional operator from the `"Semi-supervised 290 | Classification with Graph Convolutional Networks" 291 | `_ paper 292 | 293 | .. math:: 294 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 295 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, 296 | 297 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the 298 | adjacency matrix with inserted self-loops and 299 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. 300 | The adjacency matrix can include other values than :obj:`1` representing 301 | edge weights via the optional :obj:`edge_weight` tensor. 302 | 303 | Its node-wise formulation is given by: 304 | 305 | .. math:: 306 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j} 307 | \frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j 308 | 309 | with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where 310 | :math:`e_{j,i}` denotes the edge weight from source node :obj:`i` to target 311 | node :obj:`j` (default: :obj:`1`) 312 | 313 | Args: 314 | in_channels (int): Size of each input sample. 315 | out_channels (int): Size of each output sample. 316 | improved (bool, optional): If set to :obj:`True`, the layer computes 317 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. 318 | (default: :obj:`False`) 319 | cached (bool, optional): If set to :obj:`True`, the layer will cache 320 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} 321 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the 322 | cached version for further executions. 323 | This parameter should only be set to :obj:`True` in transductive 324 | learning scenarios. (default: :obj:`False`) 325 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 326 | self-loops to the input graph. (default: :obj:`True`) 327 | normalize (bool, optional): Whether to add self-loops and apply 328 | symmetric normalization. (default: :obj:`True`) 329 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 330 | an additive bias. (default: :obj:`True`) 331 | **kwargs (optional): Additional arguments of 332 | :class:`torch_geometric.nn.conv.MessagePassing`. 333 | """ 334 | 335 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]] 336 | _cached_adj_t: Optional[SparseTensor] 337 | 338 | def __init__(self, in_channels: int, out_channels: int, 339 | improved: bool = False, cached: bool = False, 340 | add_self_loops: bool = True, normalize: bool = True, 341 | bias: bool = True, **kwargs): 342 | 343 | kwargs.setdefault('aggr', 'add') 344 | super(GCNConv_3d, self).__init__(**kwargs) 345 | 346 | self.in_channels = in_channels 347 | self.out_channels = out_channels 348 | self.improved = improved 349 | self.cached = cached 350 | self.add_self_loops = add_self_loops 351 | self.normalize = normalize 352 | 353 | self._cached_edge_index = None 354 | self._cached_adj_t = None 355 | 356 | self.weight = Parameter(torch.Tensor(in_channels, out_channels)) 357 | 358 | if bias: 359 | self.bias = Parameter(torch.Tensor(out_channels)) 360 | else: 361 | self.register_parameter('bias', None) 362 | 363 | self.reset_parameters() 364 | 365 | def reset_parameters(self): 366 | glorot(self.weight) 367 | zeros(self.bias) 368 | self._cached_edge_index = None 369 | self._cached_adj_t = None 370 | 371 | def forward(self, x: Tensor, edge_index: Adj, 372 | edge_weight: OptTensor = None, layer=1) -> Tensor: 373 | """""" 374 | 375 | if self.normalize: 376 | if isinstance(edge_index, Tensor): 377 | cache = self._cached_edge_index 378 | if cache is None: 379 | edge_index, edge_weight = gcn_norm( # yapf: disable 380 | edge_index, edge_weight, x.size(self.node_dim), 381 | self.improved, self.add_self_loops, dtype=x.dtype) 382 | if self.cached: 383 | self._cached_edge_index = (edge_index, edge_weight) 384 | else: 385 | edge_index, edge_weight = cache[0], cache[1] 386 | 387 | elif isinstance(edge_index, SparseTensor): 388 | cache = self._cached_adj_t 389 | if cache is None: 390 | edge_index = gcn_norm( # yapf: disable 391 | edge_index, edge_weight, x.size(self.node_dim), 392 | self.improved, self.add_self_loops, dtype=x.dtype) 393 | if self.cached: 394 | self._cached_adj_t = edge_index 395 | else: 396 | edge_index = cache 397 | 398 | x = torch.matmul(x, self.weight) 399 | if layer != None and layer > 1: 400 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 401 | L = math.floor(self.out_channels / layer) 402 | parts = [] 403 | for i in range(layer - 1): 404 | parts.append(L) 405 | parts.append(x.shape[1] - L * (layer - 1)) 406 | parts = tuple(parts) 407 | s_x = x.split(parts, dim=-1) 408 | # s_x = torch.chunk(x,layer,-1) 409 | # ew = torch.chunk(edge_weight,layer,-1) 410 | # s_x_p: OptPairTensor = (s_x[0], s_x[0]) 411 | out = self.propagate(edge_index, x=s_x[0], edge_weight=edge_weight, 412 | size=None) 413 | for i in range(1, layer): 414 | # s_x_p: OptPairTensor = (s_x[i], s_x[i]) 415 | out_l = self.propagate(edge_index, x=s_x[i], edge_weight=edge_weight, 416 | size=None) 417 | out = torch.cat((out, out_l), dim=-1) 418 | else: 419 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight, 420 | size=None) 421 | 422 | if self.bias is not None: 423 | out += self.bias 424 | 425 | return out 426 | 427 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: 428 | if edge_weight is None: 429 | return x_j 430 | else: 431 | return edge_weight.view(-1, 1) * x_j 432 | 433 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: 434 | return matmul(adj_t, x, reduce=self.aggr) 435 | 436 | def __repr__(self): 437 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 438 | self.out_channels) 439 | 440 | 441 | class GATConv_3d(MessagePassing): 442 | r"""The graph attentional operator from the `"Graph Attention Networks" 443 | `_ paper 444 | 445 | .. math:: 446 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + 447 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, 448 | 449 | where the attention coefficients :math:`\alpha_{i,j}` are computed as 450 | 451 | .. math:: 452 | \alpha_{i,j} = 453 | \frac{ 454 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 455 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] 456 | \right)\right)} 457 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 458 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 459 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] 460 | \right)\right)}. 461 | 462 | Args: 463 | in_channels (int or tuple): Size of each input sample. A tuple 464 | corresponds to the sizes of source and target dimensionalities. 465 | out_channels (int): Size of each output sample. 466 | heads (int, optional): Number of multi-head-attentions. 467 | (default: :obj:`1`) 468 | concat (bool, optional): If set to :obj:`False`, the multi-head 469 | attentions are averaged instead of concatenated. 470 | (default: :obj:`True`) 471 | negative_slope (float, optional): LeakyReLU angle of the negative 472 | slope. (default: :obj:`0.2`) 473 | dropout (float, optional): Dropout probability of the normalized 474 | attention coefficients which exposes each node to a stochastically 475 | sampled neighborhood during training. (default: :obj:`0`) 476 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 477 | self-loops to the input graph. (default: :obj:`True`) 478 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 479 | an additive bias. (default: :obj:`True`) 480 | **kwargs (optional): Additional arguments of 481 | :class:`torch_geometric.nn.conv.MessagePassing`. 482 | """ 483 | _alpha: OptTensor 484 | 485 | def __init__(self, in_channels: Union[int, Tuple[int, int]], 486 | out_channels: int, heads: int = 1, concat: bool = True, 487 | negative_slope: float = 0.2, dropout: float = 0., 488 | add_self_loops: bool = True, bias: bool = True, **kwargs): 489 | kwargs.setdefault('aggr', 'add') 490 | super(GATConv_3d, self).__init__(node_dim=0, **kwargs) 491 | 492 | self.in_channels = in_channels 493 | self.out_channels = out_channels 494 | self.heads = heads 495 | self.concat = concat 496 | self.negative_slope = negative_slope 497 | self.dropout = dropout 498 | self.add_self_loops = add_self_loops 499 | 500 | if isinstance(in_channels, int): 501 | self.lin_l = Linear(in_channels, heads * out_channels, bias=False) 502 | self.lin_r = self.lin_l 503 | else: 504 | self.lin_l = Linear(in_channels[0], heads * out_channels, False) 505 | self.lin_r = Linear(in_channels[1], heads * out_channels, False) 506 | 507 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels)) 508 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels)) 509 | 510 | if bias and concat: 511 | self.bias = Parameter(torch.Tensor(heads * out_channels)) 512 | elif bias and not concat: 513 | self.bias = Parameter(torch.Tensor(out_channels)) 514 | else: 515 | self.register_parameter('bias', None) 516 | 517 | self._alpha = None 518 | 519 | self.reset_parameters() 520 | 521 | def reset_parameters(self): 522 | glorot(self.lin_l.weight) 523 | glorot(self.lin_r.weight) 524 | glorot(self.att_l) 525 | glorot(self.att_r) 526 | zeros(self.bias) 527 | 528 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 529 | size: Size = None, return_attention_weights=None, layer=None): 530 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa 531 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa 532 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa 533 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa 534 | r""" 535 | 536 | Args: 537 | return_attention_weights (bool, optional): If set to :obj:`True`, 538 | will additionally return the tuple 539 | :obj:`(edge_index, attention_weights)`, holding the computed 540 | attention weights for each edge. (default: :obj:`None`) 541 | """ 542 | H, C = self.heads, self.out_channels 543 | 544 | x_l: OptTensor = None 545 | x_r: OptTensor = None 546 | alpha_l: OptTensor = None 547 | alpha_r: OptTensor = None 548 | if isinstance(x, Tensor): 549 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' 550 | x_l = x_r = self.lin_l(x).view(-1, H, C) 551 | alpha_l = (x_l * self.att_l).sum(dim=-1) 552 | alpha_r = (x_r * self.att_r).sum(dim=-1) 553 | else: 554 | x_l, x_r = x[0], x[1] 555 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' 556 | x_l = self.lin_l(x_l).view(-1, H, C) 557 | alpha_l = (x_l * self.att_l).sum(dim=-1) 558 | if x_r is not None: 559 | x_r = self.lin_r(x_r).view(-1, H, C) 560 | alpha_r = (x_r * self.att_r).sum(dim=-1) 561 | 562 | assert x_l is not None 563 | assert alpha_l is not None 564 | 565 | if self.add_self_loops: 566 | if isinstance(edge_index, Tensor): 567 | num_nodes = x_l.size(0) 568 | if x_r is not None: 569 | num_nodes = min(num_nodes, x_r.size(0)) 570 | if size is not None: 571 | num_nodes = min(size[0], size[1]) 572 | edge_index, _ = remove_self_loops(edge_index) 573 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) 574 | elif isinstance(edge_index, SparseTensor): 575 | edge_index = set_diag(edge_index) 576 | 577 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) 578 | if layer != None and (layer > 1): 579 | # propagate_type: (x: Tensor, edge_weight: OptTensor) 580 | L = math.floor(self.out_channels / layer) 581 | parts = [] 582 | for i in range(layer - 1): 583 | parts.append(L) 584 | parts.append(x_l.shape[-1] - L * (layer - 1)) 585 | parts = tuple(parts) 586 | xl = x_l.split(parts, dim=-1) 587 | xr = x_r.split(parts, dim=-1) 588 | out = self.propagate(edge_index, x=(xl[0], xr[0]), 589 | alpha=(alpha_l, alpha_r), size=size) 590 | for i in range(1, layer): 591 | out1 = self.propagate(edge_index, x=(xl[i], xr[i]), 592 | alpha=(alpha_l, alpha_r), size=size) 593 | out = torch.cat((out, out1), dim=-1) 594 | else: 595 | out = self.propagate(edge_index, x=(x_l, x_r), 596 | alpha=(alpha_l, alpha_r), size=size) 597 | alpha = self._alpha 598 | self._alpha = None 599 | 600 | if self.concat: 601 | out = out.view(-1, self.heads * self.out_channels) 602 | else: 603 | out = out.mean(dim=1) 604 | 605 | if self.bias is not None: 606 | out += self.bias 607 | 608 | if isinstance(return_attention_weights, bool): 609 | assert alpha is not None 610 | if isinstance(edge_index, Tensor): 611 | return out, (edge_index, alpha) 612 | elif isinstance(edge_index, SparseTensor): 613 | return out, edge_index.set_value(alpha, layout='coo') 614 | else: 615 | return out 616 | 617 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor, 618 | index: Tensor, ptr: OptTensor, 619 | size_i: Optional[int]) -> Tensor: 620 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i 621 | alpha = F.leaky_relu(alpha, self.negative_slope) 622 | alpha = softmax(alpha, index, ptr, size_i) 623 | self._alpha = alpha 624 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 625 | return x_j * alpha.unsqueeze(-1) 626 | 627 | def __repr__(self): 628 | return '{}({}, {}, heads={})'.format(self.__class__.__name__, 629 | self.in_channels, 630 | self.out_channels, self.heads) 631 | 632 | 633 | class GINConv_3d(MessagePassing): 634 | r"""The graph isomorphism operator from the `"How Powerful are 635 | Graph Neural Networks?" `_ paper 636 | 637 | .. math:: 638 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot 639 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) 640 | 641 | or 642 | 643 | .. math:: 644 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + 645 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), 646 | 647 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. 648 | 649 | Args: 650 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 651 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to 652 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by 653 | :class:`torch.nn.Sequential`. 654 | eps (float, optional): (Initial) :math:`\epsilon`-value. 655 | (default: :obj:`0.`) 656 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` 657 | will be a trainable parameter. (default: :obj:`False`) 658 | **kwargs (optional): Additional arguments of 659 | :class:`torch_geometric.nn.conv.MessagePassing`. 660 | """ 661 | 662 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 663 | **kwargs): 664 | kwargs.setdefault('aggr', 'add') 665 | super(GINConv_3d, self).__init__(**kwargs) 666 | self.nn = nn 667 | self.initial_eps = eps 668 | if train_eps: 669 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 670 | else: 671 | self.register_buffer('eps', torch.Tensor([eps])) 672 | self.reset_parameters() 673 | 674 | def reset_parameters(self): 675 | reset(self.nn) 676 | self.eps.data.fill_(self.initial_eps) 677 | 678 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 679 | size: Size = None, layer=None) -> Tensor: 680 | """""" 681 | 682 | # propagate_type: (x: OptPairTensor) 683 | # out = self.propagate(edge_index, x=x, size=size, layer=layer) 684 | 685 | if layer == None or layer == 0: 686 | if isinstance(x, Tensor): 687 | x: OptPairTensor = (x, x) 688 | out = self.propagate(edge_index, x=x, size=size) 689 | x_r = x[1] 690 | elif layer == 1: 691 | out = self.propagate(edge_index, x=x, size=size) 692 | x_r = x 693 | else: 694 | L = math.floor(x.shape[-1] / layer) 695 | parts = [] 696 | for i in range(layer - 1): 697 | parts.append(L) 698 | parts.append(x.shape[1] - L * (layer - 1)) 699 | parts = tuple(parts) 700 | xl = x.split(parts, dim=-1) 701 | out = self.propagate(edge_index, x=xl[0], size=size) 702 | for i in range(1, layer): 703 | out1 = self.propagate(edge_index, x=xl[i], size=size) 704 | out = torch.cat((out, out1), dim=-1) 705 | x_r = x 706 | 707 | if x_r is not None: 708 | out += (1 + self.eps) * x_r 709 | 710 | return self.nn(out) 711 | 712 | def message(self, x_j: Tensor) -> Tensor: 713 | return x_j 714 | 715 | def message_and_aggregate(self, adj_t: SparseTensor, 716 | x: OptPairTensor) -> Tensor: 717 | adj_t = adj_t.set_value(None, layout=None) 718 | return matmul(adj_t, x[0], reduce=self.aggr) 719 | 720 | def __repr__(self): 721 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 722 | 723 | 724 | class GINConv_3d_com(MessagePassing): 725 | r"""The graph isomorphism operator from the `"How Powerful are 726 | Graph Neural Networks?" `_ paper 727 | 728 | .. math:: 729 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot 730 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) 731 | 732 | or 733 | 734 | .. math:: 735 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + 736 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), 737 | 738 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. 739 | 740 | Args: 741 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that 742 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to 743 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by 744 | :class:`torch.nn.Sequential`. 745 | eps (float, optional): (Initial) :math:`\epsilon`-value. 746 | (default: :obj:`0.`) 747 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` 748 | will be a trainable parameter. (default: :obj:`False`) 749 | **kwargs (optional): Additional arguments of 750 | :class:`torch_geometric.nn.conv.MessagePassing`. 751 | """ 752 | 753 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, 754 | **kwargs): 755 | kwargs.setdefault('aggr', 'add') 756 | super(GINConv_3d_com, self).__init__(**kwargs) 757 | self.nn = nn 758 | self.initial_eps = eps 759 | if train_eps: 760 | self.eps = torch.nn.Parameter(torch.Tensor([eps])) 761 | else: 762 | self.register_buffer('eps', torch.Tensor([eps])) 763 | self.reset_parameters() 764 | 765 | def reset_parameters(self): 766 | reset(self.nn) 767 | self.eps.data.fill_(self.initial_eps) 768 | 769 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 770 | size: Size = None, layer=None) -> Tensor: 771 | """""" 772 | 773 | # propagate_type: (x: OptPairTensor) 774 | # out = self.propagate(edge_index, x=x, size=size, layer=layer) 775 | x = self.nn(x) 776 | if layer == None or layer == 0: 777 | if isinstance(x, Tensor): 778 | x: OptPairTensor = (x, x) 779 | out = self.propagate(edge_index, x=x, size=size) 780 | x_r = x[1] 781 | elif layer == 1: 782 | out = self.propagate(edge_index, x=x, size=size) 783 | x_r = x 784 | else: 785 | L = math.floor(x.shape[-1] / layer) 786 | parts = [] 787 | for i in range(layer - 1): 788 | parts.append(L) 789 | parts.append(x.shape[1] - L * (layer - 1)) 790 | parts = tuple(parts) 791 | xl = x.split(parts, dim=-1) 792 | out = self.propagate(edge_index, x=xl[0], size=size) 793 | for i in range(1, layer): 794 | out1 = self.propagate(edge_index, x=xl[i], size=size) 795 | out = torch.cat((out, out1), dim=-1) 796 | x_r = x 797 | 798 | if x_r is not None: 799 | out += (1 + self.eps) * x_r 800 | 801 | return out 802 | 803 | def message(self, x_j: Tensor) -> Tensor: 804 | return x_j 805 | 806 | def message_and_aggregate(self, adj_t: SparseTensor, 807 | x: OptPairTensor) -> Tensor: 808 | adj_t = adj_t.set_value(None, layout=None) 809 | return matmul(adj_t, x[0], reduce=self.aggr) 810 | 811 | def __repr__(self): 812 | return '{}(nn={})'.format(self.__class__.__name__, self.nn) 813 | -------------------------------------------------------------------------------- /model/inf_model.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import argparse 3 | from torch.nn import Sequential, Linear, ReLU 4 | import torch 5 | import torch.nn.functional as F 6 | from torch_geometric.datasets import Planetoid 7 | import torch_geometric.transforms as T 8 | from torch_geometric.nn import GCNConv, ChebConv # noqa 9 | import time 10 | from model.gnn_model_1_6 import GATConv_3d, GINConv_3d, GCNConv_3d, SAGEConv_3d, GINConv_3d_com, SAGEConv_3d_com 11 | 12 | 13 | class GCNNet(torch.nn.Module): 14 | def __init__(self, in_channels: int, out_channels: int, num_classes: int): 15 | super(GCNNet, self).__init__() 16 | self.conv1 = GCNConv_3d(in_channels, out_channels, cached=True, 17 | normalize=True) 18 | self.conv2 = GCNConv_3d(out_channels, num_classes, cached=True, 19 | normalize=True) 20 | 21 | def forward(self, x, edge_index, edge_attr=None, layer=(0, 0)): 22 | x = F.relu(self.conv1(x, edge_index, edge_attr, layer=layer[0])) 23 | x = F.dropout(x, training=self.training) 24 | x = self.conv2(x, edge_index, edge_attr, layer=layer[1]) 25 | return F.log_softmax(x, dim=1) 26 | 27 | 28 | class GATNet(torch.nn.Module): 29 | def __init__(self, in_channels: int, out_channels: int, num_classes: int): 30 | super(GATNet, self).__init__() 31 | self.conv1 = GATConv_3d(in_channels, out_channels, heads=1, dropout=0.6) 32 | # On the Pubmed dataset, use heads=8 in conv2. 33 | self.conv2 = GATConv_3d(out_channels, num_classes, heads=1, concat=False, 34 | dropout=0.6) 35 | 36 | def forward(self, x, edge_index, layer=(0, 0)): 37 | x = F.dropout(x, p=0.6, training=self.training) 38 | x = F.elu(self.conv1(x, edge_index, layer=layer[0])) 39 | x = F.dropout(x, p=0.6, training=self.training) 40 | x = self.conv2(x, edge_index, layer=layer[1]) 41 | 42 | return F.log_softmax(x, dim=1) 43 | 44 | 45 | class SAGE_agg1(torch.nn.Module): 46 | def __init__(self, in_channels, hidden_channels, num_classes: int): 47 | super(SAGE_agg1, self).__init__() 48 | self.conv1 = SAGEConv_3d(in_channels, hidden_channels) 49 | self.conv1 = SAGEConv_3d(hidden_channels, num_classes) 50 | 51 | def forward(self, x, edge_index, layer=(0, 0)): 52 | x = self.conv1(x, edge_index, layer=layer[0]) 53 | x = x.relu() 54 | x = F.dropout(x, p=0.5, training=self.training) 55 | x = self.conv2(x, edge_index, layer=layer[1]) 56 | 57 | return F.log_softmax(x, dim=1) 58 | 59 | 60 | class SAGE_com1(torch.nn.Module): 61 | def __init__(self, in_channels, hidden_channels, num_classes: int): 62 | super(SAGE_com1, self).__init__() 63 | self.conv1 = SAGEConv_3d_com(in_channels, hidden_channels) 64 | self.conv2 = SAGEConv_3d_com(hidden_channels, num_classes) 65 | 66 | def forward(self, x, edge_index, layer=(0, 0)): 67 | x = self.conv1(x, edge_index, layer=layer[0]) 68 | x = x.relu() 69 | x = F.dropout(x, p=0.5, training=self.training) 70 | x = self.conv2(x, edge_index, layer=layer[1]) 71 | 72 | return F.log_softmax(x, dim=1) 73 | 74 | 75 | class GIN_agg(torch.nn.Module): 76 | def __init__(self, in_channels, hidden_channels, num_classes: int): 77 | super(GIN_agg, self).__init__() 78 | 79 | num_features = in_channels 80 | dim = hidden_channels 81 | 82 | nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) 83 | self.conv1 = GINConv_3d(nn1) 84 | self.bn1 = torch.nn.BatchNorm1d(dim) 85 | 86 | nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 87 | self.conv2 = GINConv_3d(nn2) 88 | self.bn2 = torch.nn.BatchNorm1d(dim) 89 | 90 | nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 91 | self.conv3 = GINConv_3d(nn3) 92 | self.bn3 = torch.nn.BatchNorm1d(dim) 93 | 94 | nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 95 | self.conv4 = GINConv_3d(nn4) 96 | self.bn4 = torch.nn.BatchNorm1d(dim) 97 | 98 | nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 99 | self.conv5 = GINConv_3d(nn5) 100 | self.bn5 = torch.nn.BatchNorm1d(dim) 101 | 102 | self.fc1 = Linear(dim, dim) 103 | self.fc2 = Linear(dim, num_classes) 104 | 105 | def forward(self, x, edge_index, layer=(0, 0)): 106 | x = F.relu(self.conv1(x, edge_index, layer=layer[0])) 107 | x = self.bn1(x) 108 | x = F.relu(self.conv2(x, edge_index, layer=layer[1])) 109 | x = self.bn2(x) 110 | x = F.relu(self.conv3(x, edge_index, layer=layer[1])) 111 | x = self.bn3(x) 112 | x = F.relu(self.conv4(x, edge_index, layer=layer[1])) 113 | x = self.bn4(x) 114 | x = F.relu(self.conv5(x, edge_index, layer=layer[1])) 115 | x = self.bn5(x) 116 | x = F.relu(self.fc1(x)) 117 | x = F.dropout(x, p=0.5, training=self.training) 118 | x = self.fc2(x) 119 | return F.log_softmax(x, dim=-1) 120 | 121 | 122 | class GIN_com(torch.nn.Module): 123 | def __init__(self, in_channels, hidden_channels, num_classes: int): 124 | super(GIN_agg, self).__init__() 125 | 126 | num_features = in_channels 127 | dim = hidden_channels 128 | 129 | nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim)) 130 | self.conv1 = GINConv_3d_com(nn1) 131 | self.bn1 = torch.nn.BatchNorm1d(dim) 132 | 133 | nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 134 | self.conv2 = GINConv_3d_com(nn2) 135 | self.bn2 = torch.nn.BatchNorm1d(dim) 136 | 137 | nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 138 | self.conv3 = GINConv_3d_com(nn3) 139 | self.bn3 = torch.nn.BatchNorm1d(dim) 140 | 141 | nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 142 | self.conv4 = GINConv_3d_com(nn4) 143 | self.bn4 = torch.nn.BatchNorm1d(dim) 144 | 145 | nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim)) 146 | self.conv5 = GINConv_3d_com(nn5) 147 | self.bn5 = torch.nn.BatchNorm1d(dim) 148 | 149 | self.fc1 = Linear(dim, dim) 150 | self.fc2 = Linear(dim, num_classes) 151 | 152 | def forward(self, x, edge_index, layer=(0, 0)): 153 | x = F.relu(self.conv1(x, edge_index, layer=layer[0])) 154 | x = self.bn1(x) 155 | x = F.relu(self.conv2(x, edge_index, layer=layer[0])) 156 | x = self.bn2(x) 157 | x = F.relu(self.conv3(x, edge_index, layer=layer[0])) 158 | x = self.bn3(x) 159 | x = F.relu(self.conv4(x, edge_index, layer=layer[0])) 160 | x = self.bn4(x) 161 | x = F.relu(self.conv5(x, edge_index, layer=layer[1])) 162 | x = self.bn5(x) 163 | x = F.relu(self.fc1(x)) 164 | x = F.dropout(x, p=0.5, training=self.training) 165 | x = self.fc2(x) 166 | return F.log_softmax(x, dim=-1) 167 | -------------------------------------------------------------------------------- /model/inits.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def uniform(size, tensor): 7 | if tensor is not None: 8 | bound = 1.0 / math.sqrt(size) 9 | tensor.data.uniform_(-bound, bound) 10 | 11 | 12 | def kaiming_uniform(tensor, fan, a): 13 | if tensor is not None: 14 | bound = math.sqrt(6 / ((1 + a**2) * fan)) 15 | tensor.data.uniform_(-bound, bound) 16 | 17 | 18 | def glorot(tensor): 19 | if tensor is not None: 20 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) 21 | tensor.data.uniform_(-stdv, stdv) 22 | 23 | 24 | def glorot_orthogonal(tensor, scale): 25 | if tensor is not None: 26 | torch.nn.init.orthogonal_(tensor.data) 27 | scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var()) 28 | tensor.data *= scale.sqrt() 29 | 30 | 31 | def zeros(tensor): 32 | if tensor is not None: 33 | tensor.data.fill_(0) 34 | 35 | 36 | def ones(tensor): 37 | if tensor is not None: 38 | tensor.data.fill_(1) 39 | 40 | 41 | def normal(tensor, mean, std): 42 | if tensor is not None: 43 | tensor.data.normal_(mean, std) 44 | 45 | 46 | def reset(nn): 47 | def _reset(item): 48 | if hasattr(item, 'reset_parameters'): 49 | item.reset_parameters() 50 | 51 | if nn is not None: 52 | if hasattr(nn, 'children') and len(list(nn.children())) > 0: 53 | for item in nn.children(): 54 | _reset(item) 55 | else: 56 | _reset(nn) 57 | -------------------------------------------------------------------------------- /pyg1.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/result.png -------------------------------------------------------------------------------- /test/test_gnn_layer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import argparse 4 | import torch 5 | import sys 6 | # from memory_profiler import profile 7 | from torch.nn import Sequential 8 | from torch.nn import Sequential, Linear, ReLU 9 | 10 | sys.path.append("..") 11 | from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv, EdgeConv 12 | from model.gnn_model_1_6 import GATConv_3d, GINConv_3d, GCNConv_3d, SAGEConv_3d, GINConv_3d_com, SAGEConv_3d_com 13 | import torch_geometric.transforms as T 14 | from torch_geometric.datasets import Reddit 15 | from torch_geometric.datasets import Planetoid 16 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset 17 | 18 | parser = argparse.ArgumentParser(description='manual to this script') 19 | parser.add_argument("--gpu", type=int, default="0", help='0:use cpu 1:use gpu') 20 | parser.add_argument("--hidden", type=int, default=8, help='hidden number') 21 | parser.add_argument("--agg", type=str, default="gas", help='gas = scatter spmm = sparse mul') 22 | parser.add_argument("--m", type=str, default="GCN", help='model name: GCN,GIN,GAT,SAGE') 23 | parser.add_argument("--layer", type=int, default=0, help='layer represent the number of layers cutting along ' 24 | 'feature dimension, None = optpair x , 1 = single x') 25 | parser.add_argument("--order", type=str, default='agg', help='agg or com') 26 | parser.add_argument("--data", type=str, default='rd', help='cr = cora,pb=pubmed,cs = cisteer, pt=ogbn-proteins, rd=reddit') 27 | 28 | args = parser.parse_args() 29 | layer = args.layer 30 | hidden = args.hidden 31 | 32 | data_use = args.data 33 | path = '' 34 | dataset = '' 35 | if data_use == 'rd': 36 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', 'Reddit') 37 | if args.agg == "gas": 38 | dataset = Reddit(path) 39 | print('Reddit dataset, gas model!') 40 | else: 41 | dataset = Reddit(path, transform=T.ToSparseTensor()) 42 | print('Reddit dataset, spmm model') 43 | elif data_use == 'cr': 44 | dataset = 'Cora' 45 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 46 | if args.agg == "gas": 47 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 48 | print('Cora dataset, gas model!') 49 | else: 50 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 51 | print('Cora dataset, spmm model') 52 | elif data_use == 'pb': 53 | dataset = 'pubmed' 54 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 55 | if args.agg == "gas": 56 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 57 | print('Pumbed dataset, gas model!') 58 | else: 59 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 60 | print('Pumbed dataset, spmm model') 61 | elif data_use == 'cs': 62 | dataset = 'CiteSeer' 63 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 64 | if args.agg == "gas": 65 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 66 | print('Citeseer dataset, gas model!') 67 | else: 68 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 69 | print('Citeseer dataset, spmm model') 70 | elif data_use == 'pt': 71 | if args.agg == "gas": 72 | dataset = PygNodePropPredDataset(name='ogbn-proteins',root='../dataset') 73 | print('Ogbn-proteins dataset, gas model!') 74 | else: 75 | dataset = PygNodePropPredDataset(name='ogbn-proteins', root='../dataset',transform=T.ToSparseTensor()) 76 | print('Ogbn-proteins dataset, spmm model!') 77 | 78 | a = time.time() 79 | data = dataset[0] 80 | b = time.time() 81 | print('Dataset process time : ', b - a) 82 | 83 | device = torch.device('cuda' if args.gpu else 'cpu') 84 | model = '' 85 | 86 | if args.m == "GCN": 87 | print('using gcn') 88 | if data_use == 'pt': 89 | num_features = 1 90 | else: 91 | num_features = dataset.num_features 92 | model = GCNConv_3d(num_features, hidden) 93 | elif args.m == "GAT": 94 | print('using gat') 95 | if data_use == 'pt': 96 | num_features = 1 97 | else: 98 | num_features = dataset.num_features 99 | model = GATConv_3d(num_features, hidden) 100 | elif args.m == "SAGE": 101 | if args.order == 'agg': 102 | print('using 1st agg graphsage') 103 | if data_use == 'pt': 104 | num_features = 1 105 | else: 106 | num_features = dataset.num_features 107 | model = SAGEConv_3d(num_features, hidden) 108 | else: 109 | print('using 1st com graphsage') 110 | if data_use == 'pt': 111 | num_features = 1 112 | else: 113 | num_features = dataset.num_features 114 | model = SAGEConv_3d_com(num_features, hidden) 115 | elif args.m == "GIN": 116 | if args.order == 'agg': 117 | if data_use == 'pt': 118 | num_features = 1 119 | else: 120 | num_features = dataset.num_features 121 | nn1 = Sequential(Linear(num_features, args.hidden), ReLU(), Linear(args.hidden, args.hidden)) 122 | print('using 1st agg GIN , hid=', args.hidden) 123 | model = GINConv_3d(nn1) 124 | else: 125 | if data_use == 'pt': 126 | num_features = 1 127 | else: 128 | num_features = dataset.num_features 129 | nn1 = Sequential(Linear(num_features, args.hidden), ReLU(), Linear(args.hidden, args.hidden)) 130 | print('using 1st com GIN , hid=', args.hidden) 131 | model = GINConv_3d_com(nn1) 132 | else: 133 | print('error,model not exit') 134 | 135 | model = model.to(device) 136 | 137 | x = data.x.to(device) 138 | y = data.y.squeeze().to(device) 139 | if args.agg == "gas": 140 | edge_index = data.edge_index.to(device) 141 | else: 142 | edge_index = data.adj_t.to(device) 143 | 144 | 145 | print('using ', device) 146 | print('layers : ', args.layer) 147 | 148 | 149 | @torch.no_grad() 150 | 151 | def test(): 152 | model.eval() 153 | # with torch.autograd.profiler.profile(use_cuda=True if args.gpu else False, profile_memory=True) as prof: 154 | t0 = time.time() 155 | out = model(x, edge_index, layer=layer) 156 | t1 = time.time() 157 | #print(prof.key_averages().table(sort_by="self_cpu_time_total" if args.gpu == 0 else "self_cuda_time_total")) 158 | print('Inference time:', t1 - t0) 159 | print('Total time:', t1 - t0 + b - a) 160 | 161 | 162 | if __name__ == '__main__': 163 | test() 164 | -------------------------------------------------------------------------------- /test/test_gnn_total.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import argparse 4 | import torch 5 | import sys 6 | 7 | from torch.nn import Sequential 8 | from torch.nn import Sequential, Linear, ReLU 9 | 10 | sys.path.append("..") 11 | 12 | import torch_geometric.transforms as T 13 | from torch_geometric.datasets import Reddit 14 | from torch_geometric.datasets import Planetoid 15 | from model.inf_model import GCNNet, GATNet, SAGE_agg1, SAGE_com1, GIN_agg, GIN_com 16 | from torch_geometric.datasets import Planetoid 17 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset 18 | 19 | parser = argparse.ArgumentParser(description='manual to this script') 20 | parser.add_argument("--gpu", type=int, default="0", help='0:use cpu 1:use gpu') 21 | parser.add_argument("--hidden", type=int, default=8, help='hidden number') 22 | parser.add_argument("--agg", type=str, default="gas", help='gas=scatter spmm=sparse mul') 23 | parser.add_argument("--m", type=str, default="GCN", help='model name: GCN,GIN,GAT,SAGE') 24 | parser.add_argument("--layer", type=int, default=0, help='layer represent the number of layers cutting along ' 25 | 'feature dimension, None = optpair x , 1 = single x') 26 | parser.add_argument("--order", type=str, default='com', help='agg first or com first') 27 | parser.add_argument("--data", type=str, default='rd', help='cr = cora,pb=pubmed,cs = cisteer, pt=ogbn-proteins, rd=reddit') 28 | args = parser.parse_args() 29 | hidden = args.hidden 30 | layer = args.layer 31 | 32 | a = layer.split(',') 33 | a0 = (int)(a[0]) 34 | a1 = (int)(a[1]) 35 | layer = [a0, a1] 36 | data_use = args.data 37 | 38 | path = '' 39 | dataset = '' 40 | if data_use == 'rd': 41 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', 'Reddit') 42 | if args.agg == "gas": 43 | dataset = Reddit(path) 44 | print('Reddit dataset, gas model!') 45 | else: 46 | dataset = Reddit(path, transform=T.ToSparseTensor()) 47 | print('Reddit dataset, spmm model') 48 | elif data_use == 'cr': 49 | dataset = 'Cora' 50 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 51 | if args.agg == "gas": 52 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 53 | print('Cora dataset, gas model!') 54 | else: 55 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 56 | print('Cora dataset, spmm model') 57 | elif data_use == 'pb': 58 | dataset = 'pubmed' 59 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 60 | if args.agg == "gas": 61 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 62 | print('Pumbed dataset, gas model!') 63 | else: 64 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 65 | print('Pumbed dataset, spmm model') 66 | elif data_use == 'cs': 67 | dataset = 'CiteSeer' 68 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset') 69 | if args.agg == "gas": 70 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) 71 | print('Citeseer dataset, gas model !') 72 | else: 73 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor()) 74 | print('Citeseer dataset, spmm model') 75 | elif data_use == 'pt': 76 | if args.agg == "gas": 77 | dataset = PygNodePropPredDataset(name='ogbn-proteins') 78 | print('Ogbn-proteins dataset, gas model!') 79 | else: 80 | dataset = PygNodePropPredDataset(name='ogbn-proteins', transform=T.ToSparseTensor()) 81 | print('Ogbn-proteins dataset, spmm model!') 82 | 83 | a = time.time() 84 | data = dataset[0] 85 | 86 | b = time.time() 87 | print('Dataset process time : ', b - a) 88 | 89 | device = torch.device('cuda' if args.gpu else 'cpu') 90 | 91 | model = '' 92 | if args.m == "GCN": 93 | print('using gcn') 94 | if data_use == 'pt': 95 | num_features = 1 96 | else: 97 | num_features = dataset.num_features 98 | model = GCNNet(num_features, hidden, dataset.num_classes) 99 | elif args.m == "GAT": 100 | print('using gat') 101 | if data_use == 'pt': 102 | num_features = 1 103 | else: 104 | num_features = dataset.num_features 105 | model = GATNet(num_features, hidden, dataset.num_classes) 106 | elif args.m == "SAGE": 107 | if args.order == 'agg': 108 | print('using 1st agg sage') 109 | if data_use == 'pt': 110 | num_features = 1 111 | else: 112 | num_features = dataset.num_features 113 | model = SAGE_agg1(num_features, hidden, dataset.num_classes) 114 | else: 115 | print('using 1st com sage') 116 | if data_use == 'pt': 117 | num_features = 1 118 | else: 119 | num_features = dataset.num_features 120 | model = SAGE_com1(num_features, hidden, dataset.num_classes) 121 | elif args.m == "GIN": 122 | if args.order == 'agg': 123 | print('using 1st agg GIN,hid=', args.hidden) 124 | if data_use == 'pt': 125 | num_features = 1 126 | else: 127 | num_features = dataset.num_features 128 | model = GIN_agg(num_features, hidden, dataset.num_classes) 129 | else: 130 | print('using 1st com GIN,hid=', args.hidden) 131 | if data_use == 'pt': 132 | num_features = 1 133 | else: 134 | num_features = dataset.num_features 135 | model = GIN_com(num_features, hidden, dataset.num_classes) 136 | else: 137 | print('error,model not exit') 138 | 139 | model = model.to(device) 140 | 141 | 142 | x = data.x.to(device) 143 | y = data.y.squeeze().to(device) 144 | 145 | if args.agg == "gas": 146 | edge_index = data.edge_index.to(device) 147 | else: 148 | edge_index = data.adj_t.to(device) 149 | 150 | 151 | print('using ', device) 152 | print('layers : ', args.layer) 153 | 154 | @torch.no_grad() 155 | def test(): 156 | model.eval() 157 | 158 | # with torch.autograd.profiler.profile(use_cuda=True if args.gpu else False, profile_memory=True) as prof: 159 | t0 = time.time() 160 | out = model(x, edge_index, layer=layer) 161 | t1 = time.time() 162 | # print(prof.key_averages().table(sort_by="self_cpu_time_total" if args.gpu == 0 else "self_cuda_time_total")) 163 | print('Inference time:', t1 - t0) 164 | print('Total time:', t1 - t0 + b - a) 165 | 166 | 167 | if __name__ == '__main__': 168 | test() 169 | --------------------------------------------------------------------------------