├── .gitignore ├── images └── graph-network.png ├── __pycache__ ├── blocks.cpython-36.pyc ├── graphs.cpython-36.pyc ├── modules.cpython-36.pyc └── utils.cpython-36.pyc ├── .idea ├── libraries │ └── R_User_Library.xml ├── modules.xml ├── misc.xml ├── graph_net_pytorch.iml └── workspace.xml ├── utils.py ├── modules.py ├── demo.py ├── README.md ├── graphs.py └── blocks.py /.gitignore: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/graph-network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/images/graph-network.png -------------------------------------------------------------------------------- /__pycache__/blocks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/blocks.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/graphs.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/graphs.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/modules.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/modules.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /.idea/libraries/R_User_Library.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import blocks 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from random import randint 7 | from graphs import GraphsTuple 8 | 9 | 10 | def data_dicts_to_graphs_tuple(graph_dicts:dict): 11 | for k,v in graph_dicts.items(): 12 | graph_dicts[k]=torch.tensor(v) 13 | return GraphsTuple(**graph_dicts) -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import blocks 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from random import randint 7 | from graphs import GraphsTuple 8 | 9 | class GraphNetwork(nn.Module): 10 | def __init__(self,graph): 11 | super(GraphNetwork,self).__init__() 12 | self._edge_block = blocks.EdgeBlock(graph) 13 | self._node_block = blocks.NodeBlock(graph) 14 | self._global_block = blocks.GlobalBlock(graph) 15 | def forward(self, graph): 16 | return self._node_block(self._edge_block(graph)) -------------------------------------------------------------------------------- /.idea/graph_net_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import networkx as nx 3 | import numpy as np 4 | import modules 5 | import utils 6 | 7 | def get_graph_data_dict(num_nodes, num_edges): 8 | GLOBAL_SIZE = 4 9 | NODE_SIZE = 5 10 | EDGE_SIZE = 6 11 | return { 12 | "globals": np.random.rand(GLOBAL_SIZE).astype(np.float32), 13 | "nodes": np.random.rand(num_nodes, NODE_SIZE).astype(np.float32), 14 | "edges": np.random.rand(num_edges, EDGE_SIZE).astype(np.float32), 15 | "senders": np.random.randint(num_nodes, size=num_edges, dtype=np.int32), 16 | "receivers": np.random.randint(num_nodes, size=num_edges, dtype=np.int32), 17 | } 18 | 19 | 20 | graph_dicts = get_graph_data_dict(num_nodes=9, num_edges=25) 21 | input_graphs = utils.data_dicts_to_graphs_tuple(graph_dicts) 22 | 23 | print('input_graphs') 24 | print(input_graphs) 25 | 26 | graph_network = modules.GraphNetwork(input_graphs) 27 | 28 | output_graphs = graph_network(input_graphs) 29 | 30 | print('output_graphs') 31 | print(output_graphs) 32 | 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Nets implement by pytorch 2 | 3 | 4 | 5 | [Graph Nets](https://github.com/deepmind/graph_nets) is DeepMind's library for 6 | building graph networks in Tensorflow and Sonnet.You can see it in https://github.com/deepmind/graph_nets 7 | 8 | I have implemented `Graph Nets` by `Pytorch` framework. You can see my work in https://github.com/TQCAI/graph_nets_pytorch 9 | 10 | #### What are graph networks? 11 | 12 | A graph network takes a graph as input and returns a graph as output. The input 13 | graph has edge- (*E* ), node- (*V* ), and global-level (**u**) attributes. The 14 | output graph has the same structure, but updated attributes. Graph networks are 15 | part of the broader family of "graph neural networks" (Scarselli et al., 2009). 16 | 17 | To learn more about graph networks, see our arXiv paper: [Relational inductive 18 | biases, deep learning, and graph networks](https://arxiv.org/abs/1806.01261). 19 | 20 | ![Graph network](images/graph-network.png) 21 | 22 | 23 | 24 | ## Usage example 25 | 26 | You can see a forward calculation in `demo.py` -------------------------------------------------------------------------------- /graphs.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import collections 4 | 5 | 6 | NODES = "nodes" 7 | EDGES = "edges" 8 | RECEIVERS = "receivers" 9 | SENDERS = "senders" 10 | GLOBALS = "globals" 11 | N_NODE = "n_node" 12 | N_EDGE = "n_edge" 13 | 14 | GRAPH_FEATURE_FIELDS = (NODES, EDGES, GLOBALS) 15 | GRAPH_INDEX_FIELDS = (RECEIVERS, SENDERS) 16 | GRAPH_DATA_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS) 17 | 18 | 19 | class GraphsTuple( 20 | collections.namedtuple("GraphsTuple", 21 | GRAPH_DATA_FIELDS )): 22 | 23 | def __init__(self, *args, **kwargs): 24 | del args, kwargs 25 | # The fields of a `namedtuple` are filled in the `__new__` method. 26 | # `__init__` does not accept parameters. 27 | super(GraphsTuple, self).__init__() 28 | 29 | def replace(self, **kwargs): 30 | output = self._replace(**kwargs) 31 | return output 32 | 33 | def map(self, field_fn, fields=GRAPH_FEATURE_FIELDS): 34 | """Applies `field_fn` to the fields `fields` of the instance. 35 | 36 | `field_fn` is applied exactly once per field in `fields`. The result must 37 | satisfy the `GraphsTuple` requirement w.r.t. `None` fields, i.e. the 38 | `SENDERS` cannot be `None` if the `EDGES` or `RECEIVERS` are not `None`, 39 | etc. 40 | 41 | Args: 42 | field_fn: A callable that take a single argument. 43 | fields: (iterable of `str`). An iterable of the fields to apply 44 | `field_fn` to. 45 | 46 | Returns: 47 | A copy of the instance, with the fields in `fields` replaced by the result 48 | of applying `field_fn` to them. 49 | """ 50 | return self.replace(**{k: field_fn(getattr(self, k)) for k in fields}) -------------------------------------------------------------------------------- /blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from random import randint 6 | from graphs import GraphsTuple 7 | 8 | 9 | # import utils_tf 10 | 11 | def broadcast_receiver_nodes_to_edges(graph: GraphsTuple): 12 | return graph.nodes.index_select(index=graph.receivers.long(), dim=0) 13 | 14 | 15 | def broadcast_sender_nodes_to_edges(graph: GraphsTuple): 16 | return graph.nodes.index_select(index=graph.senders.long(), dim=0) 17 | 18 | 19 | def broadcast_globals_to_edges(graph: GraphsTuple): 20 | N_edges = graph.edges.shape[0] 21 | return graph.globals.repeat(N_edges, 1) 22 | 23 | 24 | def broadcast_globals_to_nodes(graph: GraphsTuple): 25 | N_nodes = graph.nodes.shape[0] 26 | return graph.globals.repeat(N_nodes, 1) 27 | 28 | 29 | class Aggregator(nn.Module): 30 | def __init__(self, mode): 31 | super(Aggregator, self).__init__() 32 | self.mode = mode 33 | 34 | def forward(self, graph): 35 | edges = graph.edges 36 | nodes = graph.nodes 37 | if self.mode == 'receivers': 38 | indeces = graph.receivers 39 | elif self.mode == 'senders': 40 | indeces = graph.senders 41 | else: 42 | raise AttributeError("invalid parameter `mode`") 43 | N_edges, N_features = edges.shape 44 | N_nodes=nodes.shape[0] 45 | aggrated_list = [] 46 | for i in range(N_nodes): 47 | aggrated = edges[indeces == i] 48 | if aggrated.shape[0] == 0: 49 | aggrated = torch.zeros(1, N_features) 50 | aggrated_list.append(torch.sum(aggrated, dim=0)) 51 | return torch.stack(aggrated_list,dim=0) 52 | 53 | 54 | class EdgeBlock(nn.Module): 55 | def __init__(self, 56 | graph: GraphsTuple, 57 | use_edges=True, 58 | use_receiver_nodes=True, 59 | use_sender_nodes=True, 60 | use_globals=True): 61 | super(EdgeBlock, self).__init__() 62 | self._use_edges = use_edges 63 | self._use_receiver_nodes = use_receiver_nodes 64 | self._use_sender_nodes = use_sender_nodes 65 | self._use_globals = use_globals 66 | N_features = 0 67 | pre_features=graph.edges.shape[-1] 68 | if self._use_edges: 69 | N_features += graph.edges.shape[-1] 70 | if self._use_receiver_nodes: 71 | N_features += graph.nodes.shape[-1] 72 | if self._use_sender_nodes: 73 | N_features += graph.nodes.shape[-1] 74 | if self._use_globals: 75 | N_features += graph.globals.shape[-1] 76 | self.linear = nn.Linear(N_features, pre_features) 77 | 78 | def forward(self, graph: GraphsTuple): 79 | edges_to_collect = [] 80 | 81 | if self._use_edges: 82 | edges_to_collect.append(graph.edges) # edge feature (50,6) 50边,6特征 83 | 84 | if self._use_receiver_nodes: 85 | edges_to_collect.append(broadcast_receiver_nodes_to_edges(graph)) # (50,5) 86 | # 顶点有5个特征 receiver=(50,) 表示 每个边的汇点index 87 | # 得到的是每个边发射终点的顶点的feature 88 | 89 | if self._use_sender_nodes: 90 | edges_to_collect.append(broadcast_sender_nodes_to_edges(graph)) # (50,5) 91 | # 同上,只不过换成了起点 92 | 93 | if self._use_globals: 94 | edges_to_collect.append(broadcast_globals_to_edges(graph)) # (50,) 95 | 96 | collected_edges = torch.cat(edges_to_collect, dim=1) 97 | updated_edges = self.linear(collected_edges) 98 | return graph.replace(edges=updated_edges) 99 | 100 | 101 | class NodeBlock(nn.Module): 102 | 103 | def __init__(self, 104 | graph, 105 | use_received_edges=True, 106 | use_sent_edges=False, 107 | use_nodes=True, 108 | use_globals=True): 109 | super(NodeBlock, self).__init__() 110 | self._use_received_edges = use_received_edges 111 | self._use_sent_edges = use_sent_edges 112 | self._use_nodes = use_nodes 113 | self._use_globals = use_globals 114 | N_features = 0 115 | pre_features=graph.nodes.shape[-1] 116 | if self._use_nodes: 117 | N_features += graph.nodes.shape[-1] 118 | if self._use_received_edges: 119 | N_features += graph.edges.shape[-1] 120 | if self._use_sent_edges: 121 | N_features += graph.edges.shape[-1] 122 | if self._use_globals: 123 | N_features += graph.globals.shape[-1] 124 | self.linear = nn.Linear(N_features, pre_features) 125 | self._received_edges_aggregator = Aggregator('receivers') 126 | self._sent_edges_aggregator = Aggregator('senders') 127 | 128 | def forward(self, graph): 129 | 130 | nodes_to_collect = [] 131 | # nodes: (24,5) 132 | # edges: (50,10) # 上一轮更新了 133 | # global: (4,4) 134 | 135 | if self._use_received_edges: 136 | nodes_to_collect.append(self._received_edges_aggregator(graph)) # (24,10) 137 | # 在上一轮对边的处理中, 使用的是 _received_nodes_aggregator 将边相连的顶点信息考虑进来 138 | # 现在是将与顶点相连的边考虑进来 139 | 140 | if self._use_sent_edges: 141 | nodes_to_collect.append(self._sent_edges_aggregator(graph)) 142 | 143 | if self._use_nodes: 144 | nodes_to_collect.append(graph.nodes) 145 | 146 | if self._use_globals: 147 | nodes_to_collect.append(broadcast_globals_to_nodes(graph)) # (24,4) 148 | 149 | collected_nodes = torch.cat(nodes_to_collect, dim=1) # 24,19 150 | updated_nodes = self.linear(collected_nodes) # 24,11 151 | return graph.replace(nodes=updated_nodes) 152 | 153 | 154 | class GlobalBlock(nn.Module): 155 | def __init__(self, 156 | use_edges=True, 157 | use_nodes=True, 158 | use_globals=True): 159 | 160 | super(GlobalBlock, self).__init__() 161 | 162 | self._use_edges = use_edges 163 | self._use_nodes = use_nodes 164 | self._use_globals = use_globals 165 | 166 | 167 | def forward(self, graph): 168 | globals_to_collect = [] 169 | 170 | if self._use_edges: 171 | globals_to_collect.append(self._edges_aggregator(graph)) 172 | 173 | if self._use_nodes: 174 | globals_to_collect.append(self._nodes_aggregator(graph)) 175 | 176 | if self._use_globals: 177 | globals_to_collect.append(graph.globals) 178 | 179 | collected_globals = torch.cat(globals_to_collect, dim=1) 180 | updated_globals = self._global_model(collected_globals) 181 | return graph.replace(globals=updated_globals) 182 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 78 | 79 | 80 | 92 | 93 | 94 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 |