├── README.md
├── dataset
├── CiteSeer
│ ├── processed
│ │ └── .gitignore
│ └── raw
│ │ └── .gitignore
├── Cora
│ ├── .gitignore
│ ├── processed
│ │ └── .gitignore
│ └── raw
│ │ └── .gitignore
├── Reddit
│ ├── processed
│ │ └── .gitignore
│ └── raw
│ │ └── .gitignore
├── ogbn_proteins
│ ├── .gitignore
│ ├── mapping
│ │ └── .gitignore
│ ├── processed
│ │ └── .gitignore
│ ├── raw
│ │ └── .gitignore
│ └── split
│ │ └── species
│ │ └── .gitignore
└── pubmed
│ ├── .gitignore
│ ├── processed
│ └── .gitignore
│ └── raw
│ └── .gitignore
├── example
└── GCN.py
├── model
├── __pycache__
│ ├── gnn_model_1_6.cpython-37.pyc
│ ├── gnn_model_1_6.cpython-38.pyc
│ ├── inf_model.cpython-38.pyc
│ └── inits.cpython-38.pyc
├── gnn_model_1_6.py
├── inf_model.py
└── inits.py
├── pyg1.svg
├── result.png
└── test
├── test_gnn_layer.py
└── test_gnn_total.py
/README.md:
--------------------------------------------------------------------------------
1 | # GNN-Feature-Decomposition
2 |
3 |
4 | ### ***NOTE!***
5 | After our cooperation with the [PyG team](https://github.com/pyg-team/pytorch_geometric), the feature decomposition method has been improved and added
6 | to the [PyG framework](https://github.com/pyg-team/pytorch_geometric) as an optional feature. For more details and how to use please refer to
7 | [Feature Decomposition in PyG](#Feature-Decomposition-in-PyG)
8 |
9 |
10 |
11 |
12 |
13 | --------------------------------------------------------------------------------
14 |
15 |
16 | ### Overview
17 |
18 | This is a repository for our work: GNN Feature Decomposition,
19 | which is accepted by RTAS 2021(Brif Industry Track), named ***"Optimizing Memory Efficiency of Graph NeuralNetworks on Edge Computing Platforms"***
20 |
21 | For more details, please see our full paper: https://arxiv.org/abs/2104.03058
22 |
23 | Graph neural networks (GNN) have achieved state-of-the-art performance on various industrial tasks.
24 | However, the poor efficiency of GNN inference and frequent Out-Of-Memory (OOM) problem limit the successful application of GNN on edge computing platforms.
25 | To tackle these problems, a feature decomposition approach is proposed for memory efficiency optimization of GNN inference.
26 | The proposed approach could achieve outstanding optimization on various GNN models, covering a wide range of datasets, which speeds up the inference by up to 3x.
27 | Furthermore, the proposed feature decomposition could significantly reduce the peak memory usage (up to 5x in memory efficiency improvement) and mitigate OOM problems during GNN inference.
28 |
29 | ### Requirements
30 |
31 | Recent versions of PyTorch, numpy, torch_geometric(1.6.3) are required.
32 |
33 |
34 | ### Contents
35 | There are two main top-level scripts in this repo:
36 |
37 | 1.test_gnn_layer.py: runs a gnn feature decomposition method on single GNN layer.
38 | 2.test_gnn_total.py: runs a gnn feature decomposition method on total gnn models.
39 |
40 | ### Running the code
41 | #### test single gnn layer by our feature decomposition method.
42 | cd test
43 | python test_gnn_layer.py --hidden=32 --agg="gas" --m="GCN" --layer=32 --data="rd"
44 |
45 | #### test total gnn model by our feature decomposition method.
46 | cd test
47 | python test_gnn_total.py --hidden=32 --agg="gas" --m="GCN" --layer="32,41" --data="rd"
48 |
49 | - hidden: the hidden layer size of gnn.
50 | - agg: the aggregate model, include spmm and gas. if using feature decomposition, there shoule be "gas".
51 | - m: the gnn model name,include "GCN,GAT,GIN,SAGE".
52 | - layer: the layers of feature decomposition along dimension of feature vector, the basic gnn inference using 1 layer.
53 | if test total gnn model,there should be two parameters.
54 | - data: dataset name.
55 |
56 |
57 | ### Feature Decomposition in PyG
58 |
59 | ---
60 |
61 | We integrated the feature decomposition into the [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py)
62 | module of the [PyG framework](https://github.com/pyg-team/pytorch_geometric). Specifically, we added an optional argument `decomposed_layers: int = 1` to the initialization
63 | function of the [`MessagePassing`](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py) module, as shown below.
64 |
65 |
66 | def __init__(self, aggr: Optional[str] = "add",
67 | flow: str = "source_to_target", node_dim: int = -2,
68 | decomposed_layers: int = 1):
69 |
70 | When creating a layer, pass in the `decomposed_layers` (>1) to use the feature decomposition method. As fllows:
71 |
72 | conv = GCNConv(16, 32, decomoposed_layers = 2 )
73 |
74 | For specific usage, please refer to the [example](example/GCN.py)
75 |
76 | ---
77 | The following table is the test result of the [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv) layer on the reddit data set.
78 | It can be seen that the best acceleration effect can be achieved when the decomposition granularity is the largest (Not mean that all layers are like this).
79 | The only rule that we can be sure of is: For a given graph, there is an optimal dimension of one decomposed layer, which does not change with the hidden layer dimension
80 | (eg: as the following table, it is optimal to set the dimension of the decomposed layer to 1, which means `decomoposed_layers = hidden layers` ).
81 |
82 | Test machine: Intel Core 8700k CPU, Ubuntu system.
83 |
84 |
85 |
86 |
87 |
88 | - The horizontal axis : represents the hidden layer dimension of [GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv).
89 | - the vertical axis : equal to the `decomposed_layers`, which means the granularity of feature decomposition.
90 | - total : means the running time of whole layer.
91 | - agg : means the running time of Aggregation phase of GCNConv.
92 |
93 |
--------------------------------------------------------------------------------
/dataset/CiteSeer/processed/.gitignore:
--------------------------------------------------------------------------------
1 | /data.pt
2 | /pre_filter.pt
3 | /pre_transform.pt
4 |
--------------------------------------------------------------------------------
/dataset/CiteSeer/raw/.gitignore:
--------------------------------------------------------------------------------
1 | /ind.citeseer.allx
2 | /ind.citeseer.ally
3 | /ind.citeseer.graph
4 | /ind.citeseer.test.index
5 | /ind.citeseer.tx
6 | /ind.citeseer.ty
7 | /ind.citeseer.x
8 | /ind.citeseer.y
9 |
--------------------------------------------------------------------------------
/dataset/Cora/.gitignore:
--------------------------------------------------------------------------------
1 | /processed_cora.pkl
2 |
--------------------------------------------------------------------------------
/dataset/Cora/processed/.gitignore:
--------------------------------------------------------------------------------
1 | /data.pt
2 | /pre_filter.pt
3 | /pre_transform.pt
4 |
--------------------------------------------------------------------------------
/dataset/Cora/raw/.gitignore:
--------------------------------------------------------------------------------
1 | /ind.cora.allx
2 | /ind.cora.ally
3 | /ind.cora.graph
4 | /ind.cora.test.index
5 | /ind.cora.tx
6 | /ind.cora.ty
7 | /ind.cora.x
8 | /ind.cora.y
9 |
--------------------------------------------------------------------------------
/dataset/Reddit/processed/.gitignore:
--------------------------------------------------------------------------------
1 | /data.pt
2 | /pre_filter.pt
3 | /pre_transform.pt
4 |
--------------------------------------------------------------------------------
/dataset/Reddit/raw/.gitignore:
--------------------------------------------------------------------------------
1 | /reddit_data.npz
2 | /reddit_graph.npz
3 |
--------------------------------------------------------------------------------
/dataset/ogbn_proteins/.gitignore:
--------------------------------------------------------------------------------
1 | /RELEASE_v1.txt
2 |
--------------------------------------------------------------------------------
/dataset/ogbn_proteins/mapping/.gitignore:
--------------------------------------------------------------------------------
1 | /README.md
2 | /labelidx2GO.csv.gz
3 | /nodeidx2proteinid.csv.gz
4 |
--------------------------------------------------------------------------------
/dataset/ogbn_proteins/processed/.gitignore:
--------------------------------------------------------------------------------
1 | /geometric_data_processed.pt
2 | /pre_filter.pt
3 | /pre_transform.pt
4 |
--------------------------------------------------------------------------------
/dataset/ogbn_proteins/raw/.gitignore:
--------------------------------------------------------------------------------
1 | /edge-feat.csv.gz
2 | /edge.csv.gz
3 | /node-label.csv.gz
4 | /node_species.csv.gz
5 | /num-edge-list.csv.gz
6 | /num-node-list.csv.gz
7 |
--------------------------------------------------------------------------------
/dataset/ogbn_proteins/split/species/.gitignore:
--------------------------------------------------------------------------------
1 | /test.csv.gz
2 | /train.csv.gz
3 | /valid.csv.gz
4 |
--------------------------------------------------------------------------------
/dataset/pubmed/.gitignore:
--------------------------------------------------------------------------------
1 | /processed_pubmed.pkl
2 |
--------------------------------------------------------------------------------
/dataset/pubmed/processed/.gitignore:
--------------------------------------------------------------------------------
1 | /data.pt
2 | /pre_filter.pt
3 | /pre_transform.pt
4 |
--------------------------------------------------------------------------------
/dataset/pubmed/raw/.gitignore:
--------------------------------------------------------------------------------
1 | /ind.pubmed.allx
2 | /ind.pubmed.ally
3 | /ind.pubmed.graph
4 | /ind.pubmed.test.index
5 | /ind.pubmed.tx
6 | /ind.pubmed.ty
7 | /ind.pubmed.x
8 | /ind.pubmed.y
9 |
--------------------------------------------------------------------------------
/example/GCN.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import argparse
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_geometric.datasets import Planetoid
7 | import torch_geometric.transforms as T
8 | from torch_geometric.nn import GCNConv, ChebConv # noqa
9 | from torch_geometric.datasets import Reddit
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--use_gdc', action='store_true',
12 | help='Use GDC preprocessing.')
13 | args = parser.parse_args()
14 |
15 | dataset = 'Reddit'
16 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', dataset)
17 | dataset = Reddit(path,transform=T.NormalizeFeatures())
18 | data = dataset[0]
19 |
20 | if args.use_gdc:
21 | gdc = T.GDC(self_loop_weight=1, normalization_in='sym',
22 | normalization_out='col',
23 | diffusion_kwargs=dict(method='ppr', alpha=0.05),
24 | sparsification_kwargs=dict(method='topk', k=128,
25 | dim=0), exact=True)
26 | data = gdc(data)
27 |
28 |
29 | class Net(torch.nn.Module):
30 | def __init__(self):
31 | super(Net, self).__init__()
32 | # using feature decomposition method, and set the decomposed layers = 2
33 | self.conv1 = GCNConv(dataset.num_features, 16, cached=True,
34 | normalize=not args.use_gdc, decomoposed_layers = 2)
35 | self.conv2 = GCNConv(16, dataset.num_classes, cached=True,
36 | normalize=not args.use_gdc, decomoposed_layers = 2)
37 | # self.conv1 = ChebConv(data.num_features, 16, K=2)
38 | # self.conv2 = ChebConv(16, data.num_features, K=2)
39 |
40 | def forward(self):
41 | x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr
42 | x = F.relu(self.conv1(x, edge_index, edge_weight))
43 | x = F.dropout(x, training=self.training)
44 | x = self.conv2(x, edge_index, edge_weight)
45 | return F.log_softmax(x, dim=1)
46 |
47 |
48 | device = torch.device('cpu')
49 | model, data = Net().to(device), data.to(device)
50 | optimizer = torch.optim.Adam([
51 | dict(params=model.conv1.parameters(), weight_decay=5e-4),
52 | dict(params=model.conv2.parameters(), weight_decay=0)
53 | ], lr=0.01) # Only perform weight-decay on first convolution.
54 |
55 |
56 | def train():
57 | model.train()
58 | optimizer.zero_grad()
59 | F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
60 | optimizer.step()
61 |
62 |
63 | @torch.no_grad()
64 | def test():
65 | model.eval()
66 | logits, accs = model(), []
67 | for _, mask in data('train_mask', 'val_mask', 'test_mask'):
68 | pred = logits[mask].max(1)[1]
69 | acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
70 | accs.append(acc)
71 | return accs
72 |
73 |
74 | best_val_acc = test_acc = 0
75 | for epoch in range(1, 201):
76 | train()
77 | train_acc, val_acc, tmp_test_acc = test()
78 | if val_acc > best_val_acc:
79 | best_val_acc = val_acc
80 | test_acc = tmp_test_acc
81 | log = 'Epoch: {:03d}, Train: {:.4f}, Val: {:.4f}, Test: {:.4f}'
82 | print(log.format(epoch, train_acc, best_val_acc, test_acc))
--------------------------------------------------------------------------------
/model/__pycache__/gnn_model_1_6.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/gnn_model_1_6.cpython-37.pyc
--------------------------------------------------------------------------------
/model/__pycache__/gnn_model_1_6.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/gnn_model_1_6.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/inf_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/inf_model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/inits.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/model/__pycache__/inits.cpython-38.pyc
--------------------------------------------------------------------------------
/model/gnn_model_1_6.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch_sparse import SparseTensor, matmul
4 |
5 | from typing import Union, Tuple, Optional
6 | from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
7 | OptTensor)
8 | import torch
9 | from torch import Tensor
10 | import torch.nn.functional as F
11 | from torch.nn import Parameter, Linear
12 | from torch_sparse import SparseTensor, set_diag
13 | from torch_geometric.nn.conv import MessagePassing
14 | from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
15 |
16 | from .inits import glorot, zeros, reset
17 | # 基于1.6的pyg的各gnn模型修改
18 | from typing import Optional, Tuple
19 | from torch_geometric.typing import Adj, OptTensor, PairTensor
20 |
21 | import torch
22 | from torch import Tensor
23 | from torch.nn import Parameter
24 | from torch_scatter import scatter_add
25 | from torch_sparse import SparseTensor, matmul, fill_diag, sum, mul
26 | from torch_geometric.nn.conv import MessagePassing
27 | from torch_geometric.utils import add_remaining_self_loops
28 | from torch_geometric.utils.num_nodes import maybe_num_nodes
29 |
30 | from typing import Callable, Union
31 | from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
32 |
33 | import torch
34 | from torch import Tensor
35 | import torch.nn.functional as F
36 | from torch_sparse import SparseTensor, matmul
37 | from torch_geometric.nn.conv import MessagePassing
38 |
39 |
40 | @torch.jit._overload
41 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
42 | add_self_loops=True, dtype=None):
43 | # type: (Tensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> PairTensor # noqa
44 | pass
45 |
46 |
47 | @torch.jit._overload
48 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
49 | add_self_loops=True, dtype=None):
50 | # type: (SparseTensor, OptTensor, Optional[int], bool, bool, Optional[int]) -> SparseTensor # noqa
51 | pass
52 |
53 |
54 | def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
55 | add_self_loops=True, dtype=None):
56 | fill_value = 2. if improved else 1.
57 |
58 | if isinstance(edge_index, SparseTensor):
59 | adj_t = edge_index
60 | if not adj_t.has_value():
61 | adj_t = adj_t.fill_value(1., dtype=dtype)
62 | if add_self_loops:
63 | adj_t = fill_diag(adj_t, fill_value)
64 | deg = sum(adj_t, dim=1)
65 |
66 | deg_inv_sqrt = deg.pow_(-0.5)
67 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
68 | adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
69 | adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
70 | return adj_t
71 |
72 | else:
73 | num_nodes = maybe_num_nodes(edge_index, num_nodes)
74 |
75 | if edge_weight is None:
76 | edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
77 | device=edge_index.device)
78 |
79 | if add_self_loops:
80 | edge_index, tmp_edge_weight = add_remaining_self_loops(
81 | edge_index, edge_weight, fill_value, num_nodes)
82 | assert tmp_edge_weight is not None
83 | edge_weight = tmp_edge_weight
84 |
85 | row, col = edge_index[0], edge_index[1]
86 | deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
87 | deg_inv_sqrt = deg.pow_(-0.5)
88 | deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
89 | return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
90 |
91 |
92 | class SAGEConv_3d(MessagePassing):
93 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on
94 | Large Graphs" `_ paper
95 |
96 | .. math::
97 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot
98 | \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j
99 |
100 | Args:
101 | in_channels (int or tuple): Size of each input sample. A tuple
102 | corresponds to the sizes of source and target dimensionalities.
103 | out_channels (int): Size of each output sample.
104 | normalize (bool, optional): If set to :obj:`True`, output features
105 | will be :math:`\ell_2`-normalized, *i.e.*,
106 | :math:`\frac{\mathbf{x}^{\prime}_i}
107 | {\| \mathbf{x}^{\prime}_i \|_2}`.
108 | (default: :obj:`False`)
109 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
110 | an additive bias. (default: :obj:`True`)
111 | **kwargs (optional): Additional arguments of
112 | :class:`torch_geometric.nn.conv.MessagePassing`.
113 | """
114 |
115 | def __init__(self, in_channels: Union[int, Tuple[int, int]],
116 | out_channels: int, normalize: bool = False,
117 | bias: bool = True, **kwargs): # yapf: disable
118 | super(SAGEConv_3d, self).__init__(aggr='mean', **kwargs)
119 |
120 | self.in_channels = in_channels
121 | self.out_channels = out_channels
122 | self.normalize = normalize
123 |
124 | if isinstance(in_channels, int):
125 | in_channels = (in_channels, in_channels)
126 |
127 | self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
128 | self.lin_r = Linear(in_channels[1], out_channels, bias=False)
129 |
130 | self.reset_parameters()
131 |
132 | def reset_parameters(self):
133 | self.lin_l.reset_parameters()
134 | self.lin_r.reset_parameters()
135 |
136 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
137 | size: Size = None, layer=None) -> Tensor:
138 | """"""
139 | if layer == None or layer == 0:
140 | x_r = x
141 | if isinstance(x, Tensor):
142 | x: OptPairTensor = (x, x)
143 | out = self.propagate(edge_index, x=x, size=size, layer=None)
144 | out = self.lin_l(out)
145 | # propagate_type: (x: OptPairTensor)
146 | elif layer > 1:
147 | x_r = x
148 | L = math.floor(self.out_channels / layer)
149 | parts = []
150 | for i in range(layer - 1):
151 | parts.append(L)
152 | parts.append(x.shape[-1] - L * (layer - 1))
153 | parts = tuple(parts)
154 | s_x = x.split(parts, dim=-1)
155 | # s_x_p: OptPairTensor = (s_x[0], s_x[0])
156 | out = self.propagate(edge_index, x=s_x[0], size=size)
157 | for i in range(1, layer):
158 | # s_x_p: OptPairTensor = (s_x[i], s_x[i])
159 | out_l = self.propagate(edge_index, x=s_x[i], size=size)
160 | out = torch.cat((out, out_l), dim=-1)
161 | out = self.lin_l(out)
162 | else:
163 | x_r = x
164 | out = self.propagate(edge_index, x=x, size=size, layer=layer)
165 | out = self.lin_l(out)
166 | if x_r is not None:
167 | out += self.lin_r(x_r)
168 |
169 | if self.normalize:
170 | out = F.normalize(out, p=2., dim=-1)
171 |
172 | return out
173 |
174 | def message(self, x_j: Tensor) -> Tensor:
175 | return x_j
176 |
177 | def message_and_aggregate(self, adj_t: SparseTensor,
178 | x: OptPairTensor) -> Tensor:
179 | adj_t = adj_t.set_value(None, layout=None)
180 | return matmul(adj_t, x[0], reduce=self.aggr)
181 |
182 | def __repr__(self):
183 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
184 | self.out_channels)
185 |
186 |
187 | class SAGEConv_3d_com(MessagePassing):
188 | r"""The GraphSAGE operator from the `"Inductive Representation Learning on
189 | Large Graphs" `_ paper
190 |
191 | .. math::
192 | \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot
193 | \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j
194 |
195 | Args:
196 | in_channels (int or tuple): Size of each input sample. A tuple
197 | corresponds to the sizes of source and target dimensionalities.
198 | out_channels (int): Size of each output sample.
199 | normalize (bool, optional): If set to :obj:`True`, output features
200 | will be :math:`\ell_2`-normalized, *i.e.*,
201 | :math:`\frac{\mathbf{x}^{\prime}_i}
202 | {\| \mathbf{x}^{\prime}_i \|_2}`.
203 | (default: :obj:`False`)
204 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
205 | an additive bias. (default: :obj:`True`)
206 | **kwargs (optional): Additional arguments of
207 | :class:`torch_geometric.nn.conv.MessagePassing`.
208 | """
209 |
210 | def __init__(self, in_channels: Union[int, Tuple[int, int]],
211 | out_channels: int, normalize: bool = False,
212 | bias: bool = True, **kwargs): # yapf: disable
213 | super(SAGEConv_3d_com, self).__init__(aggr='mean', **kwargs)
214 |
215 | self.in_channels = in_channels
216 | self.out_channels = out_channels
217 | self.normalize = normalize
218 |
219 | if isinstance(in_channels, int):
220 | in_channels = (in_channels, in_channels)
221 |
222 | self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
223 | self.lin_r = Linear(in_channels[1], out_channels, bias=False)
224 |
225 | self.reset_parameters()
226 |
227 | def reset_parameters(self):
228 | self.lin_l.reset_parameters()
229 | self.lin_r.reset_parameters()
230 |
231 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
232 | size: Size = None, layer=None) -> Tensor:
233 | """"""
234 | if layer == None or layer == 0:
235 | x_r = x
236 | x = self.lin_l(x)
237 | if isinstance(x, Tensor):
238 | x: OptPairTensor = (x, x)
239 | out = self.propagate(edge_index, x=x, size=size, layer=layer)
240 |
241 |
242 | # propagate_type: (x: OptPairTensor)
243 | elif layer > 1:
244 | x_r = x
245 | x = self.lin_l(x)
246 | L = math.floor(self.out_channels / layer)
247 | parts = []
248 | for i in range(layer - 1):
249 | parts.append(L)
250 | parts.append(x.shape[-1] - L * (layer - 1))
251 | parts = tuple(parts)
252 | s_x = x.split(parts, dim=-1)
253 | # s_x_p: OptPairTensor = (s_x[0], s_x[0])
254 | out = self.propagate(edge_index, x=s_x[0], size=size)
255 |
256 | for i in range(1, layer):
257 | # s_x_p: OptPairTensor = (s_x[i], s_x[i])
258 | out_l = self.propagate(edge_index, x=s_x[i], size=size)
259 | out = torch.cat((out, out_l), dim=-1)
260 | # out = self.lin_l(out)
261 | # x_r = x
262 | else:
263 | x_r = x
264 | x = self.lin_l(x)
265 | out = self.propagate(edge_index, x=x, size=size, layer=layer)
266 |
267 | if x_r is not None:
268 | out += self.lin_r(x_r)
269 |
270 | if self.normalize:
271 | out = F.normalize(out, p=2., dim=-1)
272 |
273 | return out
274 |
275 | def message(self, x_j: Tensor) -> Tensor:
276 | return x_j
277 |
278 | def message_and_aggregate(self, adj_t: SparseTensor,
279 | x: OptPairTensor) -> Tensor:
280 | adj_t = adj_t.set_value(None, layout=None)
281 | return matmul(adj_t, x[0], reduce=self.aggr)
282 |
283 | def __repr__(self):
284 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
285 | self.out_channels)
286 |
287 |
288 | class GCNConv_3d(MessagePassing):
289 | r"""The graph convolutional operator from the `"Semi-supervised
290 | Classification with Graph Convolutional Networks"
291 | `_ paper
292 |
293 | .. math::
294 | \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
295 | \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta},
296 |
297 | where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
298 | adjacency matrix with inserted self-loops and
299 | :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
300 | The adjacency matrix can include other values than :obj:`1` representing
301 | edge weights via the optional :obj:`edge_weight` tensor.
302 |
303 | Its node-wise formulation is given by:
304 |
305 | .. math::
306 | \mathbf{x}^{\prime}_i = \mathbf{\Theta} \sum_{j}
307 | \frac{1}{\sqrt{\hat{d}_j \hat{d}_i}} \mathbf{x}_j
308 |
309 | with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where
310 | :math:`e_{j,i}` denotes the edge weight from source node :obj:`i` to target
311 | node :obj:`j` (default: :obj:`1`)
312 |
313 | Args:
314 | in_channels (int): Size of each input sample.
315 | out_channels (int): Size of each output sample.
316 | improved (bool, optional): If set to :obj:`True`, the layer computes
317 | :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
318 | (default: :obj:`False`)
319 | cached (bool, optional): If set to :obj:`True`, the layer will cache
320 | the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
321 | \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
322 | cached version for further executions.
323 | This parameter should only be set to :obj:`True` in transductive
324 | learning scenarios. (default: :obj:`False`)
325 | add_self_loops (bool, optional): If set to :obj:`False`, will not add
326 | self-loops to the input graph. (default: :obj:`True`)
327 | normalize (bool, optional): Whether to add self-loops and apply
328 | symmetric normalization. (default: :obj:`True`)
329 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
330 | an additive bias. (default: :obj:`True`)
331 | **kwargs (optional): Additional arguments of
332 | :class:`torch_geometric.nn.conv.MessagePassing`.
333 | """
334 |
335 | _cached_edge_index: Optional[Tuple[Tensor, Tensor]]
336 | _cached_adj_t: Optional[SparseTensor]
337 |
338 | def __init__(self, in_channels: int, out_channels: int,
339 | improved: bool = False, cached: bool = False,
340 | add_self_loops: bool = True, normalize: bool = True,
341 | bias: bool = True, **kwargs):
342 |
343 | kwargs.setdefault('aggr', 'add')
344 | super(GCNConv_3d, self).__init__(**kwargs)
345 |
346 | self.in_channels = in_channels
347 | self.out_channels = out_channels
348 | self.improved = improved
349 | self.cached = cached
350 | self.add_self_loops = add_self_loops
351 | self.normalize = normalize
352 |
353 | self._cached_edge_index = None
354 | self._cached_adj_t = None
355 |
356 | self.weight = Parameter(torch.Tensor(in_channels, out_channels))
357 |
358 | if bias:
359 | self.bias = Parameter(torch.Tensor(out_channels))
360 | else:
361 | self.register_parameter('bias', None)
362 |
363 | self.reset_parameters()
364 |
365 | def reset_parameters(self):
366 | glorot(self.weight)
367 | zeros(self.bias)
368 | self._cached_edge_index = None
369 | self._cached_adj_t = None
370 |
371 | def forward(self, x: Tensor, edge_index: Adj,
372 | edge_weight: OptTensor = None, layer=1) -> Tensor:
373 | """"""
374 |
375 | if self.normalize:
376 | if isinstance(edge_index, Tensor):
377 | cache = self._cached_edge_index
378 | if cache is None:
379 | edge_index, edge_weight = gcn_norm( # yapf: disable
380 | edge_index, edge_weight, x.size(self.node_dim),
381 | self.improved, self.add_self_loops, dtype=x.dtype)
382 | if self.cached:
383 | self._cached_edge_index = (edge_index, edge_weight)
384 | else:
385 | edge_index, edge_weight = cache[0], cache[1]
386 |
387 | elif isinstance(edge_index, SparseTensor):
388 | cache = self._cached_adj_t
389 | if cache is None:
390 | edge_index = gcn_norm( # yapf: disable
391 | edge_index, edge_weight, x.size(self.node_dim),
392 | self.improved, self.add_self_loops, dtype=x.dtype)
393 | if self.cached:
394 | self._cached_adj_t = edge_index
395 | else:
396 | edge_index = cache
397 |
398 | x = torch.matmul(x, self.weight)
399 | if layer != None and layer > 1:
400 | # propagate_type: (x: Tensor, edge_weight: OptTensor)
401 | L = math.floor(self.out_channels / layer)
402 | parts = []
403 | for i in range(layer - 1):
404 | parts.append(L)
405 | parts.append(x.shape[1] - L * (layer - 1))
406 | parts = tuple(parts)
407 | s_x = x.split(parts, dim=-1)
408 | # s_x = torch.chunk(x,layer,-1)
409 | # ew = torch.chunk(edge_weight,layer,-1)
410 | # s_x_p: OptPairTensor = (s_x[0], s_x[0])
411 | out = self.propagate(edge_index, x=s_x[0], edge_weight=edge_weight,
412 | size=None)
413 | for i in range(1, layer):
414 | # s_x_p: OptPairTensor = (s_x[i], s_x[i])
415 | out_l = self.propagate(edge_index, x=s_x[i], edge_weight=edge_weight,
416 | size=None)
417 | out = torch.cat((out, out_l), dim=-1)
418 | else:
419 | out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
420 | size=None)
421 |
422 | if self.bias is not None:
423 | out += self.bias
424 |
425 | return out
426 |
427 | def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
428 | if edge_weight is None:
429 | return x_j
430 | else:
431 | return edge_weight.view(-1, 1) * x_j
432 |
433 | def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
434 | return matmul(adj_t, x, reduce=self.aggr)
435 |
436 | def __repr__(self):
437 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
438 | self.out_channels)
439 |
440 |
441 | class GATConv_3d(MessagePassing):
442 | r"""The graph attentional operator from the `"Graph Attention Networks"
443 | `_ paper
444 |
445 | .. math::
446 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
447 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
448 |
449 | where the attention coefficients :math:`\alpha_{i,j}` are computed as
450 |
451 | .. math::
452 | \alpha_{i,j} =
453 | \frac{
454 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
455 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
456 | \right)\right)}
457 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
458 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
459 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
460 | \right)\right)}.
461 |
462 | Args:
463 | in_channels (int or tuple): Size of each input sample. A tuple
464 | corresponds to the sizes of source and target dimensionalities.
465 | out_channels (int): Size of each output sample.
466 | heads (int, optional): Number of multi-head-attentions.
467 | (default: :obj:`1`)
468 | concat (bool, optional): If set to :obj:`False`, the multi-head
469 | attentions are averaged instead of concatenated.
470 | (default: :obj:`True`)
471 | negative_slope (float, optional): LeakyReLU angle of the negative
472 | slope. (default: :obj:`0.2`)
473 | dropout (float, optional): Dropout probability of the normalized
474 | attention coefficients which exposes each node to a stochastically
475 | sampled neighborhood during training. (default: :obj:`0`)
476 | add_self_loops (bool, optional): If set to :obj:`False`, will not add
477 | self-loops to the input graph. (default: :obj:`True`)
478 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
479 | an additive bias. (default: :obj:`True`)
480 | **kwargs (optional): Additional arguments of
481 | :class:`torch_geometric.nn.conv.MessagePassing`.
482 | """
483 | _alpha: OptTensor
484 |
485 | def __init__(self, in_channels: Union[int, Tuple[int, int]],
486 | out_channels: int, heads: int = 1, concat: bool = True,
487 | negative_slope: float = 0.2, dropout: float = 0.,
488 | add_self_loops: bool = True, bias: bool = True, **kwargs):
489 | kwargs.setdefault('aggr', 'add')
490 | super(GATConv_3d, self).__init__(node_dim=0, **kwargs)
491 |
492 | self.in_channels = in_channels
493 | self.out_channels = out_channels
494 | self.heads = heads
495 | self.concat = concat
496 | self.negative_slope = negative_slope
497 | self.dropout = dropout
498 | self.add_self_loops = add_self_loops
499 |
500 | if isinstance(in_channels, int):
501 | self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
502 | self.lin_r = self.lin_l
503 | else:
504 | self.lin_l = Linear(in_channels[0], heads * out_channels, False)
505 | self.lin_r = Linear(in_channels[1], heads * out_channels, False)
506 |
507 | self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
508 | self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
509 |
510 | if bias and concat:
511 | self.bias = Parameter(torch.Tensor(heads * out_channels))
512 | elif bias and not concat:
513 | self.bias = Parameter(torch.Tensor(out_channels))
514 | else:
515 | self.register_parameter('bias', None)
516 |
517 | self._alpha = None
518 |
519 | self.reset_parameters()
520 |
521 | def reset_parameters(self):
522 | glorot(self.lin_l.weight)
523 | glorot(self.lin_r.weight)
524 | glorot(self.att_l)
525 | glorot(self.att_r)
526 | zeros(self.bias)
527 |
528 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
529 | size: Size = None, return_attention_weights=None, layer=None):
530 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
531 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
532 | # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
533 | # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
534 | r"""
535 |
536 | Args:
537 | return_attention_weights (bool, optional): If set to :obj:`True`,
538 | will additionally return the tuple
539 | :obj:`(edge_index, attention_weights)`, holding the computed
540 | attention weights for each edge. (default: :obj:`None`)
541 | """
542 | H, C = self.heads, self.out_channels
543 |
544 | x_l: OptTensor = None
545 | x_r: OptTensor = None
546 | alpha_l: OptTensor = None
547 | alpha_r: OptTensor = None
548 | if isinstance(x, Tensor):
549 | assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
550 | x_l = x_r = self.lin_l(x).view(-1, H, C)
551 | alpha_l = (x_l * self.att_l).sum(dim=-1)
552 | alpha_r = (x_r * self.att_r).sum(dim=-1)
553 | else:
554 | x_l, x_r = x[0], x[1]
555 | assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
556 | x_l = self.lin_l(x_l).view(-1, H, C)
557 | alpha_l = (x_l * self.att_l).sum(dim=-1)
558 | if x_r is not None:
559 | x_r = self.lin_r(x_r).view(-1, H, C)
560 | alpha_r = (x_r * self.att_r).sum(dim=-1)
561 |
562 | assert x_l is not None
563 | assert alpha_l is not None
564 |
565 | if self.add_self_loops:
566 | if isinstance(edge_index, Tensor):
567 | num_nodes = x_l.size(0)
568 | if x_r is not None:
569 | num_nodes = min(num_nodes, x_r.size(0))
570 | if size is not None:
571 | num_nodes = min(size[0], size[1])
572 | edge_index, _ = remove_self_loops(edge_index)
573 | edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
574 | elif isinstance(edge_index, SparseTensor):
575 | edge_index = set_diag(edge_index)
576 |
577 | # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
578 | if layer != None and (layer > 1):
579 | # propagate_type: (x: Tensor, edge_weight: OptTensor)
580 | L = math.floor(self.out_channels / layer)
581 | parts = []
582 | for i in range(layer - 1):
583 | parts.append(L)
584 | parts.append(x_l.shape[-1] - L * (layer - 1))
585 | parts = tuple(parts)
586 | xl = x_l.split(parts, dim=-1)
587 | xr = x_r.split(parts, dim=-1)
588 | out = self.propagate(edge_index, x=(xl[0], xr[0]),
589 | alpha=(alpha_l, alpha_r), size=size)
590 | for i in range(1, layer):
591 | out1 = self.propagate(edge_index, x=(xl[i], xr[i]),
592 | alpha=(alpha_l, alpha_r), size=size)
593 | out = torch.cat((out, out1), dim=-1)
594 | else:
595 | out = self.propagate(edge_index, x=(x_l, x_r),
596 | alpha=(alpha_l, alpha_r), size=size)
597 | alpha = self._alpha
598 | self._alpha = None
599 |
600 | if self.concat:
601 | out = out.view(-1, self.heads * self.out_channels)
602 | else:
603 | out = out.mean(dim=1)
604 |
605 | if self.bias is not None:
606 | out += self.bias
607 |
608 | if isinstance(return_attention_weights, bool):
609 | assert alpha is not None
610 | if isinstance(edge_index, Tensor):
611 | return out, (edge_index, alpha)
612 | elif isinstance(edge_index, SparseTensor):
613 | return out, edge_index.set_value(alpha, layout='coo')
614 | else:
615 | return out
616 |
617 | def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
618 | index: Tensor, ptr: OptTensor,
619 | size_i: Optional[int]) -> Tensor:
620 | alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
621 | alpha = F.leaky_relu(alpha, self.negative_slope)
622 | alpha = softmax(alpha, index, ptr, size_i)
623 | self._alpha = alpha
624 | alpha = F.dropout(alpha, p=self.dropout, training=self.training)
625 | return x_j * alpha.unsqueeze(-1)
626 |
627 | def __repr__(self):
628 | return '{}({}, {}, heads={})'.format(self.__class__.__name__,
629 | self.in_channels,
630 | self.out_channels, self.heads)
631 |
632 |
633 | class GINConv_3d(MessagePassing):
634 | r"""The graph isomorphism operator from the `"How Powerful are
635 | Graph Neural Networks?" `_ paper
636 |
637 | .. math::
638 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot
639 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)
640 |
641 | or
642 |
643 | .. math::
644 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} +
645 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),
646 |
647 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP.
648 |
649 | Args:
650 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
651 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to
652 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by
653 | :class:`torch.nn.Sequential`.
654 | eps (float, optional): (Initial) :math:`\epsilon`-value.
655 | (default: :obj:`0.`)
656 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon`
657 | will be a trainable parameter. (default: :obj:`False`)
658 | **kwargs (optional): Additional arguments of
659 | :class:`torch_geometric.nn.conv.MessagePassing`.
660 | """
661 |
662 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
663 | **kwargs):
664 | kwargs.setdefault('aggr', 'add')
665 | super(GINConv_3d, self).__init__(**kwargs)
666 | self.nn = nn
667 | self.initial_eps = eps
668 | if train_eps:
669 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
670 | else:
671 | self.register_buffer('eps', torch.Tensor([eps]))
672 | self.reset_parameters()
673 |
674 | def reset_parameters(self):
675 | reset(self.nn)
676 | self.eps.data.fill_(self.initial_eps)
677 |
678 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
679 | size: Size = None, layer=None) -> Tensor:
680 | """"""
681 |
682 | # propagate_type: (x: OptPairTensor)
683 | # out = self.propagate(edge_index, x=x, size=size, layer=layer)
684 |
685 | if layer == None or layer == 0:
686 | if isinstance(x, Tensor):
687 | x: OptPairTensor = (x, x)
688 | out = self.propagate(edge_index, x=x, size=size)
689 | x_r = x[1]
690 | elif layer == 1:
691 | out = self.propagate(edge_index, x=x, size=size)
692 | x_r = x
693 | else:
694 | L = math.floor(x.shape[-1] / layer)
695 | parts = []
696 | for i in range(layer - 1):
697 | parts.append(L)
698 | parts.append(x.shape[1] - L * (layer - 1))
699 | parts = tuple(parts)
700 | xl = x.split(parts, dim=-1)
701 | out = self.propagate(edge_index, x=xl[0], size=size)
702 | for i in range(1, layer):
703 | out1 = self.propagate(edge_index, x=xl[i], size=size)
704 | out = torch.cat((out, out1), dim=-1)
705 | x_r = x
706 |
707 | if x_r is not None:
708 | out += (1 + self.eps) * x_r
709 |
710 | return self.nn(out)
711 |
712 | def message(self, x_j: Tensor) -> Tensor:
713 | return x_j
714 |
715 | def message_and_aggregate(self, adj_t: SparseTensor,
716 | x: OptPairTensor) -> Tensor:
717 | adj_t = adj_t.set_value(None, layout=None)
718 | return matmul(adj_t, x[0], reduce=self.aggr)
719 |
720 | def __repr__(self):
721 | return '{}(nn={})'.format(self.__class__.__name__, self.nn)
722 |
723 |
724 | class GINConv_3d_com(MessagePassing):
725 | r"""The graph isomorphism operator from the `"How Powerful are
726 | Graph Neural Networks?" `_ paper
727 |
728 | .. math::
729 | \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot
730 | \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right)
731 |
732 | or
733 |
734 | .. math::
735 | \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} +
736 | (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right),
737 |
738 | here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP.
739 |
740 | Args:
741 | nn (torch.nn.Module): A neural network :math:`h_{\mathbf{\Theta}}` that
742 | maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to
743 | shape :obj:`[-1, out_channels]`, *e.g.*, defined by
744 | :class:`torch.nn.Sequential`.
745 | eps (float, optional): (Initial) :math:`\epsilon`-value.
746 | (default: :obj:`0.`)
747 | train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon`
748 | will be a trainable parameter. (default: :obj:`False`)
749 | **kwargs (optional): Additional arguments of
750 | :class:`torch_geometric.nn.conv.MessagePassing`.
751 | """
752 |
753 | def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
754 | **kwargs):
755 | kwargs.setdefault('aggr', 'add')
756 | super(GINConv_3d_com, self).__init__(**kwargs)
757 | self.nn = nn
758 | self.initial_eps = eps
759 | if train_eps:
760 | self.eps = torch.nn.Parameter(torch.Tensor([eps]))
761 | else:
762 | self.register_buffer('eps', torch.Tensor([eps]))
763 | self.reset_parameters()
764 |
765 | def reset_parameters(self):
766 | reset(self.nn)
767 | self.eps.data.fill_(self.initial_eps)
768 |
769 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
770 | size: Size = None, layer=None) -> Tensor:
771 | """"""
772 |
773 | # propagate_type: (x: OptPairTensor)
774 | # out = self.propagate(edge_index, x=x, size=size, layer=layer)
775 | x = self.nn(x)
776 | if layer == None or layer == 0:
777 | if isinstance(x, Tensor):
778 | x: OptPairTensor = (x, x)
779 | out = self.propagate(edge_index, x=x, size=size)
780 | x_r = x[1]
781 | elif layer == 1:
782 | out = self.propagate(edge_index, x=x, size=size)
783 | x_r = x
784 | else:
785 | L = math.floor(x.shape[-1] / layer)
786 | parts = []
787 | for i in range(layer - 1):
788 | parts.append(L)
789 | parts.append(x.shape[1] - L * (layer - 1))
790 | parts = tuple(parts)
791 | xl = x.split(parts, dim=-1)
792 | out = self.propagate(edge_index, x=xl[0], size=size)
793 | for i in range(1, layer):
794 | out1 = self.propagate(edge_index, x=xl[i], size=size)
795 | out = torch.cat((out, out1), dim=-1)
796 | x_r = x
797 |
798 | if x_r is not None:
799 | out += (1 + self.eps) * x_r
800 |
801 | return out
802 |
803 | def message(self, x_j: Tensor) -> Tensor:
804 | return x_j
805 |
806 | def message_and_aggregate(self, adj_t: SparseTensor,
807 | x: OptPairTensor) -> Tensor:
808 | adj_t = adj_t.set_value(None, layout=None)
809 | return matmul(adj_t, x[0], reduce=self.aggr)
810 |
811 | def __repr__(self):
812 | return '{}(nn={})'.format(self.__class__.__name__, self.nn)
813 |
--------------------------------------------------------------------------------
/model/inf_model.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import argparse
3 | from torch.nn import Sequential, Linear, ReLU
4 | import torch
5 | import torch.nn.functional as F
6 | from torch_geometric.datasets import Planetoid
7 | import torch_geometric.transforms as T
8 | from torch_geometric.nn import GCNConv, ChebConv # noqa
9 | import time
10 | from model.gnn_model_1_6 import GATConv_3d, GINConv_3d, GCNConv_3d, SAGEConv_3d, GINConv_3d_com, SAGEConv_3d_com
11 |
12 |
13 | class GCNNet(torch.nn.Module):
14 | def __init__(self, in_channels: int, out_channels: int, num_classes: int):
15 | super(GCNNet, self).__init__()
16 | self.conv1 = GCNConv_3d(in_channels, out_channels, cached=True,
17 | normalize=True)
18 | self.conv2 = GCNConv_3d(out_channels, num_classes, cached=True,
19 | normalize=True)
20 |
21 | def forward(self, x, edge_index, edge_attr=None, layer=(0, 0)):
22 | x = F.relu(self.conv1(x, edge_index, edge_attr, layer=layer[0]))
23 | x = F.dropout(x, training=self.training)
24 | x = self.conv2(x, edge_index, edge_attr, layer=layer[1])
25 | return F.log_softmax(x, dim=1)
26 |
27 |
28 | class GATNet(torch.nn.Module):
29 | def __init__(self, in_channels: int, out_channels: int, num_classes: int):
30 | super(GATNet, self).__init__()
31 | self.conv1 = GATConv_3d(in_channels, out_channels, heads=1, dropout=0.6)
32 | # On the Pubmed dataset, use heads=8 in conv2.
33 | self.conv2 = GATConv_3d(out_channels, num_classes, heads=1, concat=False,
34 | dropout=0.6)
35 |
36 | def forward(self, x, edge_index, layer=(0, 0)):
37 | x = F.dropout(x, p=0.6, training=self.training)
38 | x = F.elu(self.conv1(x, edge_index, layer=layer[0]))
39 | x = F.dropout(x, p=0.6, training=self.training)
40 | x = self.conv2(x, edge_index, layer=layer[1])
41 |
42 | return F.log_softmax(x, dim=1)
43 |
44 |
45 | class SAGE_agg1(torch.nn.Module):
46 | def __init__(self, in_channels, hidden_channels, num_classes: int):
47 | super(SAGE_agg1, self).__init__()
48 | self.conv1 = SAGEConv_3d(in_channels, hidden_channels)
49 | self.conv1 = SAGEConv_3d(hidden_channels, num_classes)
50 |
51 | def forward(self, x, edge_index, layer=(0, 0)):
52 | x = self.conv1(x, edge_index, layer=layer[0])
53 | x = x.relu()
54 | x = F.dropout(x, p=0.5, training=self.training)
55 | x = self.conv2(x, edge_index, layer=layer[1])
56 |
57 | return F.log_softmax(x, dim=1)
58 |
59 |
60 | class SAGE_com1(torch.nn.Module):
61 | def __init__(self, in_channels, hidden_channels, num_classes: int):
62 | super(SAGE_com1, self).__init__()
63 | self.conv1 = SAGEConv_3d_com(in_channels, hidden_channels)
64 | self.conv2 = SAGEConv_3d_com(hidden_channels, num_classes)
65 |
66 | def forward(self, x, edge_index, layer=(0, 0)):
67 | x = self.conv1(x, edge_index, layer=layer[0])
68 | x = x.relu()
69 | x = F.dropout(x, p=0.5, training=self.training)
70 | x = self.conv2(x, edge_index, layer=layer[1])
71 |
72 | return F.log_softmax(x, dim=1)
73 |
74 |
75 | class GIN_agg(torch.nn.Module):
76 | def __init__(self, in_channels, hidden_channels, num_classes: int):
77 | super(GIN_agg, self).__init__()
78 |
79 | num_features = in_channels
80 | dim = hidden_channels
81 |
82 | nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
83 | self.conv1 = GINConv_3d(nn1)
84 | self.bn1 = torch.nn.BatchNorm1d(dim)
85 |
86 | nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
87 | self.conv2 = GINConv_3d(nn2)
88 | self.bn2 = torch.nn.BatchNorm1d(dim)
89 |
90 | nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
91 | self.conv3 = GINConv_3d(nn3)
92 | self.bn3 = torch.nn.BatchNorm1d(dim)
93 |
94 | nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
95 | self.conv4 = GINConv_3d(nn4)
96 | self.bn4 = torch.nn.BatchNorm1d(dim)
97 |
98 | nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
99 | self.conv5 = GINConv_3d(nn5)
100 | self.bn5 = torch.nn.BatchNorm1d(dim)
101 |
102 | self.fc1 = Linear(dim, dim)
103 | self.fc2 = Linear(dim, num_classes)
104 |
105 | def forward(self, x, edge_index, layer=(0, 0)):
106 | x = F.relu(self.conv1(x, edge_index, layer=layer[0]))
107 | x = self.bn1(x)
108 | x = F.relu(self.conv2(x, edge_index, layer=layer[1]))
109 | x = self.bn2(x)
110 | x = F.relu(self.conv3(x, edge_index, layer=layer[1]))
111 | x = self.bn3(x)
112 | x = F.relu(self.conv4(x, edge_index, layer=layer[1]))
113 | x = self.bn4(x)
114 | x = F.relu(self.conv5(x, edge_index, layer=layer[1]))
115 | x = self.bn5(x)
116 | x = F.relu(self.fc1(x))
117 | x = F.dropout(x, p=0.5, training=self.training)
118 | x = self.fc2(x)
119 | return F.log_softmax(x, dim=-1)
120 |
121 |
122 | class GIN_com(torch.nn.Module):
123 | def __init__(self, in_channels, hidden_channels, num_classes: int):
124 | super(GIN_agg, self).__init__()
125 |
126 | num_features = in_channels
127 | dim = hidden_channels
128 |
129 | nn1 = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
130 | self.conv1 = GINConv_3d_com(nn1)
131 | self.bn1 = torch.nn.BatchNorm1d(dim)
132 |
133 | nn2 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
134 | self.conv2 = GINConv_3d_com(nn2)
135 | self.bn2 = torch.nn.BatchNorm1d(dim)
136 |
137 | nn3 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
138 | self.conv3 = GINConv_3d_com(nn3)
139 | self.bn3 = torch.nn.BatchNorm1d(dim)
140 |
141 | nn4 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
142 | self.conv4 = GINConv_3d_com(nn4)
143 | self.bn4 = torch.nn.BatchNorm1d(dim)
144 |
145 | nn5 = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
146 | self.conv5 = GINConv_3d_com(nn5)
147 | self.bn5 = torch.nn.BatchNorm1d(dim)
148 |
149 | self.fc1 = Linear(dim, dim)
150 | self.fc2 = Linear(dim, num_classes)
151 |
152 | def forward(self, x, edge_index, layer=(0, 0)):
153 | x = F.relu(self.conv1(x, edge_index, layer=layer[0]))
154 | x = self.bn1(x)
155 | x = F.relu(self.conv2(x, edge_index, layer=layer[0]))
156 | x = self.bn2(x)
157 | x = F.relu(self.conv3(x, edge_index, layer=layer[0]))
158 | x = self.bn3(x)
159 | x = F.relu(self.conv4(x, edge_index, layer=layer[0]))
160 | x = self.bn4(x)
161 | x = F.relu(self.conv5(x, edge_index, layer=layer[1]))
162 | x = self.bn5(x)
163 | x = F.relu(self.fc1(x))
164 | x = F.dropout(x, p=0.5, training=self.training)
165 | x = self.fc2(x)
166 | return F.log_softmax(x, dim=-1)
167 |
--------------------------------------------------------------------------------
/model/inits.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 |
5 |
6 | def uniform(size, tensor):
7 | if tensor is not None:
8 | bound = 1.0 / math.sqrt(size)
9 | tensor.data.uniform_(-bound, bound)
10 |
11 |
12 | def kaiming_uniform(tensor, fan, a):
13 | if tensor is not None:
14 | bound = math.sqrt(6 / ((1 + a**2) * fan))
15 | tensor.data.uniform_(-bound, bound)
16 |
17 |
18 | def glorot(tensor):
19 | if tensor is not None:
20 | stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
21 | tensor.data.uniform_(-stdv, stdv)
22 |
23 |
24 | def glorot_orthogonal(tensor, scale):
25 | if tensor is not None:
26 | torch.nn.init.orthogonal_(tensor.data)
27 | scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var())
28 | tensor.data *= scale.sqrt()
29 |
30 |
31 | def zeros(tensor):
32 | if tensor is not None:
33 | tensor.data.fill_(0)
34 |
35 |
36 | def ones(tensor):
37 | if tensor is not None:
38 | tensor.data.fill_(1)
39 |
40 |
41 | def normal(tensor, mean, std):
42 | if tensor is not None:
43 | tensor.data.normal_(mean, std)
44 |
45 |
46 | def reset(nn):
47 | def _reset(item):
48 | if hasattr(item, 'reset_parameters'):
49 | item.reset_parameters()
50 |
51 | if nn is not None:
52 | if hasattr(nn, 'children') and len(list(nn.children())) > 0:
53 | for item in nn.children():
54 | _reset(item)
55 | else:
56 | _reset(nn)
57 |
--------------------------------------------------------------------------------
/pyg1.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUAA-CI-LAB/GNN-Feature-Decomposition/e1980bac20c200b324d6534da1b33f4fab556c78/result.png
--------------------------------------------------------------------------------
/test/test_gnn_layer.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os.path as osp
3 | import argparse
4 | import torch
5 | import sys
6 | # from memory_profiler import profile
7 | from torch.nn import Sequential
8 | from torch.nn import Sequential, Linear, ReLU
9 |
10 | sys.path.append("..")
11 | from torch_geometric.nn import GCNConv, GATConv, SAGEConv, GINConv, EdgeConv
12 | from model.gnn_model_1_6 import GATConv_3d, GINConv_3d, GCNConv_3d, SAGEConv_3d, GINConv_3d_com, SAGEConv_3d_com
13 | import torch_geometric.transforms as T
14 | from torch_geometric.datasets import Reddit
15 | from torch_geometric.datasets import Planetoid
16 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset
17 |
18 | parser = argparse.ArgumentParser(description='manual to this script')
19 | parser.add_argument("--gpu", type=int, default="0", help='0:use cpu 1:use gpu')
20 | parser.add_argument("--hidden", type=int, default=8, help='hidden number')
21 | parser.add_argument("--agg", type=str, default="gas", help='gas = scatter spmm = sparse mul')
22 | parser.add_argument("--m", type=str, default="GCN", help='model name: GCN,GIN,GAT,SAGE')
23 | parser.add_argument("--layer", type=int, default=0, help='layer represent the number of layers cutting along '
24 | 'feature dimension, None = optpair x , 1 = single x')
25 | parser.add_argument("--order", type=str, default='agg', help='agg or com')
26 | parser.add_argument("--data", type=str, default='rd', help='cr = cora,pb=pubmed,cs = cisteer, pt=ogbn-proteins, rd=reddit')
27 |
28 | args = parser.parse_args()
29 | layer = args.layer
30 | hidden = args.hidden
31 |
32 | data_use = args.data
33 | path = ''
34 | dataset = ''
35 | if data_use == 'rd':
36 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', 'Reddit')
37 | if args.agg == "gas":
38 | dataset = Reddit(path)
39 | print('Reddit dataset, gas model!')
40 | else:
41 | dataset = Reddit(path, transform=T.ToSparseTensor())
42 | print('Reddit dataset, spmm model')
43 | elif data_use == 'cr':
44 | dataset = 'Cora'
45 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
46 | if args.agg == "gas":
47 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
48 | print('Cora dataset, gas model!')
49 | else:
50 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
51 | print('Cora dataset, spmm model')
52 | elif data_use == 'pb':
53 | dataset = 'pubmed'
54 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
55 | if args.agg == "gas":
56 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
57 | print('Pumbed dataset, gas model!')
58 | else:
59 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
60 | print('Pumbed dataset, spmm model')
61 | elif data_use == 'cs':
62 | dataset = 'CiteSeer'
63 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
64 | if args.agg == "gas":
65 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
66 | print('Citeseer dataset, gas model!')
67 | else:
68 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
69 | print('Citeseer dataset, spmm model')
70 | elif data_use == 'pt':
71 | if args.agg == "gas":
72 | dataset = PygNodePropPredDataset(name='ogbn-proteins',root='../dataset')
73 | print('Ogbn-proteins dataset, gas model!')
74 | else:
75 | dataset = PygNodePropPredDataset(name='ogbn-proteins', root='../dataset',transform=T.ToSparseTensor())
76 | print('Ogbn-proteins dataset, spmm model!')
77 |
78 | a = time.time()
79 | data = dataset[0]
80 | b = time.time()
81 | print('Dataset process time : ', b - a)
82 |
83 | device = torch.device('cuda' if args.gpu else 'cpu')
84 | model = ''
85 |
86 | if args.m == "GCN":
87 | print('using gcn')
88 | if data_use == 'pt':
89 | num_features = 1
90 | else:
91 | num_features = dataset.num_features
92 | model = GCNConv_3d(num_features, hidden)
93 | elif args.m == "GAT":
94 | print('using gat')
95 | if data_use == 'pt':
96 | num_features = 1
97 | else:
98 | num_features = dataset.num_features
99 | model = GATConv_3d(num_features, hidden)
100 | elif args.m == "SAGE":
101 | if args.order == 'agg':
102 | print('using 1st agg graphsage')
103 | if data_use == 'pt':
104 | num_features = 1
105 | else:
106 | num_features = dataset.num_features
107 | model = SAGEConv_3d(num_features, hidden)
108 | else:
109 | print('using 1st com graphsage')
110 | if data_use == 'pt':
111 | num_features = 1
112 | else:
113 | num_features = dataset.num_features
114 | model = SAGEConv_3d_com(num_features, hidden)
115 | elif args.m == "GIN":
116 | if args.order == 'agg':
117 | if data_use == 'pt':
118 | num_features = 1
119 | else:
120 | num_features = dataset.num_features
121 | nn1 = Sequential(Linear(num_features, args.hidden), ReLU(), Linear(args.hidden, args.hidden))
122 | print('using 1st agg GIN , hid=', args.hidden)
123 | model = GINConv_3d(nn1)
124 | else:
125 | if data_use == 'pt':
126 | num_features = 1
127 | else:
128 | num_features = dataset.num_features
129 | nn1 = Sequential(Linear(num_features, args.hidden), ReLU(), Linear(args.hidden, args.hidden))
130 | print('using 1st com GIN , hid=', args.hidden)
131 | model = GINConv_3d_com(nn1)
132 | else:
133 | print('error,model not exit')
134 |
135 | model = model.to(device)
136 |
137 | x = data.x.to(device)
138 | y = data.y.squeeze().to(device)
139 | if args.agg == "gas":
140 | edge_index = data.edge_index.to(device)
141 | else:
142 | edge_index = data.adj_t.to(device)
143 |
144 |
145 | print('using ', device)
146 | print('layers : ', args.layer)
147 |
148 |
149 | @torch.no_grad()
150 |
151 | def test():
152 | model.eval()
153 | # with torch.autograd.profiler.profile(use_cuda=True if args.gpu else False, profile_memory=True) as prof:
154 | t0 = time.time()
155 | out = model(x, edge_index, layer=layer)
156 | t1 = time.time()
157 | #print(prof.key_averages().table(sort_by="self_cpu_time_total" if args.gpu == 0 else "self_cuda_time_total"))
158 | print('Inference time:', t1 - t0)
159 | print('Total time:', t1 - t0 + b - a)
160 |
161 |
162 | if __name__ == '__main__':
163 | test()
164 |
--------------------------------------------------------------------------------
/test/test_gnn_total.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os.path as osp
3 | import argparse
4 | import torch
5 | import sys
6 |
7 | from torch.nn import Sequential
8 | from torch.nn import Sequential, Linear, ReLU
9 |
10 | sys.path.append("..")
11 |
12 | import torch_geometric.transforms as T
13 | from torch_geometric.datasets import Reddit
14 | from torch_geometric.datasets import Planetoid
15 | from model.inf_model import GCNNet, GATNet, SAGE_agg1, SAGE_com1, GIN_agg, GIN_com
16 | from torch_geometric.datasets import Planetoid
17 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset
18 |
19 | parser = argparse.ArgumentParser(description='manual to this script')
20 | parser.add_argument("--gpu", type=int, default="0", help='0:use cpu 1:use gpu')
21 | parser.add_argument("--hidden", type=int, default=8, help='hidden number')
22 | parser.add_argument("--agg", type=str, default="gas", help='gas=scatter spmm=sparse mul')
23 | parser.add_argument("--m", type=str, default="GCN", help='model name: GCN,GIN,GAT,SAGE')
24 | parser.add_argument("--layer", type=int, default=0, help='layer represent the number of layers cutting along '
25 | 'feature dimension, None = optpair x , 1 = single x')
26 | parser.add_argument("--order", type=str, default='com', help='agg first or com first')
27 | parser.add_argument("--data", type=str, default='rd', help='cr = cora,pb=pubmed,cs = cisteer, pt=ogbn-proteins, rd=reddit')
28 | args = parser.parse_args()
29 | hidden = args.hidden
30 | layer = args.layer
31 |
32 | a = layer.split(',')
33 | a0 = (int)(a[0])
34 | a1 = (int)(a[1])
35 | layer = [a0, a1]
36 | data_use = args.data
37 |
38 | path = ''
39 | dataset = ''
40 | if data_use == 'rd':
41 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset', 'Reddit')
42 | if args.agg == "gas":
43 | dataset = Reddit(path)
44 | print('Reddit dataset, gas model!')
45 | else:
46 | dataset = Reddit(path, transform=T.ToSparseTensor())
47 | print('Reddit dataset, spmm model')
48 | elif data_use == 'cr':
49 | dataset = 'Cora'
50 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
51 | if args.agg == "gas":
52 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
53 | print('Cora dataset, gas model!')
54 | else:
55 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
56 | print('Cora dataset, spmm model')
57 | elif data_use == 'pb':
58 | dataset = 'pubmed'
59 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
60 | if args.agg == "gas":
61 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
62 | print('Pumbed dataset, gas model!')
63 | else:
64 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
65 | print('Pumbed dataset, spmm model')
66 | elif data_use == 'cs':
67 | dataset = 'CiteSeer'
68 | path = osp.join(osp.dirname(osp.realpath(__file__)), '../dataset')
69 | if args.agg == "gas":
70 | dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
71 | print('Citeseer dataset, gas model !')
72 | else:
73 | dataset = Planetoid(path, dataset, transform=T.ToSparseTensor())
74 | print('Citeseer dataset, spmm model')
75 | elif data_use == 'pt':
76 | if args.agg == "gas":
77 | dataset = PygNodePropPredDataset(name='ogbn-proteins')
78 | print('Ogbn-proteins dataset, gas model!')
79 | else:
80 | dataset = PygNodePropPredDataset(name='ogbn-proteins', transform=T.ToSparseTensor())
81 | print('Ogbn-proteins dataset, spmm model!')
82 |
83 | a = time.time()
84 | data = dataset[0]
85 |
86 | b = time.time()
87 | print('Dataset process time : ', b - a)
88 |
89 | device = torch.device('cuda' if args.gpu else 'cpu')
90 |
91 | model = ''
92 | if args.m == "GCN":
93 | print('using gcn')
94 | if data_use == 'pt':
95 | num_features = 1
96 | else:
97 | num_features = dataset.num_features
98 | model = GCNNet(num_features, hidden, dataset.num_classes)
99 | elif args.m == "GAT":
100 | print('using gat')
101 | if data_use == 'pt':
102 | num_features = 1
103 | else:
104 | num_features = dataset.num_features
105 | model = GATNet(num_features, hidden, dataset.num_classes)
106 | elif args.m == "SAGE":
107 | if args.order == 'agg':
108 | print('using 1st agg sage')
109 | if data_use == 'pt':
110 | num_features = 1
111 | else:
112 | num_features = dataset.num_features
113 | model = SAGE_agg1(num_features, hidden, dataset.num_classes)
114 | else:
115 | print('using 1st com sage')
116 | if data_use == 'pt':
117 | num_features = 1
118 | else:
119 | num_features = dataset.num_features
120 | model = SAGE_com1(num_features, hidden, dataset.num_classes)
121 | elif args.m == "GIN":
122 | if args.order == 'agg':
123 | print('using 1st agg GIN,hid=', args.hidden)
124 | if data_use == 'pt':
125 | num_features = 1
126 | else:
127 | num_features = dataset.num_features
128 | model = GIN_agg(num_features, hidden, dataset.num_classes)
129 | else:
130 | print('using 1st com GIN,hid=', args.hidden)
131 | if data_use == 'pt':
132 | num_features = 1
133 | else:
134 | num_features = dataset.num_features
135 | model = GIN_com(num_features, hidden, dataset.num_classes)
136 | else:
137 | print('error,model not exit')
138 |
139 | model = model.to(device)
140 |
141 |
142 | x = data.x.to(device)
143 | y = data.y.squeeze().to(device)
144 |
145 | if args.agg == "gas":
146 | edge_index = data.edge_index.to(device)
147 | else:
148 | edge_index = data.adj_t.to(device)
149 |
150 |
151 | print('using ', device)
152 | print('layers : ', args.layer)
153 |
154 | @torch.no_grad()
155 | def test():
156 | model.eval()
157 |
158 | # with torch.autograd.profiler.profile(use_cuda=True if args.gpu else False, profile_memory=True) as prof:
159 | t0 = time.time()
160 | out = model(x, edge_index, layer=layer)
161 | t1 = time.time()
162 | # print(prof.key_averages().table(sort_by="self_cpu_time_total" if args.gpu == 0 else "self_cuda_time_total"))
163 | print('Inference time:', t1 - t0)
164 | print('Total time:', t1 - t0 + b - a)
165 |
166 |
167 | if __name__ == '__main__':
168 | test()
169 |
--------------------------------------------------------------------------------