├── .gitignore
├── images
└── graph-network.png
├── __pycache__
├── blocks.cpython-36.pyc
├── graphs.cpython-36.pyc
├── modules.cpython-36.pyc
└── utils.cpython-36.pyc
├── .idea
├── libraries
│ └── R_User_Library.xml
├── modules.xml
├── misc.xml
├── graph_net_pytorch.iml
└── workspace.xml
├── utils.py
├── modules.py
├── demo.py
├── README.md
├── graphs.py
└── blocks.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/images/graph-network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/images/graph-network.png
--------------------------------------------------------------------------------
/__pycache__/blocks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/blocks.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/graphs.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/graphs.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TQCAI/graph_nets_pytorch/HEAD/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/libraries/R_User_Library.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import blocks
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from random import randint
7 | from graphs import GraphsTuple
8 |
9 |
10 | def data_dicts_to_graphs_tuple(graph_dicts:dict):
11 | for k,v in graph_dicts.items():
12 | graph_dicts[k]=torch.tensor(v)
13 | return GraphsTuple(**graph_dicts)
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import blocks
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from random import randint
7 | from graphs import GraphsTuple
8 |
9 | class GraphNetwork(nn.Module):
10 | def __init__(self,graph):
11 | super(GraphNetwork,self).__init__()
12 | self._edge_block = blocks.EdgeBlock(graph)
13 | self._node_block = blocks.NodeBlock(graph)
14 | self._global_block = blocks.GlobalBlock(graph)
15 | def forward(self, graph):
16 | return self._node_block(self._edge_block(graph))
--------------------------------------------------------------------------------
/.idea/graph_net_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import networkx as nx
3 | import numpy as np
4 | import modules
5 | import utils
6 |
7 | def get_graph_data_dict(num_nodes, num_edges):
8 | GLOBAL_SIZE = 4
9 | NODE_SIZE = 5
10 | EDGE_SIZE = 6
11 | return {
12 | "globals": np.random.rand(GLOBAL_SIZE).astype(np.float32),
13 | "nodes": np.random.rand(num_nodes, NODE_SIZE).astype(np.float32),
14 | "edges": np.random.rand(num_edges, EDGE_SIZE).astype(np.float32),
15 | "senders": np.random.randint(num_nodes, size=num_edges, dtype=np.int32),
16 | "receivers": np.random.randint(num_nodes, size=num_edges, dtype=np.int32),
17 | }
18 |
19 |
20 | graph_dicts = get_graph_data_dict(num_nodes=9, num_edges=25)
21 | input_graphs = utils.data_dicts_to_graphs_tuple(graph_dicts)
22 |
23 | print('input_graphs')
24 | print(input_graphs)
25 |
26 | graph_network = modules.GraphNetwork(input_graphs)
27 |
28 | output_graphs = graph_network(input_graphs)
29 |
30 | print('output_graphs')
31 | print(output_graphs)
32 |
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Nets implement by pytorch
2 |
3 |
4 |
5 | [Graph Nets](https://github.com/deepmind/graph_nets) is DeepMind's library for
6 | building graph networks in Tensorflow and Sonnet.You can see it in https://github.com/deepmind/graph_nets
7 |
8 | I have implemented `Graph Nets` by `Pytorch` framework. You can see my work in https://github.com/TQCAI/graph_nets_pytorch
9 |
10 | #### What are graph networks?
11 |
12 | A graph network takes a graph as input and returns a graph as output. The input
13 | graph has edge- (*E* ), node- (*V* ), and global-level (**u**) attributes. The
14 | output graph has the same structure, but updated attributes. Graph networks are
15 | part of the broader family of "graph neural networks" (Scarselli et al., 2009).
16 |
17 | To learn more about graph networks, see our arXiv paper: [Relational inductive
18 | biases, deep learning, and graph networks](https://arxiv.org/abs/1806.01261).
19 |
20 | 
21 |
22 |
23 |
24 | ## Usage example
25 |
26 | You can see a forward calculation in `demo.py`
--------------------------------------------------------------------------------
/graphs.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import collections
4 |
5 |
6 | NODES = "nodes"
7 | EDGES = "edges"
8 | RECEIVERS = "receivers"
9 | SENDERS = "senders"
10 | GLOBALS = "globals"
11 | N_NODE = "n_node"
12 | N_EDGE = "n_edge"
13 |
14 | GRAPH_FEATURE_FIELDS = (NODES, EDGES, GLOBALS)
15 | GRAPH_INDEX_FIELDS = (RECEIVERS, SENDERS)
16 | GRAPH_DATA_FIELDS = (NODES, EDGES, RECEIVERS, SENDERS, GLOBALS)
17 |
18 |
19 | class GraphsTuple(
20 | collections.namedtuple("GraphsTuple",
21 | GRAPH_DATA_FIELDS )):
22 |
23 | def __init__(self, *args, **kwargs):
24 | del args, kwargs
25 | # The fields of a `namedtuple` are filled in the `__new__` method.
26 | # `__init__` does not accept parameters.
27 | super(GraphsTuple, self).__init__()
28 |
29 | def replace(self, **kwargs):
30 | output = self._replace(**kwargs)
31 | return output
32 |
33 | def map(self, field_fn, fields=GRAPH_FEATURE_FIELDS):
34 | """Applies `field_fn` to the fields `fields` of the instance.
35 |
36 | `field_fn` is applied exactly once per field in `fields`. The result must
37 | satisfy the `GraphsTuple` requirement w.r.t. `None` fields, i.e. the
38 | `SENDERS` cannot be `None` if the `EDGES` or `RECEIVERS` are not `None`,
39 | etc.
40 |
41 | Args:
42 | field_fn: A callable that take a single argument.
43 | fields: (iterable of `str`). An iterable of the fields to apply
44 | `field_fn` to.
45 |
46 | Returns:
47 | A copy of the instance, with the fields in `fields` replaced by the result
48 | of applying `field_fn` to them.
49 | """
50 | return self.replace(**{k: field_fn(getattr(self, k)) for k in fields})
--------------------------------------------------------------------------------
/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from random import randint
6 | from graphs import GraphsTuple
7 |
8 |
9 | # import utils_tf
10 |
11 | def broadcast_receiver_nodes_to_edges(graph: GraphsTuple):
12 | return graph.nodes.index_select(index=graph.receivers.long(), dim=0)
13 |
14 |
15 | def broadcast_sender_nodes_to_edges(graph: GraphsTuple):
16 | return graph.nodes.index_select(index=graph.senders.long(), dim=0)
17 |
18 |
19 | def broadcast_globals_to_edges(graph: GraphsTuple):
20 | N_edges = graph.edges.shape[0]
21 | return graph.globals.repeat(N_edges, 1)
22 |
23 |
24 | def broadcast_globals_to_nodes(graph: GraphsTuple):
25 | N_nodes = graph.nodes.shape[0]
26 | return graph.globals.repeat(N_nodes, 1)
27 |
28 |
29 | class Aggregator(nn.Module):
30 | def __init__(self, mode):
31 | super(Aggregator, self).__init__()
32 | self.mode = mode
33 |
34 | def forward(self, graph):
35 | edges = graph.edges
36 | nodes = graph.nodes
37 | if self.mode == 'receivers':
38 | indeces = graph.receivers
39 | elif self.mode == 'senders':
40 | indeces = graph.senders
41 | else:
42 | raise AttributeError("invalid parameter `mode`")
43 | N_edges, N_features = edges.shape
44 | N_nodes=nodes.shape[0]
45 | aggrated_list = []
46 | for i in range(N_nodes):
47 | aggrated = edges[indeces == i]
48 | if aggrated.shape[0] == 0:
49 | aggrated = torch.zeros(1, N_features)
50 | aggrated_list.append(torch.sum(aggrated, dim=0))
51 | return torch.stack(aggrated_list,dim=0)
52 |
53 |
54 | class EdgeBlock(nn.Module):
55 | def __init__(self,
56 | graph: GraphsTuple,
57 | use_edges=True,
58 | use_receiver_nodes=True,
59 | use_sender_nodes=True,
60 | use_globals=True):
61 | super(EdgeBlock, self).__init__()
62 | self._use_edges = use_edges
63 | self._use_receiver_nodes = use_receiver_nodes
64 | self._use_sender_nodes = use_sender_nodes
65 | self._use_globals = use_globals
66 | N_features = 0
67 | pre_features=graph.edges.shape[-1]
68 | if self._use_edges:
69 | N_features += graph.edges.shape[-1]
70 | if self._use_receiver_nodes:
71 | N_features += graph.nodes.shape[-1]
72 | if self._use_sender_nodes:
73 | N_features += graph.nodes.shape[-1]
74 | if self._use_globals:
75 | N_features += graph.globals.shape[-1]
76 | self.linear = nn.Linear(N_features, pre_features)
77 |
78 | def forward(self, graph: GraphsTuple):
79 | edges_to_collect = []
80 |
81 | if self._use_edges:
82 | edges_to_collect.append(graph.edges) # edge feature (50,6) 50边,6特征
83 |
84 | if self._use_receiver_nodes:
85 | edges_to_collect.append(broadcast_receiver_nodes_to_edges(graph)) # (50,5)
86 | # 顶点有5个特征 receiver=(50,) 表示 每个边的汇点index
87 | # 得到的是每个边发射终点的顶点的feature
88 |
89 | if self._use_sender_nodes:
90 | edges_to_collect.append(broadcast_sender_nodes_to_edges(graph)) # (50,5)
91 | # 同上,只不过换成了起点
92 |
93 | if self._use_globals:
94 | edges_to_collect.append(broadcast_globals_to_edges(graph)) # (50,)
95 |
96 | collected_edges = torch.cat(edges_to_collect, dim=1)
97 | updated_edges = self.linear(collected_edges)
98 | return graph.replace(edges=updated_edges)
99 |
100 |
101 | class NodeBlock(nn.Module):
102 |
103 | def __init__(self,
104 | graph,
105 | use_received_edges=True,
106 | use_sent_edges=False,
107 | use_nodes=True,
108 | use_globals=True):
109 | super(NodeBlock, self).__init__()
110 | self._use_received_edges = use_received_edges
111 | self._use_sent_edges = use_sent_edges
112 | self._use_nodes = use_nodes
113 | self._use_globals = use_globals
114 | N_features = 0
115 | pre_features=graph.nodes.shape[-1]
116 | if self._use_nodes:
117 | N_features += graph.nodes.shape[-1]
118 | if self._use_received_edges:
119 | N_features += graph.edges.shape[-1]
120 | if self._use_sent_edges:
121 | N_features += graph.edges.shape[-1]
122 | if self._use_globals:
123 | N_features += graph.globals.shape[-1]
124 | self.linear = nn.Linear(N_features, pre_features)
125 | self._received_edges_aggregator = Aggregator('receivers')
126 | self._sent_edges_aggregator = Aggregator('senders')
127 |
128 | def forward(self, graph):
129 |
130 | nodes_to_collect = []
131 | # nodes: (24,5)
132 | # edges: (50,10) # 上一轮更新了
133 | # global: (4,4)
134 |
135 | if self._use_received_edges:
136 | nodes_to_collect.append(self._received_edges_aggregator(graph)) # (24,10)
137 | # 在上一轮对边的处理中, 使用的是 _received_nodes_aggregator 将边相连的顶点信息考虑进来
138 | # 现在是将与顶点相连的边考虑进来
139 |
140 | if self._use_sent_edges:
141 | nodes_to_collect.append(self._sent_edges_aggregator(graph))
142 |
143 | if self._use_nodes:
144 | nodes_to_collect.append(graph.nodes)
145 |
146 | if self._use_globals:
147 | nodes_to_collect.append(broadcast_globals_to_nodes(graph)) # (24,4)
148 |
149 | collected_nodes = torch.cat(nodes_to_collect, dim=1) # 24,19
150 | updated_nodes = self.linear(collected_nodes) # 24,11
151 | return graph.replace(nodes=updated_nodes)
152 |
153 |
154 | class GlobalBlock(nn.Module):
155 | def __init__(self,
156 | use_edges=True,
157 | use_nodes=True,
158 | use_globals=True):
159 |
160 | super(GlobalBlock, self).__init__()
161 |
162 | self._use_edges = use_edges
163 | self._use_nodes = use_nodes
164 | self._use_globals = use_globals
165 |
166 |
167 | def forward(self, graph):
168 | globals_to_collect = []
169 |
170 | if self._use_edges:
171 | globals_to_collect.append(self._edges_aggregator(graph))
172 |
173 | if self._use_nodes:
174 | globals_to_collect.append(self._nodes_aggregator(graph))
175 |
176 | if self._use_globals:
177 | globals_to_collect.append(graph.globals)
178 |
179 | collected_globals = torch.cat(globals_to_collect, dim=1)
180 | updated_globals = self._global_model(collected_globals)
181 | return graph.replace(globals=updated_globals)
182 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 | 1562411616157
176 |
177 |
178 | 1562411616157
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 | file://$PROJECT_DIR$/blocks.py
225 | 148
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
--------------------------------------------------------------------------------