├── LICENSE ├── README.md ├── requirements.txt ├── setup.py ├── test ├── gnn_test.py └── requirements.txt └── torch_graphnet ├── __init__.py ├── graph_networks.py └── utils ├── __init__.py └── gn_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2021, David Mulero 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Graph Networks 2 | 3 | [PyTorch](https://pytorch.org/) implementation of DeepMind [Graph Nets](https://github.com/deepmind/graph_nets). 4 | The original code depends on [Tensorflow](https://www.tensorflow.org/) and 5 | [Sonnet](https://sonnet.readthedocs.io/en/latest/index.html). 6 | 7 | This implementation is based on [PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric) 8 | which is a geometric deep learning extension library for [PyTorch](https://pytorch.org/) 9 | 10 | ## Graph Networks 11 | Graph Networks are a general framework that generalizes graph neural networks. 12 | It unifies [Message Passing Neural Networks](https://arxiv.org/pdf/1704.01212v2.pdf) (MPNNs) 13 | and [Non-Local Neural Networks](https://arxiv.org/pdf/1711.07971v3.pdf) (NLNNs), 14 | as well as other variants like [Interaction Networks](https://arxiv.org/abs/1612.00222) (INs) 15 | or[Relation Networks](https://arxiv.org/pdf/1702.05068.pdf) (RNs). 16 | 17 | You can have a look at Graph Networks in their arXiV paper: 18 | [Battaglia, Peter W., et al. "Relational inductive biases, deep learning, and graph networks." arXiv preprint arXiv:1806.01261 (2018)](https://arxiv.org/pdf/1806.01261.pdf) 19 | 20 | ## Available Models 21 | The following models are available: 22 | - Interaction Network 23 | - Graph Independent 24 | 25 | You can also build your own models using the Blocks: 26 | - Node Model 27 | - Edge Model 28 | 29 | ## Requirements 30 | PyTorch 1.8.0 and PyTorch Geometric. 31 | 32 | ## Example 33 | We provide an example that tests the output against DeepMind's graph_nets. 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.1 2 | matplotlib>=3.3.4 3 | torch>=1.8.0 4 | torch-scatter==2.0.6 5 | torch-sparse==0.6.8 6 | torch-cluster==1.5.9 7 | torch-spline-conv==1.2.1 8 | torch-geometric==1.6.3 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup 3 | 4 | setup( 5 | name="torch-graphnet", 6 | version='0.1', 7 | description='PyTorch implementation of Graph Networks', 8 | author='dblancom', 9 | packages=['torch_graphnet', 'torch_graphnet.utils'], 10 | url="https://github.com/dblanm/torch-graphnet", 11 | license="Clear BSD License" 12 | ) 13 | -------------------------------------------------------------------------------- /test/gnn_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 (c) Aalto University - All Rights Reserved 2 | # Author: David Blanco Mulero 3 | # 4 | 5 | import graph_nets as gn 6 | import sonnet as snt 7 | import numpy as np 8 | import tensorflow as tf 9 | import torch 10 | import torch_scatter 11 | from torch_geometric.data import Data 12 | 13 | from torch_graphnet.utils import receiver_nodes_to_edges, sender_nodes_to_edges, \ 14 | received_edges_to_node_aggregator, sent_edges_to_node_aggregator, context_to_nodes,\ 15 | context_to_edges 16 | 17 | from torch_graphnet import EdgeModel, NodeModel, InteractionNetwork 18 | 19 | seed = 123 20 | np.random.seed(seed) 21 | tf.random.set_seed(seed) 22 | torch.manual_seed(seed) 23 | 24 | 25 | class GNData(Data): 26 | def __init__(self, x: torch.Tensor = None, edge_index: torch.Tensor = None, 27 | edge_attr: torch.Tensor = None, global_context: torch.Tensor = None, 28 | y: torch.Tensor = None, pos: torch.Tensor = None, 29 | normal: torch.Tensor = None, face: torch.Tensor = None, **kwargs): 30 | super(GNData, self).__init__(x=x, edge_index=edge_index, edge_attr=edge_attr, 31 | y=y, pos=pos, normal=normal, face=face, kwargs=kwargs) 32 | # Convert the receivers/senders to int64 and reshape to be a 2-D vector 33 | self.global_context = global_context 34 | 35 | @property 36 | def nodes(self): 37 | return self.x 38 | 39 | @property 40 | def senders(self): 41 | return self.edge_index[0, :] 42 | 43 | @property 44 | def receivers(self): 45 | return self.edge_index[1, :] 46 | 47 | @property 48 | def num_global(self): 49 | return self.global_context.shape[-1] 50 | 51 | def from_graphstuple_to_gndata(graph: gn.graphs.GraphsTuple): 52 | # TODO Create from GraphsTuple to GNData 53 | nodes = torch.from_numpy(graph.nodes.numpy()) 54 | edges = torch.from_numpy(graph.edges.numpy()) 55 | receivers = torch.from_numpy(graph.receivers.numpy()).long() 56 | senders = torch.from_numpy(graph.senders.numpy()).long() 57 | globals = torch.from_numpy(graph.globals.numpy()) 58 | 59 | edge_index = torch.stack((senders,receivers)).long() 60 | 61 | torch_graph = GNData(x=nodes, edge_index=edge_index.contiguous(), 62 | edge_attr=edges, global_context=globals) 63 | return torch_graph 64 | 65 | 66 | def get_model_fn(in_size, out_size): 67 | w_init = np.random.rand(in_size, out_size) 68 | b_init = np.random.rand(out_size) 69 | 70 | linear_fn_tf = lambda: snt.Linear(output_size=out_size, 71 | w_init=tf.constant_initializer(w_init), 72 | b_init=tf.constant_initializer(b_init)) 73 | 74 | linear_fn_torch = torch.nn.Linear(in_size, out_size) 75 | linear_fn_torch.bias = torch.nn.Parameter(torch.from_numpy(b_init)) 76 | linear_fn_torch.weight = torch.nn.Parameter(torch.from_numpy(w_init.T)) 77 | 78 | return linear_fn_tf, linear_fn_torch 79 | 80 | 81 | def create_graph(): 82 | # Global features for graph 0. 83 | globals_0 = [1., 2., 3.] 84 | 85 | # Node features for graph 0. 86 | nodes_0 = [[10., 20., 30.], # Node 0 87 | [11., 21., 31.], # Node 1 88 | [12., 22., 32.], # Node 2 89 | [13., 23., 33.], # Node 3 90 | [14., 24., 34.]] # Node 4 91 | 92 | # Edge features for graph 0. 93 | edges_0 = [[100., 200.], # Edge 0 94 | [101., 201.], # Edge 1 95 | [102., 202.], # Edge 2 96 | [103., 203.], # Edge 3 97 | [104., 204.], # Edge 4 98 | [105., 205.]] # Edge 5 99 | 100 | # The sender and receiver nodes associated with each edge for graph 0. 101 | senders_0 = [0, 1, 1, 2, 2, 3] # Index of the sender nodes for the edge i 102 | receivers_0 = [1, 2, 3, 0, 3, 4] # Index of the receiver nodes for the edge i 103 | 104 | # Global features for graph 1. 105 | globals_1 = [1001., 1002., 1003.] 106 | 107 | # Node features for graph 1. 108 | nodes_1 = [[1010., 1020., 1030.], # Node 0 109 | [1012., 1020., 1030.], # Node 1 110 | [1013., 1020., 1030.], # Node 2 111 | [1014., 1020., 1030.], # Node 3 112 | [1015., 1021., 1031.]] # Node 4 113 | 114 | # Edge features for graph 1. 115 | edges_1 = [[1100., 1200.], # Edge 0 116 | [1101., 1201.], # Edge 1 117 | [1102., 1202.], # Edge 2 118 | [1102., 1202.], # Edge 3 119 | [1103., 1203.]] # Edge 4 120 | 121 | # The sender and receiver nodes associated with each edge for graph 1. 122 | senders_1 = [0, 1, 2, 3, 4] 123 | receivers_1 = [1, 2, 3, 4, 3] 124 | 125 | data_dict_0 = {"globals": np.array(globals_0), "nodes": np.array(nodes_0), 126 | "edges": np.array(edges_0), 127 | "senders": np.array(senders_0), "receivers": np.array(receivers_0)} 128 | data_dict_1 = {"globals": np.array(globals_1), "nodes": np.array(nodes_1), 129 | "edges": np.array(edges_1), 130 | "senders": np.array(senders_1), "receivers": np.array(receivers_1)} 131 | data_dicts = [data_dict_0, data_dict_1] 132 | graphs_tuple = gn.utils_tf.data_dicts_to_graphs_tuple(data_dicts) 133 | return graphs_tuple 134 | 135 | 136 | def test_gn_utils(graphs_tuple: gn.graphs.GraphsTuple): 137 | 138 | def test_aggregators(graph: GNData, graph_gn): 139 | 140 | reducer = tf.math.unsorted_segment_sum 141 | 142 | sedge_to_nodes = gn.blocks.SentEdgesToNodesAggregator(reducer) 143 | redge_to_nodes = gn.blocks.ReceivedEdgesToNodesAggregator(reducer) 144 | 145 | nodes_sedge_tf = sedge_to_nodes(graph_gn) 146 | nodes_redge_tf = redge_to_nodes(graph_gn) 147 | nodes_sedge = sent_edges_to_node_aggregator(graph.nodes, graph.edge_attr, 148 | graph.senders, reduce='sum') 149 | 150 | nodes_redge = received_edges_to_node_aggregator(graph.nodes, graph.edge_attr, 151 | graph.receivers, reduce='sum') 152 | 153 | np.testing.assert_allclose(nodes_sedge_tf.numpy(), nodes_sedge.numpy(), 154 | err_msg="SentEdgesToNodesAggregator does not match") 155 | np.testing.assert_allclose(nodes_redge_tf.numpy(), nodes_redge.numpy(), 156 | err_msg="ReceivedEdgesToNodesAggregator does not match") 157 | 158 | def test_broadcasts(graph: GNData, graph_gn): 159 | edge_bcast_rnodes_tf = gn.blocks.broadcast_receiver_nodes_to_edges(graph_gn) 160 | edge_bcast_snodes_tf = gn.blocks.broadcast_sender_nodes_to_edges(graph_gn) 161 | edge_bcast_globals_tf = gn.blocks.broadcast_globals_to_edges(graph_gn) 162 | 163 | edge_bcast_rnodes = receiver_nodes_to_edges(graph.nodes, graph.receivers) 164 | edge_bcast_snodes = sender_nodes_to_edges(graph.nodes, graph.senders) 165 | edge_bcast_globals = context_to_edges(graph.edge_attr, graph.global_context) 166 | 167 | np.testing.assert_allclose(edge_bcast_rnodes_tf.numpy(), edge_bcast_rnodes.numpy(), 168 | err_msg="Broadcast receiver nodes to edges does not match") 169 | np.testing.assert_allclose(edge_bcast_snodes_tf.numpy(), edge_bcast_snodes.numpy(), 170 | err_msg="Broadcast sender nodes to edges does not match") 171 | np.testing.assert_allclose(edge_bcast_globals_tf.numpy(), edge_bcast_globals.numpy(), 172 | err_msg="Broadcast global to edges does not match") 173 | 174 | graph_tf = gn.utils_tf.get_graph(graphs_tuple, index=0) 175 | graph_torch = from_graphstuple_to_gndata(graph_tf) 176 | test_aggregators(graph_torch, graph_tf) 177 | test_broadcasts(graph_torch, graph_tf) 178 | 179 | 180 | def test_blocks(graphs_tuple: gn.graphs.GraphsTuple): 181 | 182 | def test_edge_block(graph: GNData, graph_gn: gn.graphs.GraphsTuple): 183 | # Input edge block shape = 184 | # edge attr shape + num node features (receiver nodes) + num node features (sender nodes) 185 | in_size = graph.num_edge_features + graph.num_node_features + graph.num_node_features 186 | out_size = 5 187 | 188 | linear_fn_tf, linear_fn_torch = get_model_fn(in_size, out_size) 189 | 190 | edge_block_tf = gn.blocks.EdgeBlock(edge_model_fn=linear_fn_tf, use_receiver_nodes=True, 191 | use_sender_nodes=True, use_globals=False) 192 | edge_block_torch = EdgeModel(phi_edge=linear_fn_torch, use_receiver_nodes=True, 193 | use_sender_nodes=True) 194 | 195 | # Compute the output of each model 196 | out_edge_block_tf = edge_block_tf(graph_gn) 197 | edges_tf = out_edge_block_tf.edges 198 | out_edge_block_torch = edge_block_torch(graph.nodes, graph.edge_attr, 199 | graph.edge_index).detach() 200 | 201 | np.testing.assert_allclose(edges_tf.numpy(), out_edge_block_torch.numpy(), 202 | err_msg="Edge block does not match") 203 | def test_node_block(graph: GNData, graph_gn: gn.graphs.GraphsTuple): 204 | # Input node block shape = nodes attr shape + num nodes (received edges) + num nodes (sender nodes) 205 | in_size = graph.num_node_features + graph.num_edge_features + graph.num_edge_features 206 | out_size = 5 # 5, 7 207 | 208 | linear_fn_tf, linear_fn_torch = get_model_fn(in_size, out_size) 209 | 210 | node_block_tf = gn.blocks.NodeBlock(node_model_fn=linear_fn_tf, use_received_edges=True, 211 | use_sent_edges=True, use_globals=False) 212 | node_block_torch = NodeModel(phi_node=linear_fn_torch, use_received_edges=True, 213 | use_sent_edges=True, use_context=False) 214 | 215 | # Compute the output of each model 216 | out_node_block_tf = node_block_tf(graph_gn) 217 | nodes_tf = out_node_block_tf.nodes 218 | out_node_block_torch = node_block_torch(graph.nodes, graph.edge_attr, 219 | graph.edge_index).detach() 220 | 221 | np.testing.assert_allclose(nodes_tf.numpy(), out_node_block_torch.numpy(), 222 | err_msg="Node block does not match") 223 | 224 | graph_tf = gn.utils_tf.get_graph(graphs_tuple, index=0) 225 | graph_torch = from_graphstuple_to_gndata(graph_tf) 226 | test_edge_block(graph_torch, graph_tf) 227 | test_node_block(graph_torch, graph_tf) 228 | 229 | 230 | def test_in(graphs_tuple: gn.graphs.GraphsTuple): 231 | 232 | def test_interaction_network_forward(graph: GNData, graph_gn: gn.graphs.GraphsTuple, 233 | node_fn_tf, node_fn_torch, edge_fn_tf, edge_fn_torch): 234 | 235 | INet_tf = gn.modules.InteractionNetwork(edge_model_fn=edge_fn_tf, 236 | node_model_fn=node_fn_tf) 237 | INet_torch = InteractionNetwork(phi_edge=edge_fn_torch, 238 | phi_node=node_fn_torch) 239 | 240 | out_graph_tf = INet_tf(graph_gn) 241 | node_out, edge_out, _ = INet_torch(graph.nodes, graph.edge_attr, 242 | graph.edge_index) 243 | 244 | np.testing.assert_allclose(out_graph_tf.nodes.numpy(), node_out.detach().numpy(), 245 | err_msg="Interaction network Nodes output does not match") 246 | np.testing.assert_allclose(out_graph_tf.edges.numpy(), edge_out.detach().numpy(), 247 | err_msg="Interaction network Edges output does not match") 248 | print("Interaction Network forward passed") 249 | 250 | return INet_tf, INet_torch 251 | 252 | def test_interaction_network_backward(inet_tf, inet_torch, graph_in_tf, graph_tgt_tf, 253 | graph_in_torch, graph_tgt_torch): 254 | 255 | with tf.GradientTape() as tape: 256 | loss_tf = tf.reduce_mean(tf.square(inet_tf(graph_in_tf).nodes - graph_tgt_tf.nodes)) 257 | # loss_value = loss_fn(inet_tf, graph_in_tf, graph_tgt_tf) 258 | grads = tape.gradient(loss_tf, inet_tf.trainable_variables) 259 | grad_tf_phi_edge_b = grads[0].numpy() 260 | grad_tf_phi_edge_w = grads[1].numpy() 261 | grad_tf_phi_node_b = grads[2].numpy() 262 | grad_tf_phi_node_w = grads[3].numpy() 263 | node_out, edge_out, _ = inet_torch(graph_in_torch.nodes, graph_in_torch.edge_attr, 264 | graph_in_torch.edge_index) 265 | 266 | criterion = torch.nn.MSELoss() 267 | loss = criterion(node_out, graph_tgt_torch.nodes) 268 | loss.backward() 269 | grad_torch_phi_node_w = inet_torch.phi_node.weight.grad.detach().numpy().T 270 | grad_torch_phi_node_b = inet_torch.phi_node.bias.grad.detach().numpy() 271 | grad_torch_phi_edge_w = inet_torch.phi_edge.weight.grad.detach().numpy().T 272 | grad_torch_phi_edge_b = inet_torch.phi_edge.bias.grad.detach().numpy() 273 | 274 | np.testing.assert_allclose(grad_tf_phi_node_w, grad_torch_phi_node_w, 275 | err_msg="Phi node weights gradient does not match") 276 | np.testing.assert_allclose(grad_tf_phi_node_b, grad_torch_phi_node_b, 277 | err_msg="Phi node bias gradient does not match") 278 | np.testing.assert_allclose(grad_tf_phi_edge_w, grad_torch_phi_edge_w, 279 | err_msg="Phi edge weights gradient does not match") 280 | np.testing.assert_allclose(grad_tf_phi_edge_b, grad_torch_phi_edge_b, 281 | err_msg="Phi edge bias gradient does not match") 282 | print("Interaction network gradient passed") 283 | 284 | 285 | 286 | 287 | graph_tf = gn.utils_tf.get_graph(graphs_tuple, index=0) 288 | graph_tgt_tf = gn.utils_tf.get_graph(graphs_tuple, index=1) 289 | graph = from_graphstuple_to_gndata(graph_tf) 290 | graph_tgt = from_graphstuple_to_gndata(graph_tgt_tf) 291 | edge_size = graph.num_edge_features + graph.num_node_features + graph.num_node_features 292 | edge_out = 3 293 | node_size = graph.num_node_features + edge_out 294 | node_out = 3 295 | 296 | node_fn_tf, node_fn_torch = get_model_fn(node_size, node_out) 297 | edge_fn_tf, edge_fn_torch = get_model_fn(edge_size, edge_out) 298 | 299 | INet_tf, INet_torch = test_interaction_network_forward(graph, graph_tf, node_fn_tf, node_fn_torch, 300 | edge_fn_tf, edge_fn_torch) 301 | test_interaction_network_backward(INet_tf, INet_torch, graph_tf, graph_tgt_tf, graph, graph_tgt) 302 | 303 | 304 | 305 | if __name__ == "__main__": 306 | graphs_tuple = create_graph() 307 | test_gn_utils(graphs_tuple) 308 | test_blocks(graphs_tuple) 309 | test_in(graphs_tuple) 310 | print("Everything passed") -------------------------------------------------------------------------------- /test/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.20.1 2 | matplotlib>=3.3.4 3 | torch>=1.8.0 4 | torch-scatter==2.0.6 5 | torch-sparse==0.6.8 6 | torch-cluster==1.5.9 7 | torch-spline-conv==1.2.1 8 | torch-geometric==1.6.3 9 | tensorflow>=2.1.0-rc1 10 | dm-sonnet>=2.0.0b0 11 | tensorflow_probability>=0.12.1 12 | graph_nets -------------------------------------------------------------------------------- /torch_graphnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 (c) Aalto University - All Rights Reserved 2 | # Author: David Blanco Mulero 3 | # 4 | from torch_graphnet.utils import receiver_nodes_to_edges, sender_nodes_to_edges, \ 5 | received_edges_to_node_aggregator, sent_edges_to_node_aggregator, context_to_nodes, \ 6 | context_to_edges 7 | from torch_graphnet.graph_networks import EdgeModel, NodeModel, InteractionNetwork, GraphNetwork, GraphIndependent 8 | 9 | -------------------------------------------------------------------------------- /torch_graphnet/graph_networks.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 (c) Aalto University - All Rights Reserved 2 | # Author: David Blanco Mulero 3 | # 4 | # This code is based on the PyTorch Geometric MetaLayer 5 | # https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=graph%20network#models 6 | # and the Graph Networks: Relational inductive biases, deep learning, and graph networks 7 | # https://github.com/deepmind/graph_nets 8 | 9 | import torch 10 | from torch import nn 11 | from torch_scatter import scatter_mean 12 | from torch_geometric.nn import MetaLayer 13 | 14 | from torch_graphnet.utils import sender_nodes_to_edges, receiver_nodes_to_edges, \ 15 | context_to_edges, context_to_nodes, received_edges_to_node_aggregator, \ 16 | sent_edges_to_node_aggregator 17 | 18 | 19 | class EdgeModel(nn.Module): 20 | def __init__(self, phi_edge: nn.Module, use_receiver_nodes=True, 21 | use_sender_nodes=True, use_context=False): 22 | super(EdgeModel, self).__init__() 23 | 24 | self.use_receiver_nodes = use_receiver_nodes 25 | self.use_sender_nodes = use_sender_nodes 26 | self.use_context = use_context 27 | self.phi_edge = phi_edge 28 | 29 | def forward(self, nodes, edge_attr, edge_index, context=None): 30 | """ 31 | Args: 32 | src: [E, F_x], where E is the number of edges. 33 | edge_attr: [E, F_e], where E is the number of edges. 34 | edge_index: [2, E]: where 0 are the senders and 1 the recievers 35 | u: [B, F_u], where B is the number of graphs. 36 | batch: [E] with max entry B - 1. 37 | 38 | Returns: updated edges [E] 39 | 40 | """ 41 | if self.use_context and _context is None: 42 | raise ValueError("EdgeModel use_globals set to True and globals not provided") 43 | x = [edge_attr] 44 | if self.use_receiver_nodes: x.append(receiver_nodes_to_edges(nodes, edge_index[1, :])) 45 | if self.use_sender_nodes: x.append(sender_nodes_to_edges(nodes, edge_index[0, :])) 46 | if self.use_context: x.append(context_to_edges(edge_attr, _context)) 47 | 48 | x = torch.cat(x, dim=-1) 49 | 50 | return self.phi_edge(x) 51 | 52 | 53 | class NodeModel(torch.nn.Module): 54 | def __init__(self, phi_node: nn.Module, use_received_edges=True, 55 | use_sent_edges=False, use_context=False, reduce:str= 'sum'): 56 | super(NodeModel, self).__init__() 57 | self.phi_node = phi_node 58 | self.use_received_edges = use_received_edges 59 | self.use_sent_edges = use_sent_edges 60 | self.use_context = use_context 61 | self.reduce = reduce 62 | 63 | def forward(self, nodes, edge_attr, edge_index, context=None): 64 | # def forward(self, graph:GNData): 65 | """ 66 | Args: 67 | x: [N, F_x], where N is the number of nodes. 68 | edge_attr: [E, F_e], where E is the number of edges. 69 | edge_index: [2, E]: where 0 are the senders and 1 the recievers 70 | u: [B, F_u] 71 | Returns: 72 | """ 73 | if self.use_context and context is None: 74 | raise ValueError("EdgeModel use_globals set to True and globals not provided") 75 | x = [] 76 | if self.use_received_edges: 77 | x.append(received_edges_to_node_aggregator(nodes, edge_attr, edge_index[1, :], reduce=self.reduce)) 78 | if self.use_sent_edges: 79 | x.append(sent_edges_to_node_aggregator(nodes, edge_attr, edge_index[0, :], reduce=self.reduce)) 80 | if self.use_context: x.append(context_to_nodes(nodes, context)) 81 | 82 | x.append(nodes) 83 | x = torch.cat(x, dim=1) 84 | 85 | return self.phi_node(x) 86 | 87 | 88 | class GraphNetwork(nn.Module): 89 | def __init__(self, phi_edge:nn.Module, phi_node:nn.Module, phi_context:nn.Module): 90 | super(GraphNetwork, self).__init__() 91 | self.phi_edge = phi_edge 92 | self.phi_node = phi_node 93 | self.phi_global = phi_context 94 | 95 | 96 | def forward(self, nodes, edge_attr, edge_index, context=None): 97 | raise NotImplementedError 98 | 99 | 100 | class InteractionNetwork(GraphNetwork): 101 | 102 | def __init__(self, phi_edge: nn.Module, phi_node: nn.Module): 103 | super(InteractionNetwork, self).__init__(phi_edge=phi_edge, phi_node=phi_node, phi_context=None) 104 | # TODO Assert output of phi edge is input of phi node 105 | # TODO Assert input of phi edge and phi node is correct (ask for shapes of nodes and edge_attr. 106 | 107 | self.edge_model = EdgeModel(phi_edge=self.phi_edge, use_receiver_nodes=True, 108 | use_sender_nodes=True, use_context=False) 109 | self.node_model = NodeModel(phi_node=self.phi_node, use_received_edges=True, 110 | use_sent_edges=False, use_context=False) 111 | 112 | def forward(self, nodes, edge_attr, edge_index): 113 | # We output the senders and receivers so that we can 114 | # create a sequential module 115 | edge_out = self.edge_model(nodes, edge_attr, edge_index) 116 | node_out = self.node_model(nodes, edge_out, edge_index) 117 | 118 | return node_out, edge_out, edge_index 119 | 120 | 121 | class GraphIndependent(GraphNetwork): 122 | def __init__(self, phi_edge:nn.Module=None, phi_node:nn.Module=None, phi_context:nn.Module=None): 123 | if phi_edge is None: 124 | phi_edge = lambda x: x 125 | if phi_node is None: 126 | phi_node = lambda x: x 127 | if phi_context is None: 128 | phi_context = lambda x: x 129 | super(GraphIndependent, self).__init__(phi_edge, phi_node, phi_context) 130 | 131 | def forward(self, nodes, edge_attr, edge_index, context=None): 132 | 133 | node_out = self.phi_node(nodes) 134 | edge_out = self.phi_edge(edge_attr) 135 | context_out = self.phi_global(context) 136 | 137 | return node_out, edge_out, context_out 138 | # 139 | # size_in = self._get_input_size(graph) 140 | # # TODO We could leave the phi edge as an input 141 | # # However we would need a requirement of the input size 142 | # # Without an example input graph, that would not be possible 143 | # # We can create this and specify it in the MetaLayer 144 | # self.phi_edge = nn.Linear(size_in, output_size) 145 | # 146 | # def _get_input_size(self, graph:GNData): 147 | # size_in = graph.num_edges 148 | # if self.use_receiver_nodes: size_in += graph.num_nodes 149 | # if self.use_sender_nodes: size_in += graph.num_nodes 150 | # if self.use_globals: size_in += graph.num_global 151 | # 152 | # return size_in -------------------------------------------------------------------------------- /torch_graphnet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 (c) Aalto University - All Rights Reserved 2 | # Author: David Blanco Mulero 3 | # 4 | 5 | from torch_graphnet.utils.gn_utils import receiver_nodes_to_edges, sender_nodes_to_edges, \ 6 | received_edges_to_node_aggregator, sent_edges_to_node_aggregator, context_to_nodes, \ 7 | context_to_edges 8 | -------------------------------------------------------------------------------- /torch_graphnet/utils/gn_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 (c) Aalto University - All Rights Reserved 2 | # Author: David Blanco Mulero 3 | # 4 | import torch 5 | import torch_scatter 6 | 7 | @torch.jit.script 8 | def receiver_nodes_to_edges(nodes:torch.Tensor, receivers:torch.Tensor): 9 | """ 10 | Gather the receiver nodes from the graph 11 | Args: 12 | graph: GNData 13 | 14 | Returns: tensor of shape [E, F_x] 15 | 16 | """ 17 | return nodes[receivers, :] 18 | 19 | @torch.jit.script 20 | def sender_nodes_to_edges(nodes:torch.Tensor, senders:torch.Tensor): 21 | """ 22 | Gather the receiver nodes from the graph 23 | Args: 24 | graph: GNData 25 | 26 | Returns: tensor of shape [E, F_x] 27 | """ 28 | return nodes[senders, :] 29 | 30 | @torch.jit.script 31 | def context_to_edges(edge_attr:torch.Tensor, global_context:torch.Tensor): 32 | """ 33 | Broadcasts the global features to the edges of the graph 34 | Args: 35 | graph: GNData 36 | 37 | Returns: tensor of shape [E, F_u] 38 | """ 39 | return global_context.repeat(edge_attr.shape[0], 1) 40 | 41 | @torch.jit.script 42 | def context_to_nodes(nodes:torch.Tensor, global_context:torch.Tensor): 43 | """ 44 | Broadcasts the global features to the edges of the graph 45 | Args: 46 | graph: GNData 47 | 48 | Returns: tensor of shape [N, F_u] 49 | """ 50 | return global_context.repeat(nodes.shape[0], 1) 51 | 52 | 53 | def received_edges_to_node_aggregator(nodes, edge_attr, receivers, reduce:str): 54 | return scatter_sum(edge_attr, receivers, nodes.shape[0]) 55 | 56 | 57 | def sent_edges_to_node_aggregator(nodes, edge_attr, senders, reduce:str): 58 | return scatter_sum(edge_attr, senders, nodes.shape[0]) 59 | 60 | @torch.jit.script 61 | def scatter_sum(src: torch.Tensor, idx:torch.Tensor, out_segments: int): 62 | out = torch.zeros(out_segments, src.shape[1], dtype=src.dtype, device=src.device) 63 | torch_scatter.scatter_add(src, idx, dim=0, out=out) 64 | return out 65 | 66 | --------------------------------------------------------------------------------