├── .gitignore ├── README.md ├── datasets.py ├── environment.yml ├── layers.py ├── models.py ├── outputs.txt ├── requirement.txt ├── run.sh ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | build/ 4 | dist/ 5 | alpha/ 6 | runs/ 7 | .cache/ 8 | .eggs/ 9 | *.egg-info/ 10 | .coverage 11 | .coverage.* 12 | .vscode 13 | .idea 14 | .code 15 | *.pyc 16 | .DS_Store 17 | desktop.ini 18 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Matching Networks 2 | 3 | Graph Matching Networks for Learning the Similarity of Graph Structured Objects 4 | 5 | Adapted with https://arxiv.org/abs/1904.12787 6 | 7 | Check all the env requirements and use `./run.sh` to run. 8 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import six 2 | import abc 3 | import contextlib 4 | import random 5 | import collections 6 | import copy 7 | 8 | import numpy as np 9 | import networkx as nx 10 | import tensorflow as tf 11 | GraphData = collections.namedtuple( 12 | "GraphData", 13 | ["from_idx", "to_idx", "node_features", "edge_features", "graph_idx", "n_graphs"], 14 | ) 15 | 16 | """A general Interface""" 17 | class GraphSimilarityDataset(object): 18 | """Base class for all the graph similarity learning datasets. 19 | 20 | This class defines some common interfaces a graph similarity dataset can have, 21 | in particular the functions that creates iterators over pairs and triplets. 22 | """ 23 | 24 | @abc.abstractmethod 25 | def triplets(self, batch_size): 26 | """Create an iterator over triplets. 27 | 28 | Args: 29 | batch_size: int, number of triplets in a batch. 30 | 31 | Yields: 32 | graphs: a `GraphData` instance. The batch of triplets put together. Each 33 | triplet has 3 graphs (x, y, z). Here the first graph is duplicated once 34 | so the graphs for each triplet are ordered as (x, y, x, z) in the batch. 35 | The batch contains `batch_size` number of triplets, hence `4*batch_size` 36 | many graphs. 37 | """ 38 | pass 39 | 40 | @abc.abstractmethod 41 | def pairs(self, batch_size): 42 | """Create an iterator over pairs. 43 | 44 | Args: 45 | batch_size: int, number of pairs in a batch. 46 | 47 | Yields: 48 | graphs: a `GraphData` instance. The batch of pairs put together. Each 49 | pair has 2 graphs (x, y). The batch contains `batch_size` number of 50 | pairs, hence `2*batch_size` many graphs. 51 | labels: [batch_size] int labels for each pair, +1 for similar, -1 for not. 52 | """ 53 | pass 54 | 55 | """Graph Edit Distance Task""" 56 | 57 | # Graph Manipulation Functions 58 | def permute_graph_nodes(g): 59 | """Permute node ordering of a graph, returns a new graph.""" 60 | n = g.number_of_nodes() 61 | new_g = nx.Graph() 62 | new_g.add_nodes_from(range(n)) 63 | perm = np.random.permutation(n) 64 | edges = g.edges() 65 | new_edges = [] 66 | for x, y in edges: 67 | new_edges.append((perm[x], perm[y])) 68 | new_g.add_edges_from(new_edges) 69 | return new_g 70 | 71 | def substitute_random_edges(g, n): 72 | """Substitutes n edges from graph g with another n randomly picked edges.""" 73 | g = copy.deepcopy(g) 74 | n_nodes = g.number_of_nodes() 75 | edges = list(g.edges()) 76 | # sample n edges without replacement 77 | e_remove = [ 78 | edges[i] for i in np.random.choice(np.arange(len(edges)), n, replace=False) 79 | ] 80 | edge_set = set(edges) 81 | e_add = set() 82 | while len(e_add) < n: 83 | e = np.random.choice(n_nodes, 2, replace=False) 84 | # make sure e does not exist and is not already chosen to be added 85 | if ( 86 | (e[0], e[1]) not in edge_set 87 | and (e[1], e[0]) not in edge_set 88 | and (e[0], e[1]) not in e_add 89 | and (e[1], e[0]) not in e_add 90 | ): 91 | e_add.add((e[0], e[1])) 92 | 93 | for i, j in e_remove: 94 | g.remove_edge(i, j) 95 | for i, j in e_add: 96 | g.add_edge(i, j) 97 | return g 98 | 99 | 100 | class GraphEditDistanceDataset(GraphSimilarityDataset): 101 | """Graph edit distance dataset.""" 102 | 103 | def __init__( 104 | self, 105 | n_nodes_range, 106 | p_edge_range, 107 | n_changes_positive, 108 | n_changes_negative, 109 | permute=True, 110 | ): 111 | """Constructor. 112 | 113 | Args: 114 | n_nodes_range: a tuple (n_min, n_max). The minimum and maximum number of 115 | nodes in a graph to generate. 116 | p_edge_range: a tuple (p_min, p_max). The minimum and maximum edge 117 | probability. 118 | n_changes_positive: the number of edge substitutions for a pair to be 119 | considered positive (similar). 120 | n_changes_negative: the number of edge substitutions for a pair to be 121 | considered negative (not similar). 122 | permute: if True (default), permute node orderings in addition to 123 | changing edges; if False, the node orderings across a pair or triplet of 124 | graphs will be the same, useful for visualization. 125 | """ 126 | self._n_min, self._n_max = n_nodes_range 127 | self._p_min, self._p_max = p_edge_range 128 | self._k_pos = n_changes_positive 129 | self._k_neg = n_changes_negative 130 | self._permute = permute 131 | 132 | def _get_graph(self): 133 | """Generate one graph.""" 134 | n_nodes = np.random.randint(self._n_min, self._n_max + 1) 135 | p_edge = np.random.uniform(self._p_min, self._p_max) 136 | 137 | # do a little bit of filtering 138 | n_trials = 100 139 | for _ in range(n_trials): 140 | g = nx.erdos_renyi_graph(n_nodes, p_edge) 141 | if nx.is_connected(g): 142 | return g 143 | 144 | raise ValueError("Failed to generate a connected graph.") 145 | 146 | def _get_pair(self, positive): 147 | """Generate one pair of graphs.""" 148 | g = self._get_graph() 149 | if self._permute: 150 | permuted_g = permute_graph_nodes(g) 151 | else: 152 | permuted_g = g 153 | n_changes = self._k_pos if positive else self._k_neg 154 | changed_g = substitute_random_edges(g, n_changes) 155 | return permuted_g, changed_g 156 | 157 | def _get_triplet(self): 158 | """Generate one triplet of graphs.""" 159 | g = self._get_graph() 160 | if self._permute: 161 | permuted_g = permute_graph_nodes(g) 162 | else: 163 | permuted_g = g 164 | pos_g = substitute_random_edges(g, self._k_pos) 165 | neg_g = substitute_random_edges(g, self._k_neg) 166 | return permuted_g, pos_g, neg_g 167 | 168 | def triplets(self, batch_size): 169 | """Yields batches of triplet data.""" 170 | while True: 171 | batch_graphs = [] 172 | for _ in range(batch_size): 173 | g1, g2, g3 = self._get_triplet() 174 | batch_graphs.append((g1, g2, g1, g3)) 175 | yield self._pack_batch(batch_graphs) 176 | 177 | def pairs(self, batch_size): 178 | """Yields batches of pair data.""" 179 | while True: 180 | batch_graphs = [] 181 | batch_labels = [] 182 | positive = True 183 | for _ in range(batch_size): 184 | g1, g2 = self._get_pair(positive) 185 | batch_graphs.append((g1, g2)) 186 | batch_labels.append(1 if positive else -1) 187 | positive = not positive 188 | 189 | packed_graphs = self._pack_batch(batch_graphs) 190 | labels = np.array(batch_labels, dtype=np.int32) 191 | yield packed_graphs, labels 192 | 193 | def _pack_batch(self, graphs): 194 | """Pack a batch of graphs into a single `GraphData` instance. 195 | 196 | Args: 197 | graphs: a list of generated networkx graphs. 198 | 199 | Returns: 200 | graph_data: a `GraphData` instance, with node and edge indices properly 201 | shifted. 202 | """ 203 | graphs = tf.nest.flatten(graphs) 204 | from_idx = [] 205 | to_idx = [] 206 | graph_idx = [] 207 | 208 | n_total_nodes = 0 209 | n_total_edges = 0 210 | for i, g in enumerate(graphs): 211 | n_nodes = g.number_of_nodes() 212 | n_edges = g.number_of_edges() 213 | edges = np.array(g.edges(), dtype=np.int32) 214 | # shift the node indices for the edges 215 | from_idx.append(edges[:, 0] + n_total_nodes) 216 | to_idx.append(edges[:, 1] + n_total_nodes) 217 | graph_idx.append(np.ones(n_nodes, dtype=np.int32) * i) 218 | 219 | n_total_nodes += n_nodes 220 | n_total_edges += n_edges 221 | 222 | return GraphData( 223 | from_idx=np.concatenate(from_idx, axis=0), 224 | to_idx=np.concatenate(to_idx, axis=0), 225 | # this task only cares about the structures, the graphs have no features 226 | node_features=np.ones((n_total_nodes, 1), dtype=np.float32), 227 | edge_features=np.ones((n_total_edges, 1), dtype=np.float32), 228 | graph_idx=np.concatenate(graph_idx, axis=0), 229 | n_graphs=len(graphs), 230 | ) 231 | 232 | # Use Fixed datasets for evaluation 233 | @contextlib.contextmanager 234 | def reset_random_state(seed): 235 | """This function creates a context that uses the given seed.""" 236 | np_rnd_state = np.random.get_state() 237 | rnd_state = random.getstate() 238 | np.random.seed(seed) 239 | random.seed(seed + 1) 240 | try: 241 | yield 242 | finally: 243 | random.setstate(rnd_state) 244 | np.random.set_state(np_rnd_state) 245 | 246 | 247 | class FixedGraphEditDistanceDataset(GraphEditDistanceDataset): 248 | """A fixed dataset of pairs or triplets for the graph edit distance task. 249 | 250 | This dataset can be used for evaluation. 251 | """ 252 | 253 | def __init__( 254 | self, 255 | n_nodes_range, 256 | p_edge_range, 257 | n_changes_positive, 258 | n_changes_negative, 259 | dataset_size, 260 | permute=True, 261 | seed=1234, 262 | ): 263 | super(FixedGraphEditDistanceDataset, self).__init__( 264 | n_nodes_range, 265 | p_edge_range, 266 | n_changes_positive, 267 | n_changes_negative, 268 | permute=permute, 269 | ) 270 | self._dataset_size = dataset_size 271 | self._seed = seed 272 | 273 | def triplets(self, batch_size): 274 | """Yield triplets.""" 275 | 276 | if hasattr(self, "_triplets"): 277 | triplets = self._triplets 278 | else: 279 | # get a fixed set of triplets 280 | with reset_random_state(self._seed): 281 | triplets = [] 282 | for _ in range(self._dataset_size): 283 | g1, g2, g3 = self._get_triplet() 284 | triplets.append((g1, g2, g1, g3)) 285 | self._triplets = triplets 286 | 287 | ptr = 0 288 | while ptr + batch_size <= len(triplets): 289 | batch_graphs = triplets[ptr : ptr + batch_size] 290 | yield self._pack_batch(batch_graphs) 291 | ptr += batch_size 292 | 293 | def pairs(self, batch_size): 294 | """Yield pairs and labels.""" 295 | 296 | if hasattr(self, "_pairs") and hasattr(self, "_labels"): 297 | pairs = self._pairs 298 | labels = self._labels 299 | else: 300 | # get a fixed set of pairs first 301 | with reset_random_state(self._seed): 302 | pairs = [] 303 | labels = [] 304 | positive = True 305 | for _ in range(self._dataset_size): 306 | pairs.append(self._get_pair(positive)) 307 | labels.append(1 if positive else -1) 308 | positive = not positive 309 | labels = np.array(labels, dtype=np.int32) 310 | 311 | self._pairs = pairs 312 | self._labels = labels 313 | 314 | ptr = 0 315 | while ptr + batch_size <= len(pairs): 316 | batch_graphs = pairs[ptr : ptr + batch_size] 317 | packed_batch = self._pack_batch(batch_graphs) 318 | yield packed_batch, labels[ptr : ptr + batch_size] 319 | ptr += batch_size 320 | 321 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gmn 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - ca-certificates=2020.6.24=0 7 | - certifi=2020.6.20=py36_0 8 | - ld_impl_linux-64=2.33.1=h53a641e_7 9 | - libedit=3.1.20191231=h14c3975_1 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.1.0=hdf63c60_0 12 | - libstdcxx-ng=9.1.0=hdf63c60_0 13 | - ncurses=6.2=he6710b0_1 14 | - openssl=1.1.1g=h7b6447c_0 15 | - pip=20.2.2=py36_0 16 | - python=3.6.10=h7579374_2 17 | - readline=8.0=h7b6447c_0 18 | - setuptools=49.6.0=py36_0 19 | - sqlite=3.32.3=h62c20be_0 20 | - tk=8.6.10=hbc83047_0 21 | - wheel=0.34.2=py36_0 22 | - xz=5.2.5=h7b6447c_0 23 | - zlib=1.2.11=h7b6447c_3 24 | - pip: 25 | - absl-py==0.9.0 26 | - astor==0.8.1 27 | - cloudpickle==1.5.0 28 | - contextlib2==0.6.0.post1 29 | - cycler==0.10.0 30 | - decorator==4.4.2 31 | - dm-sonnet==1.34 32 | - dm-tree==0.1.5 33 | - gast==0.2.2 34 | - google-pasta==0.2.0 35 | - grpcio==1.31.0 36 | - h5py==2.10.0 37 | - importlib-metadata==1.7.0 38 | - keras-applications==1.0.8 39 | - keras-preprocessing==1.1.2 40 | - kiwisolver==1.2.0 41 | - markdown==3.2.2 42 | - matplotlib==3.1.1 43 | - networkx==2.3 44 | - numpy==1.16.4 45 | - protobuf==3.13.0 46 | - pyparsing==2.4.7 47 | - python-dateutil==2.8.1 48 | - semantic-version==2.8.5 49 | - six==1.12.0 50 | - tensorboard==1.14.0 51 | - tensorflow==1.14.0 52 | - tensorflow-estimator==1.14.0 53 | - tensorflow-probability==0.7.0rc0 54 | - termcolor==1.1.0 55 | - werkzeug==1.0.1 56 | - wrapt==1.12.1 57 | - zipp==3.1.0 58 | 59 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import sonnet as snt 2 | import tensorflow as tf 3 | from utils import * 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | 7 | 8 | """Graph Embedding Network""" 9 | class GraphEncoder(snt.AbstractModule): 10 | """Encoder module that projects node and edge features to some embeddings.""" 11 | 12 | def __init__( 13 | self, node_hidden_sizes=None, edge_hidden_sizes=None, name="graph-encoder" 14 | ): 15 | """Constructor. 16 | 17 | Args: 18 | node_hidden_sizes: if provided should be a list of ints, hidden sizes of 19 | node encoder network, the last element is the size of the node outputs. 20 | If not provided, node features will pass through as is. 21 | edge_hidden_sizes: if provided should be a list of ints, hidden sizes of 22 | edge encoder network, the last element is the size of the edge outptus. 23 | If not provided, edge features will pass through as is. 24 | name: name of this module. 25 | """ 26 | super(GraphEncoder, self).__init__(name=name) 27 | 28 | # this also handles the case of an empty list 29 | self._node_hidden_sizes = node_hidden_sizes if node_hidden_sizes else None 30 | self._edge_hidden_sizes = edge_hidden_sizes 31 | 32 | def _build(self, node_features, edge_features=None): 33 | """Encode node and edge features. 34 | 35 | Args: 36 | node_features: [n_nodes, node_feat_dim] float tensor. 37 | edge_features: if provided, should be [n_edges, edge_feat_dim] float 38 | tensor. 39 | 40 | Returns: 41 | node_outputs: [n_nodes, node_embedding_dim] float tensor, node embeddings. 42 | edge_outputs: if edge_features is not None and edge_hidden_sizes is not 43 | None, this is [n_edges, edge_embedding_dim] float tensor, edge 44 | embeddings; otherwise just the input edge_features. 45 | """ 46 | if self._node_hidden_sizes is None: 47 | node_outputs = node_features 48 | else: 49 | node_outputs = snt.nets.MLP( 50 | self._node_hidden_sizes, name="node-feature-mlp" 51 | )(node_features) 52 | 53 | if edge_features is None or self._edge_hidden_sizes is None: 54 | edge_outputs = edge_features 55 | else: 56 | edge_outputs = snt.nets.MLP( 57 | self._edge_hidden_sizes, name="edge-feature-mlp" 58 | )(edge_features) 59 | 60 | return node_outputs, edge_outputs 61 | 62 | 63 | """The Message Passing Layer""" 64 | def graph_prop_once( 65 | node_states, 66 | from_idx, 67 | to_idx, 68 | message_net, 69 | aggregation_module=tf.unsorted_segment_sum, 70 | edge_features=None, 71 | ): 72 | """One round of propagation (message passing) in a graph. 73 | 74 | Args: 75 | node_states: [n_nodes, node_state_dim] float tensor, node state vectors, one 76 | row for each node. 77 | from_idx: [n_edges] int tensor, index of the from nodes. 78 | to_idx: [n_edges] int tensor, index of the to nodes. 79 | message_net: a network that maps concatenated edge inputs to message 80 | vectors. 81 | aggregation_module: a module that aggregates messages on edges to aggregated 82 | messages for each node. Should be a callable and can be called like the 83 | following, 84 | `aggregated_messages = aggregation_module(messages, to_idx, n_nodes)`, 85 | where messages is [n_edges, edge_message_dim] tensor, to_idx is the index 86 | of the to nodes, i.e. where each message should go to, and n_nodes is an 87 | int which is the number of nodes to aggregate into. 88 | edge_features: if provided, should be a [n_edges, edge_feature_dim] float 89 | tensor, extra features for each edge. 90 | 91 | Returns: 92 | aggregated_messages: an [n_nodes, edge_message_dim] float tensor, the 93 | aggregated messages, one row for each node. 94 | """ 95 | from_states = tf.gather(node_states, from_idx) 96 | to_states = tf.gather(node_states, to_idx) 97 | 98 | edge_inputs = [from_states, to_states] 99 | if edge_features is not None: 100 | edge_inputs.append(edge_features) 101 | 102 | edge_inputs = tf.concat(edge_inputs, axis=-1) 103 | messages = message_net(edge_inputs) 104 | 105 | return aggregation_module(messages, to_idx, tf.shape(node_states)[0]) 106 | 107 | class GraphPropLayer(snt.AbstractModule): 108 | """Implementation of a graph propagation (message passing) layer.""" 109 | 110 | def __init__( 111 | self, 112 | node_state_dim, 113 | edge_hidden_sizes, 114 | node_hidden_sizes, 115 | edge_net_init_scale=0.1, 116 | node_update_type="residual", 117 | use_reverse_direction=True, 118 | reverse_dir_param_different=True, 119 | layer_norm=False, 120 | name="graph-net", 121 | ): 122 | """Constructor. 123 | 124 | Args: 125 | node_state_dim: int, dimensionality of node states. 126 | edge_hidden_sizes: list of ints, hidden sizes for the edge message 127 | net, the last element in the list is the size of the message vectors. 128 | node_hidden_sizes: list of ints, hidden sizes for the node update 129 | net. 130 | edge_net_init_scale: initialization scale for the edge networks. This 131 | is typically set to a small value such that the gradient does not blow 132 | up. 133 | node_update_type: type of node updates, one of {mlp, gru, residual}. 134 | use_reverse_direction: set to True to also propagate messages in the 135 | reverse direction. 136 | reverse_dir_param_different: set to True to have the messages computed 137 | using a different set of parameters than for the forward direction. 138 | layer_norm: set to True to use layer normalization in a few places. 139 | name: name of this module. 140 | """ 141 | super(GraphPropLayer, self).__init__(name=name) 142 | 143 | self._node_state_dim = node_state_dim 144 | self._edge_hidden_sizes = edge_hidden_sizes[:] 145 | 146 | # output size is node_state_dim 147 | self._node_hidden_sizes = node_hidden_sizes[:] + [node_state_dim] 148 | self._edge_net_init_scale = edge_net_init_scale 149 | self._node_update_type = node_update_type 150 | 151 | self._use_reverse_direction = use_reverse_direction 152 | self._reverse_dir_param_different = reverse_dir_param_different 153 | 154 | self._layer_norm = layer_norm 155 | 156 | def _compute_aggregated_messages( 157 | self, node_states, from_idx, to_idx, edge_features=None 158 | ): 159 | """Compute aggregated messages for each node. 160 | 161 | Args: 162 | node_states: [n_nodes, input_node_state_dim] float tensor, node states. 163 | from_idx: [n_edges] int tensor, from node indices for each edge. 164 | to_idx: [n_edges] int tensor, to node indices for each edge. 165 | edge_features: if not None, should be [n_edges, edge_embedding_dim] 166 | tensor, edge features. 167 | 168 | Returns: 169 | aggregated_messages: [n_nodes, aggregated_message_dim] float tensor, the 170 | aggregated messages for each node. 171 | """ 172 | self._message_net = snt.nets.MLP( 173 | self._edge_hidden_sizes, 174 | initializers={ 175 | "w": tf.variance_scaling_initializer(scale=self._edge_net_init_scale), 176 | "b": tf.zeros_initializer(), 177 | }, 178 | name="message-mlp", 179 | ) 180 | 181 | aggregated_messages = graph_prop_once( 182 | node_states, 183 | from_idx, 184 | to_idx, 185 | self._message_net, 186 | aggregation_module=tf.unsorted_segment_sum, 187 | edge_features=edge_features, 188 | ) 189 | 190 | # optionally compute message vectors in the reverse direction 191 | if self._use_reverse_direction: 192 | if self._reverse_dir_param_different: 193 | self._reverse_message_net = snt.nets.MLP( 194 | self._edge_hidden_sizes, 195 | initializers={ 196 | "w": tf.variance_scaling_initializer( 197 | scale=self._edge_net_init_scale 198 | ), 199 | "b": tf.zeros_initializer(), 200 | }, 201 | name="reverse-message-mlp", 202 | ) 203 | else: 204 | self._reverse_message_net = self._message_net 205 | 206 | reverse_aggregated_messages = graph_prop_once( 207 | node_states, 208 | to_idx, 209 | from_idx, 210 | self._reverse_message_net, 211 | aggregation_module=tf.unsorted_segment_sum, 212 | edge_features=edge_features, 213 | ) 214 | 215 | aggregated_messages += reverse_aggregated_messages 216 | 217 | if self._layer_norm: 218 | aggregated_messages = snt.LayerNorm()(aggregated_messages) 219 | 220 | return aggregated_messages 221 | 222 | def _compute_node_update(self, node_states, node_state_inputs, node_features=None): 223 | """Compute node updates. 224 | 225 | Args: 226 | node_states: [n_nodes, node_state_dim] float tensor, the input node 227 | states. 228 | node_state_inputs: a list of tensors used to compute node updates. Each 229 | element tensor should have shape [n_nodes, feat_dim], where feat_dim can 230 | be different. These tensors will be concatenated along the feature 231 | dimension. 232 | node_features: extra node features if provided, should be of size 233 | [n_nodes, extra_node_feat_dim] float tensor, can be used to implement 234 | different types of skip connections. 235 | 236 | Returns: 237 | new_node_states: [n_nodes, node_state_dim] float tensor, the new node 238 | state tensor. 239 | 240 | Raises: 241 | ValueError: if node update type is not supported. 242 | """ 243 | if self._node_update_type in ("mlp", "residual"): 244 | node_state_inputs.append(node_states) 245 | if node_features is not None: 246 | node_state_inputs.append(node_features) 247 | 248 | if len(node_state_inputs) == 1: 249 | node_state_inputs = node_state_inputs[0] 250 | else: 251 | node_state_inputs = tf.concat(node_state_inputs, axis=-1) 252 | 253 | if self._node_update_type == "gru": 254 | _, new_node_states = snt.GRU(self._node_state_dim)( 255 | node_state_inputs, node_states 256 | ) 257 | return new_node_states 258 | else: 259 | mlp_output = snt.nets.MLP(self._node_hidden_sizes, name="node-mlp")( 260 | node_state_inputs 261 | ) 262 | if self._layer_norm: 263 | mlp_output = snt.LayerNorm()(mlp_output) 264 | if self._node_update_type == "mlp": 265 | return mlp_output 266 | elif self._node_update_type == "residual": 267 | return node_states + mlp_output 268 | else: 269 | raise ValueError("Unknown node update type %s" % self._node_update_type) 270 | 271 | def _build( 272 | self, node_states, from_idx, to_idx, edge_features=None, node_features=None 273 | ): 274 | """Run one propagation step. 275 | 276 | Args: 277 | node_states: [n_nodes, input_node_state_dim] float tensor, node states. 278 | from_idx: [n_edges] int tensor, from node indices for each edge. 279 | to_idx: [n_edges] int tensor, to node indices for each edge. 280 | edge_features: if not None, should be [n_edges, edge_embedding_dim] 281 | tensor, edge features. 282 | node_features: extra node features if provided, should be of size 283 | [n_nodes, extra_node_feat_dim] float tensor, can be used to implement 284 | different types of skip connections. 285 | 286 | Returns: 287 | node_states: [n_nodes, node_state_dim] float tensor, new node states. 288 | """ 289 | aggregated_messages = self._compute_aggregated_messages( 290 | node_states, from_idx, to_idx, edge_features=edge_features 291 | ) 292 | 293 | return self._compute_node_update( 294 | node_states, [aggregated_messages], node_features=node_features 295 | ) 296 | 297 | """Aggregator""" 298 | 299 | 300 | AGGREGATION_TYPE = { 301 | "sum": tf.unsorted_segment_sum, 302 | "mean": tf.unsorted_segment_mean, 303 | "sqrt_n": tf.unsorted_segment_sqrt_n, 304 | "max": tf.unsorted_segment_max, 305 | } 306 | 307 | 308 | class GraphAggregator(snt.AbstractModule): 309 | """This module computes graph representations by aggregating from parts.""" 310 | 311 | def __init__( 312 | self, 313 | node_hidden_sizes, 314 | graph_transform_sizes=None, 315 | gated=True, 316 | aggregation_type="sum", 317 | name="graph-aggregator", 318 | ): 319 | """Constructor. 320 | 321 | Args: 322 | node_hidden_sizes: the hidden layer sizes of the node transformation nets. 323 | The last element is the size of the aggregated graph representation. 324 | graph_transform_sizes: sizes of the transformation layers on top of the 325 | graph representations. The last element of this list is the final 326 | dimensionality of the output graph representations. 327 | gated: set to True to do gated aggregation, False not to. 328 | aggregation_type: one of {sum, max, mean, sqrt_n}. 329 | name: name of this module. 330 | """ 331 | super(GraphAggregator, self).__init__(name=name) 332 | 333 | self._node_hidden_sizes = node_hidden_sizes 334 | self._graph_transform_sizes = graph_transform_sizes 335 | self._graph_state_dim = node_hidden_sizes[-1] 336 | self._gated = gated 337 | self._aggregation_type = aggregation_type 338 | self._aggregation_op = AGGREGATION_TYPE[aggregation_type] 339 | 340 | def _build(self, node_states, graph_idx, n_graphs): 341 | """Compute aggregated graph representations. 342 | 343 | Args: 344 | node_states: [n_nodes, node_state_dim] float tensor, node states of a 345 | batch of graphs concatenated together along the first dimension. 346 | graph_idx: [n_nodes] int tensor, graph ID for each node. 347 | n_graphs: integer, number of graphs in this batch. 348 | 349 | Returns: 350 | graph_states: [n_graphs, graph_state_dim] float tensor, graph 351 | representations, one row for each graph. 352 | """ 353 | node_hidden_sizes = self._node_hidden_sizes 354 | if self._gated: 355 | node_hidden_sizes[-1] = self._graph_state_dim * 2 356 | 357 | node_states_g = snt.nets.MLP(node_hidden_sizes, name="node-state-g-mlp")( 358 | node_states 359 | ) 360 | 361 | if self._gated: 362 | gates = tf.nn.sigmoid(node_states_g[:, : self._graph_state_dim]) 363 | node_states_g = node_states_g[:, self._graph_state_dim :] * gates 364 | 365 | graph_states = self._aggregation_op(node_states_g, graph_idx, n_graphs) 366 | 367 | # unsorted_segment_max does not handle empty graphs in the way we want 368 | # it assigns the lowest possible float to empty segments, we want to reset 369 | # them to zero. 370 | if self._aggregation_type == "max": 371 | # reset everything that's smaller than -1e5 to 0. 372 | graph_states *= tf.cast(graph_states > -1e5, tf.float32) 373 | 374 | # transform the reduced graph states further 375 | 376 | # pylint: disable=g-explicit-length-test 377 | if ( 378 | self._graph_transform_sizes is not None 379 | and len(self._graph_transform_sizes) > 0 380 | ): 381 | graph_states = snt.nets.MLP( 382 | self._graph_transform_sizes, name="graph-transform-mlp" 383 | )(graph_states) 384 | 385 | return graph_states 386 | 387 | 388 | """Graph Matching Networks Related Layers""" 389 | 390 | class GraphPropMatchingLayer(GraphPropLayer): 391 | """A graph propagation layer that also does cross graph matching. 392 | 393 | It assumes the incoming graph data is batched and paired, i.e. graph 0 and 1 394 | forms the first pair and graph 2 and 3 are the second pair etc., and computes 395 | cross-graph attention-based matching for each pair. 396 | """ 397 | 398 | def _build( 399 | self, 400 | node_states, 401 | from_idx, 402 | to_idx, 403 | graph_idx, 404 | n_graphs, 405 | similarity="dotproduct", 406 | edge_features=None, 407 | node_features=None, 408 | ): 409 | """Run one propagation step with cross-graph matching. 410 | 411 | Args: 412 | node_states: [n_nodes, node_state_dim] float tensor, node states. 413 | from_idx: [n_edges] int tensor, from node indices for each edge. 414 | to_idx: [n_edges] int tensor, to node indices for each edge. 415 | graph_idx: [n_onodes] int tensor, graph id for each node. 416 | n_graphs: integer, number of graphs in the batch. 417 | similarity: type of similarity to use for the cross graph attention. 418 | edge_features: if not None, should be [n_edges, edge_feat_dim] tensor, 419 | extra edge features. 420 | node_features: if not None, should be [n_nodes, node_feat_dim] tensor, 421 | extra node features. 422 | 423 | Returns: 424 | node_states: [n_nodes, node_state_dim] float tensor, new node states. 425 | 426 | Raises: 427 | ValueError: if some options are not provided correctly. 428 | """ 429 | aggregated_messages = self._compute_aggregated_messages( 430 | node_states, from_idx, to_idx, edge_features=edge_features 431 | ) 432 | 433 | # new stuff here 434 | cross_graph_attention = batch_block_pair_attention( 435 | node_states, graph_idx, n_graphs, similarity=similarity 436 | ) 437 | attention_input = node_states - cross_graph_attention 438 | 439 | return self._compute_node_update( 440 | node_states, 441 | [aggregated_messages, attention_input], 442 | node_features=node_features, 443 | ) 444 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import sonnet as snt 2 | import tensorflow as tf 3 | from layers import * 4 | import os 5 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 6 | 7 | 8 | class GraphEmbeddingNet(snt.AbstractModule): 9 | """A graph to embedding mapping network.""" 10 | 11 | def __init__( 12 | self, 13 | encoder, 14 | aggregator, 15 | node_state_dim, 16 | edge_hidden_sizes, 17 | node_hidden_sizes, 18 | n_prop_layers, 19 | share_prop_params=False, 20 | edge_net_init_scale=0.1, 21 | node_update_type="residual", 22 | use_reverse_direction=True, 23 | reverse_dir_param_different=True, 24 | layer_norm=False, 25 | name="graph-embedding-net", 26 | ): 27 | """Constructor. 28 | Args: 29 | encoder: GraphEncoder, encoder that maps features to embeddings. 30 | aggregator: GraphAggregator, aggregator that produces graph 31 | representations. 32 | node_state_dim: dimensionality of node states. 33 | edge_hidden_sizes: sizes of the hidden layers of the edge message nets. 34 | node_hidden_sizes: sizes of the hidden layers of the node update nets. 35 | n_prop_layers: number of graph propagation layers. 36 | share_prop_params: set to True to share propagation parameters across all 37 | graph propagation layers, False not to. 38 | edge_net_init_scale: scale of initialization for the edge message nets. 39 | node_update_type: type of node updates, one of {mlp, gru, residual}. 40 | use_reverse_direction: set to True to also propagate messages in the 41 | reverse direction. 42 | reverse_dir_param_different: set to True to have the messages computed 43 | using a different set of parameters than for the forward direction. 44 | layer_norm: set to True to use layer normalization in a few places. 45 | name: name of this module. 46 | """ 47 | super(GraphEmbeddingNet, self).__init__(name=name) 48 | 49 | self._encoder = encoder 50 | self._aggregator = aggregator 51 | self._node_state_dim = node_state_dim 52 | self._edge_hidden_sizes = edge_hidden_sizes 53 | self._node_hidden_sizes = node_hidden_sizes 54 | self._n_prop_layers = n_prop_layers 55 | self._share_prop_params = share_prop_params 56 | self._edge_net_init_scale = edge_net_init_scale 57 | self._node_update_type = node_update_type 58 | self._use_reverse_direction = use_reverse_direction 59 | self._reverse_dir_param_different = reverse_dir_param_different 60 | self._layer_norm = layer_norm 61 | 62 | self._prop_layers = [] 63 | self._layer_class = GraphPropLayer 64 | 65 | def _build_layer(self, layer_id): 66 | """Build one layer in the network.""" 67 | return self._layer_class( 68 | self._node_state_dim, 69 | self._edge_hidden_sizes, 70 | self._node_hidden_sizes, 71 | edge_net_init_scale=self._edge_net_init_scale, 72 | node_update_type=self._node_update_type, 73 | use_reverse_direction=self._use_reverse_direction, 74 | reverse_dir_param_different=self._reverse_dir_param_different, 75 | layer_norm=self._layer_norm, 76 | name="graph-prop-%d" % layer_id, 77 | ) 78 | 79 | def _apply_layer( 80 | self, layer, node_states, from_idx, to_idx, graph_idx, n_graphs, edge_features 81 | ): 82 | """Apply one layer on the given inputs.""" 83 | del graph_idx, n_graphs 84 | return layer(node_states, from_idx, to_idx, edge_features=edge_features) 85 | 86 | def _build( 87 | self, node_features, edge_features, from_idx, to_idx, graph_idx, n_graphs 88 | ): 89 | """Compute graph representations. 90 | 91 | Args: 92 | node_features: [n_nodes, node_feat_dim] float tensor. 93 | edge_features: [n_edges, edge_feat_dim] float tensor. 94 | from_idx: [n_edges] int tensor, index of the from node for each edge. 95 | to_idx: [n_edges] int tensor, index of the to node for each edge. 96 | graph_idx: [n_nodes] int tensor, graph id for each node. 97 | n_graphs: int, number of graphs in the batch. 98 | 99 | Returns: 100 | graph_representations: [n_graphs, graph_representation_dim] float tensor, 101 | graph representations. 102 | """ 103 | if len(self._prop_layers) < self._n_prop_layers: 104 | # build the layers 105 | for i in range(self._n_prop_layers): 106 | if i == 0 or not self._share_prop_params: 107 | layer = self._build_layer(i) 108 | else: 109 | layer = self._prop_layers[0] 110 | self._prop_layers.append(layer) 111 | 112 | node_features, edge_features = self._encoder(node_features, edge_features) 113 | node_states = node_features 114 | 115 | layer_outputs = [node_states] 116 | 117 | for layer in self._prop_layers: 118 | # node_features could be wired in here as well, leaving it out for now as 119 | # it is already in the inputs 120 | node_states = self._apply_layer( 121 | layer, node_states, from_idx, to_idx, graph_idx, n_graphs, edge_features 122 | ) 123 | layer_outputs.append(node_states) 124 | 125 | # these tensors may be used e.g. for visualization 126 | self._layer_outputs = layer_outputs 127 | return self._aggregator(node_states, graph_idx, n_graphs) 128 | 129 | def reset_n_prop_layers(self, n_prop_layers): 130 | """Set n_prop_layers to the provided new value. 131 | 132 | This allows us to train with certain number of propagation layers and 133 | evaluate with a different number of propagation layers. 134 | 135 | This only works if n_prop_layers is smaller than the number used for 136 | training, or when share_prop_params is set to True, in which case this can 137 | be arbitrarily large. 138 | 139 | Args: 140 | n_prop_layers: the new number of propagation layers to set. 141 | """ 142 | self._n_prop_layers = n_prop_layers 143 | 144 | @property 145 | def n_prop_layers(self): 146 | return self._n_prop_layers 147 | 148 | def get_layer_outputs(self): 149 | """Get the outputs at each layer.""" 150 | if hasattr(self, "_layer_outputs"): 151 | return self._layer_outputs 152 | else: 153 | raise ValueError("No layer outputs available.") 154 | 155 | class GraphMatchingNet(GraphEmbeddingNet): 156 | """Graph matching net. 157 | 158 | This class uses graph matching layers instead of the simple graph prop layers. 159 | 160 | It assumes the incoming graph data is batched and paired, i.e. graph 0 and 1 161 | forms the first pair and graph 2 and 3 are the second pair etc., and computes 162 | cross-graph attention-based matching for each pair. 163 | """ 164 | 165 | def __init__( 166 | self, 167 | encoder, 168 | aggregator, 169 | node_state_dim, 170 | edge_hidden_sizes, 171 | node_hidden_sizes, 172 | n_prop_layers, 173 | share_prop_params=False, 174 | edge_net_init_scale=0.1, 175 | node_update_type="residual", 176 | use_reverse_direction=True, 177 | reverse_dir_param_different=True, 178 | layer_norm=False, 179 | similarity="dotproduct", 180 | name="graph-matching-net", 181 | ): 182 | super(GraphMatchingNet, self).__init__( 183 | encoder, 184 | aggregator, 185 | node_state_dim, 186 | edge_hidden_sizes, 187 | node_hidden_sizes, 188 | n_prop_layers, 189 | share_prop_params=share_prop_params, 190 | edge_net_init_scale=edge_net_init_scale, 191 | node_update_type=node_update_type, 192 | use_reverse_direction=use_reverse_direction, 193 | reverse_dir_param_different=reverse_dir_param_different, 194 | layer_norm=layer_norm, 195 | name=name, 196 | ) 197 | self._similarity = similarity 198 | self._layer_class = GraphPropMatchingLayer 199 | 200 | def _apply_layer( 201 | self, layer, node_states, from_idx, to_idx, graph_idx, n_graphs, edge_features 202 | ): 203 | """Apply one layer on the given inputs.""" 204 | return layer( 205 | node_states, 206 | from_idx, 207 | to_idx, 208 | graph_idx, 209 | n_graphs, 210 | similarity=self._similarity, 211 | edge_features=edge_features, 212 | ) 213 | -------------------------------------------------------------------------------- /outputs.txt: -------------------------------------------------------------------------------- 1 | iter 100, loss 0.8873, grad_scale 10.0000, param_scale 24.2240, graph_vec_scale 10.3391, sim_pos -0.5914, sim_neg -1.1151, sim_diff 0.5237, time 11.99s 2 | iter 200, loss 0.8771, grad_scale 9.9898, param_scale 24.2670, graph_vec_scale 17.1087, sim_pos -0.6368, sim_neg -1.2004, sim_diff 0.5636, time 7.75s 3 | iter 300, loss 0.8830, grad_scale 9.9773, param_scale 24.3027, graph_vec_scale 17.0012, sim_pos -0.5818, sim_neg -1.0533, sim_diff 0.4715, time 7.87s 4 | iter 400, loss 0.8372, grad_scale 10.0000, param_scale 24.3126, graph_vec_scale 11.9274, sim_pos -0.5922, sim_neg -1.2563, sim_diff 0.6641, time 7.73s 5 | iter 500, loss 0.8818, grad_scale 9.9879, param_scale 24.3395, graph_vec_scale 10.5516, sim_pos -0.5558, sim_neg -1.0426, sim_diff 0.4868, time 7.82s 6 | iter 600, loss 0.8757, grad_scale 9.9834, param_scale 24.3923, graph_vec_scale 12.8394, sim_pos -0.5530, sim_neg -1.0178, sim_diff 0.4647, time 7.70s 7 | iter 700, loss 0.8613, grad_scale 10.0000, param_scale 24.4811, graph_vec_scale 13.1900, sim_pos -0.5879, sim_neg -1.0978, sim_diff 0.5099, time 7.77s 8 | iter 800, loss 0.8571, grad_scale 9.9728, param_scale 24.5740, graph_vec_scale 9.2787, sim_pos -0.5790, sim_neg -1.1338, sim_diff 0.5549, time 7.97s 9 | iter 900, loss 0.8237, grad_scale 10.0000, param_scale 24.6664, graph_vec_scale 11.6841, sim_pos -0.6255, sim_neg -1.2683, sim_diff 0.6429, time 7.78s 10 | iter 1000, loss 0.8638, grad_scale 9.9662, param_scale 24.7718, graph_vec_scale 17.8530, sim_pos -0.5702, sim_neg -1.0790, sim_diff 0.5088, val/pair_auc 0.6324, val/triplet_acc 0.6670, time 14.69s 11 | iter 1100, loss 0.8409, grad_scale 9.9880, param_scale 24.9303, graph_vec_scale 15.7435, sim_pos -0.6286, sim_neg -1.2096, sim_diff 0.5811, time 7.82s 12 | iter 1200, loss 0.8522, grad_scale 10.0000, param_scale 25.1148, graph_vec_scale 7.6839, sim_pos -0.6068, sim_neg -1.1389, sim_diff 0.5321, time 8.09s 13 | iter 1300, loss 0.8138, grad_scale 10.0000, param_scale 25.1907, graph_vec_scale 4.7778, sim_pos -0.7042, sim_neg -1.3391, sim_diff 0.6349, time 7.81s 14 | iter 1400, loss 0.8215, grad_scale 10.0000, param_scale 25.2458, graph_vec_scale 6.9688, sim_pos -0.6505, sim_neg -1.2704, sim_diff 0.6199, time 7.74s 15 | iter 1500, loss 0.8152, grad_scale 10.0000, param_scale 25.3281, graph_vec_scale 6.8125, sim_pos -0.7034, sim_neg -1.3504, sim_diff 0.6470, time 7.75s 16 | iter 1600, loss 0.8066, grad_scale 10.0000, param_scale 25.4271, graph_vec_scale 7.2998, sim_pos -0.7107, sim_neg -1.3687, sim_diff 0.6580, time 7.86s 17 | iter 1700, loss 0.8151, grad_scale 10.0000, param_scale 25.4918, graph_vec_scale 8.7095, sim_pos -0.6957, sim_neg -1.3460, sim_diff 0.6503, time 7.85s 18 | iter 1800, loss 0.8053, grad_scale 10.0000, param_scale 25.6026, graph_vec_scale 6.1122, sim_pos -0.7249, sim_neg -1.3746, sim_diff 0.6497, time 7.78s 19 | iter 1900, loss 0.8008, grad_scale 10.0000, param_scale 25.7393, graph_vec_scale 5.9124, sim_pos -0.6897, sim_neg -1.3541, sim_diff 0.6644, time 8.07s 20 | iter 2000, loss 0.8250, grad_scale 10.0000, param_scale 25.8922, graph_vec_scale 6.9056, sim_pos -0.7086, sim_neg -1.3000, sim_diff 0.5913, val/pair_auc 0.6538, val/triplet_acc 0.6890, time 10.58s 21 | iter 2100, loss 0.8122, grad_scale 10.0000, param_scale 26.0320, graph_vec_scale 5.2555, sim_pos -0.7462, sim_neg -1.3755, sim_diff 0.6293, time 7.80s 22 | iter 2200, loss 0.8045, grad_scale 10.0000, param_scale 26.1303, graph_vec_scale 4.3697, sim_pos -0.7543, sim_neg -1.3940, sim_diff 0.6397, time 7.90s 23 | iter 2300, loss 0.7916, grad_scale 10.0000, param_scale 26.1934, graph_vec_scale 3.9951, sim_pos -0.7153, sim_neg -1.3503, sim_diff 0.6350, time 7.78s 24 | iter 2400, loss 0.8278, grad_scale 10.0000, param_scale 26.2812, graph_vec_scale 5.2692, sim_pos -0.7016, sim_neg -1.2960, sim_diff 0.5944, time 7.73s 25 | iter 2500, loss 0.8058, grad_scale 10.0000, param_scale 26.3557, graph_vec_scale 5.8444, sim_pos -0.7284, sim_neg -1.3689, sim_diff 0.6404, time 7.80s 26 | iter 2600, loss 0.7689, grad_scale 10.0000, param_scale 26.4257, graph_vec_scale 3.2281, sim_pos -0.7457, sim_neg -1.5151, sim_diff 0.7695, time 7.84s 27 | iter 2700, loss 0.7969, grad_scale 10.0000, param_scale 26.4954, graph_vec_scale 3.3767, sim_pos -0.7482, sim_neg -1.4360, sim_diff 0.6878, time 7.73s 28 | iter 2800, loss 0.8082, grad_scale 10.0000, param_scale 26.5255, graph_vec_scale 3.0197, sim_pos -0.7027, sim_neg -1.3742, sim_diff 0.6715, time 7.76s 29 | iter 2900, loss 0.7893, grad_scale 10.0000, param_scale 26.5781, graph_vec_scale 2.8958, sim_pos -0.7676, sim_neg -1.4553, sim_diff 0.6876, time 7.72s 30 | iter 3000, loss 0.8431, grad_scale 10.0000, param_scale 26.6646, graph_vec_scale 3.8674, sim_pos -0.7303, sim_neg -1.2741, sim_diff 0.5438, val/pair_auc 0.6621, val/triplet_acc 0.6680, time 10.91s 31 | iter 3100, loss 0.7689, grad_scale 10.0000, param_scale 26.8008, graph_vec_scale 4.1606, sim_pos -0.7765, sim_neg -1.5472, sim_diff 0.7707, time 7.77s 32 | iter 3200, loss 0.7840, grad_scale 10.0000, param_scale 26.8969, graph_vec_scale 2.8118, sim_pos -0.8286, sim_neg -1.5254, sim_diff 0.6968, time 7.78s 33 | iter 3300, loss 0.7904, grad_scale 10.0000, param_scale 26.9548, graph_vec_scale 1.9365, sim_pos -0.7651, sim_neg -1.4085, sim_diff 0.6434, time 7.75s 34 | iter 3400, loss 0.7791, grad_scale 10.0000, param_scale 27.0782, graph_vec_scale 1.8720, sim_pos -0.8018, sim_neg -1.4594, sim_diff 0.6575, time 7.76s 35 | iter 3500, loss 0.8122, grad_scale 10.0000, param_scale 27.1524, graph_vec_scale 1.9007, sim_pos -0.7581, sim_neg -1.3476, sim_diff 0.5896, time 7.88s 36 | iter 3600, loss 0.7733, grad_scale 10.0000, param_scale 27.2225, graph_vec_scale 2.7267, sim_pos -0.7728, sim_neg -1.4766, sim_diff 0.7038, time 7.67s 37 | iter 3700, loss 0.7730, grad_scale 10.0000, param_scale 27.3204, graph_vec_scale 2.7635, sim_pos -0.7620, sim_neg -1.4900, sim_diff 0.7279, time 7.78s 38 | iter 3800, loss 0.7557, grad_scale 10.0000, param_scale 27.4433, graph_vec_scale 3.6189, sim_pos -0.7520, sim_neg -1.5714, sim_diff 0.8195, time 7.71s 39 | iter 3900, loss 0.7714, grad_scale 10.0000, param_scale 27.5628, graph_vec_scale 1.8792, sim_pos -0.8008, sim_neg -1.5219, sim_diff 0.7211, time 7.98s 40 | iter 4000, loss 0.7786, grad_scale 10.0000, param_scale 27.6404, graph_vec_scale 2.4379, sim_pos -0.7773, sim_neg -1.4837, sim_diff 0.7063, val/pair_auc 0.6648, val/triplet_acc 0.7350, time 10.61s 41 | iter 4100, loss 0.7738, grad_scale 10.0000, param_scale 27.7126, graph_vec_scale 2.1756, sim_pos -0.7762, sim_neg -1.5106, sim_diff 0.7344, time 7.71s 42 | iter 4200, loss 0.7691, grad_scale 10.0000, param_scale 27.7810, graph_vec_scale 1.4263, sim_pos -0.8066, sim_neg -1.5122, sim_diff 0.7056, time 7.71s 43 | iter 4300, loss 0.7667, grad_scale 10.0000, param_scale 27.8617, graph_vec_scale 1.7901, sim_pos -0.8347, sim_neg -1.5831, sim_diff 0.7484, time 7.75s 44 | iter 4400, loss 0.7857, grad_scale 10.0000, param_scale 27.9286, graph_vec_scale 1.4075, sim_pos -0.8296, sim_neg -1.4867, sim_diff 0.6571, time 7.70s 45 | iter 4500, loss 0.7750, grad_scale 10.0000, param_scale 27.9861, graph_vec_scale 1.4040, sim_pos -0.8384, sim_neg -1.5247, sim_diff 0.6863, time 7.80s 46 | iter 4600, loss 0.7680, grad_scale 10.0000, param_scale 28.0557, graph_vec_scale 1.2411, sim_pos -0.7540, sim_neg -1.4221, sim_diff 0.6681, time 7.76s 47 | iter 4700, loss 0.7698, grad_scale 10.0000, param_scale 28.1328, graph_vec_scale 2.2770, sim_pos -0.7909, sim_neg -1.4841, sim_diff 0.6932, time 7.71s 48 | iter 4800, loss 0.7713, grad_scale 10.0000, param_scale 28.1726, graph_vec_scale 1.2568, sim_pos -0.8339, sim_neg -1.5140, sim_diff 0.6801, time 8.06s 49 | iter 4900, loss 0.7792, grad_scale 10.0000, param_scale 28.2418, graph_vec_scale 1.3821, sim_pos -0.7927, sim_neg -1.4273, sim_diff 0.6346, time 7.72s 50 | iter 5000, loss 0.7895, grad_scale 10.0000, param_scale 28.2885, graph_vec_scale 1.7310, sim_pos -0.7933, sim_neg -1.4116, sim_diff 0.6183, val/pair_auc 0.6756, val/triplet_acc 0.7410, time 10.60s 51 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.1.1 2 | 3 | networkx==2.3 4 | dm-sonnet==1.34 5 | numpy==1.16.4 6 | tensorflow==1.14 7 | tensorflow-probability==0.7.0rc0 8 | six==1.12 9 | #code end 10 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python -W ignore train.py 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from models import * 3 | from datasets import * 4 | from layers import * 5 | 6 | import time 7 | import random 8 | import copy 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | import os 14 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 15 | 16 | 17 | """Evaluation""" 18 | 19 | def exact_hamming_similarity(x, y): 20 | """Compute the binary Hamming similarity.""" 21 | match = tf.cast(tf.equal(x > 0, y > 0), dtype=tf.float32) 22 | return tf.reduce_mean(match, axis=1) 23 | 24 | 25 | def compute_similarity(config, x, y): 26 | """Compute the distance between x and y vectors. 27 | 28 | The distance will be computed based on the training loss type. 29 | 30 | Args: 31 | config: a config dict. 32 | x: [n_examples, feature_dim] float tensor. 33 | y: [n_examples, feature_dim] float tensor. 34 | 35 | Returns: 36 | dist: [n_examples] float tensor. 37 | 38 | Raises: 39 | ValueError: if loss type is not supported. 40 | """ 41 | if config["training"]["loss"] == "margin": 42 | # similarity is negative distance 43 | return -euclidean_distance(x, y) 44 | elif config["training"]["loss"] == "hamming": 45 | return exact_hamming_similarity(x, y) 46 | else: 47 | raise ValueError("Unknown loss type %s" % config["training"]["loss"]) 48 | 49 | 50 | def auc(scores, labels, **auc_args): 51 | """Compute the AUC for pair classification. 52 | 53 | See `tf.metrics.auc` for more details about this metric. 54 | 55 | Args: 56 | scores: [n_examples] float. Higher scores mean higher preference of being 57 | assigned the label of +1. 58 | labels: [n_examples] int. Labels are either +1 or -1. 59 | **auc_args: other arguments that can be used by `tf.metrics.auc`. 60 | 61 | Returns: 62 | auc: the area under the ROC curve. 63 | """ 64 | scores_max = tf.reduce_max(scores) 65 | scores_min = tf.reduce_min(scores) 66 | # normalize scores to [0, 1] and add a small epislon for safety 67 | scores = (scores - scores_min) / (scores_max - scores_min + 1e-8) 68 | 69 | labels = (labels + 1) / 2 70 | # The following code should be used according to the tensorflow official 71 | # documentation: 72 | # value, _ = tf.metrics.auc(labels, scores, **auc_args) 73 | 74 | # However `tf.metrics.auc` is currently (as of July 23, 2019) buggy so we have 75 | # to use the following: 76 | _, value = tf.metrics.auc(labels, scores, **auc_args) 77 | return value 78 | 79 | """Build the model""" 80 | def reshape_and_split_tensor(tensor, n_splits): 81 | """Reshape and split a 2D tensor along the last dimension. 82 | 83 | Args: 84 | tensor: a [num_examples, feature_dim] tensor. num_examples must be a 85 | multiple of `n_splits`. 86 | n_splits: int, number of splits to split the tensor into. 87 | 88 | Returns: 89 | splits: a list of `n_splits` tensors. The first split is [tensor[0], 90 | tensor[n_splits], tensor[n_splits * 2], ...], the second split is 91 | [tensor[1], tensor[n_splits + 1], tensor[n_splits * 2 + 1], ...], etc.. 92 | """ 93 | feature_dim = tensor.shape.as_list()[-1] 94 | # feature dim must be known, otherwise you can provide that as an input 95 | assert isinstance(feature_dim, int) 96 | tensor = tf.reshape(tensor, [-1, feature_dim * n_splits]) 97 | return tf.split(tensor, n_splits, axis=-1) 98 | 99 | 100 | def build_placeholders(node_feature_dim, edge_feature_dim): 101 | """Build the placeholders needed for the model. 102 | 103 | Args: 104 | node_feature_dim: int. 105 | edge_feature_dim: int. 106 | 107 | Returns: 108 | placeholders: a placeholder name -> placeholder tensor dict. 109 | """ 110 | # `n_graphs` must be specified as an integer, as `tf.dynamic_partition` 111 | # requires so. 112 | return { 113 | "node_features": tf.placeholder(tf.float32, [None, node_feature_dim]), 114 | "edge_features": tf.placeholder(tf.float32, [None, edge_feature_dim]), 115 | "from_idx": tf.placeholder(tf.int32, [None]), 116 | "to_idx": tf.placeholder(tf.int32, [None]), 117 | "graph_idx": tf.placeholder(tf.int32, [None]), 118 | # only used for pairwise training and evaluation 119 | "labels": tf.placeholder(tf.int32, [None]), 120 | } 121 | 122 | 123 | def build_model(config, node_feature_dim, edge_feature_dim): 124 | """Create model for training and evaluation. 125 | 126 | Args: 127 | config: a dictionary of configs, like the one created by the 128 | `get_default_config` function. 129 | node_feature_dim: int, dimensionality of node features. 130 | edge_feature_dim: int, dimensionality of edge features. 131 | 132 | Returns: 133 | tensors: a (potentially nested) name => tensor dict. 134 | placeholders: a (potentially nested) name => tensor dict. 135 | model: a GraphEmbeddingNet or GraphMatchingNet instance. 136 | 137 | Raises: 138 | ValueError: if the specified model or training settings are not supported. 139 | """ 140 | encoder = GraphEncoder(**config["encoder"]) 141 | aggregator = GraphAggregator(**config["aggregator"]) 142 | if config["model_type"] == "embedding": 143 | model = GraphEmbeddingNet(encoder, aggregator, **config["graph_embedding_net"]) 144 | elif config["model_type"] == "matching": 145 | model = GraphMatchingNet(encoder, aggregator, **config["graph_matching_net"]) 146 | else: 147 | raise ValueError("Unknown model type: %s" % config["model_type"]) 148 | 149 | training_n_graphs_in_batch = config["training"]["batch_size"] 150 | if config["training"]["mode"] == "pair": 151 | training_n_graphs_in_batch *= 2 152 | elif config["training"]["mode"] == "triplet": 153 | training_n_graphs_in_batch *= 4 154 | else: 155 | raise ValueError("Unknown training mode: %s" % config["training"]["mode"]) 156 | 157 | placeholders = build_placeholders(node_feature_dim, edge_feature_dim) 158 | 159 | # training 160 | model_inputs = placeholders.copy() 161 | del model_inputs["labels"] 162 | model_inputs["n_graphs"] = training_n_graphs_in_batch 163 | graph_vectors = model(**model_inputs) 164 | 165 | if config["training"]["mode"] == "pair": 166 | x, y = reshape_and_split_tensor(graph_vectors, 2) 167 | loss = pairwise_loss( 168 | x, 169 | y, 170 | placeholders["labels"], 171 | loss_type=config["training"]["loss"], 172 | margin=config["training"]["margin"], 173 | ) 174 | 175 | # optionally monitor the similarity between positive and negative pairs 176 | is_pos = tf.cast(tf.equal(placeholders["labels"], 1), tf.float32) 177 | is_neg = 1 - is_pos 178 | n_pos = tf.reduce_sum(is_pos) 179 | n_neg = tf.reduce_sum(is_neg) 180 | sim = compute_similarity(config, x, y) 181 | sim_pos = tf.reduce_sum(sim * is_pos) / (n_pos + 1e-8) 182 | sim_neg = tf.reduce_sum(sim * is_neg) / (n_neg + 1e-8) 183 | else: 184 | x_1, y, x_2, z = reshape_and_split_tensor(graph_vectors, 4) 185 | loss = triplet_loss( 186 | x_1, 187 | y, 188 | x_2, 189 | z, 190 | loss_type=config["training"]["loss"], 191 | margin=config["training"]["margin"], 192 | ) 193 | 194 | sim_pos = tf.reduce_mean(compute_similarity(config, x_1, y)) 195 | sim_neg = tf.reduce_mean(compute_similarity(config, x_2, z)) 196 | 197 | graph_vec_scale = tf.reduce_mean(graph_vectors ** 2) 198 | if config["training"]["graph_vec_regularizer_weight"] > 0: 199 | loss += ( 200 | config["training"]["graph_vec_regularizer_weight"] * 0.5 * graph_vec_scale 201 | ) 202 | 203 | # monitor scale of the parameters and gradients, these are typically helpful 204 | optimizer = tf.train.AdamOptimizer( 205 | learning_rate=config["training"]["learning_rate"] 206 | ) 207 | grads_and_params = optimizer.compute_gradients(loss) 208 | grads, params = zip(*grads_and_params) 209 | grads, _ = tf.clip_by_global_norm(grads, config["training"]["clip_value"]) 210 | train_step = optimizer.apply_gradients(zip(grads, params)) 211 | 212 | grad_scale = tf.global_norm(grads) 213 | param_scale = tf.global_norm(params) 214 | 215 | # evaluation 216 | model_inputs["n_graphs"] = config["evaluation"]["batch_size"] * 2 217 | eval_pairs = model(**model_inputs) 218 | x, y = reshape_and_split_tensor(eval_pairs, 2) 219 | similarity = compute_similarity(config, x, y) 220 | pair_auc = auc(similarity, placeholders["labels"]) 221 | 222 | model_inputs["n_graphs"] = config["evaluation"]["batch_size"] * 4 223 | eval_triplets = model(**model_inputs) 224 | x_1, y, x_2, z = reshape_and_split_tensor(eval_triplets, 4) 225 | sim_1 = compute_similarity(config, x_1, y) 226 | sim_2 = compute_similarity(config, x_2, z) 227 | triplet_acc = tf.reduce_mean(tf.cast(sim_1 > sim_2, dtype=tf.float32)) 228 | 229 | return ( 230 | { 231 | "train_step": train_step, 232 | "metrics": { 233 | "training": { 234 | "loss": loss, 235 | "grad_scale": grad_scale, 236 | "param_scale": param_scale, 237 | "graph_vec_scale": graph_vec_scale, 238 | "sim_pos": sim_pos, 239 | "sim_neg": sim_neg, 240 | "sim_diff": sim_pos - sim_neg, 241 | }, 242 | "validation": {"pair_auc": pair_auc, "triplet_acc": triplet_acc,}, 243 | }, 244 | }, 245 | placeholders, 246 | model, 247 | ) 248 | 249 | """Training Pipeline""" 250 | def build_datasets(config): 251 | """Build the training and evaluation datasets.""" 252 | config = copy.deepcopy(config) 253 | 254 | if config["data"]["problem"] == "graph_edit_distance": 255 | dataset_params = config["data"]["dataset_params"] 256 | validation_dataset_size = dataset_params["validation_dataset_size"] 257 | del dataset_params["validation_dataset_size"] 258 | training_set = GraphEditDistanceDataset(**dataset_params) 259 | dataset_params["dataset_size"] = validation_dataset_size 260 | validation_set = FixedGraphEditDistanceDataset(**dataset_params) 261 | else: 262 | raise ValueError("Unknown problem type: %s" % config["data"]["problem"]) 263 | return training_set, validation_set 264 | 265 | 266 | def fill_feed_dict(placeholders, batch): 267 | """Create a feed dict for the given batch of data. 268 | 269 | Args: 270 | placeholders: a dict of placeholders. 271 | batch: a batch of data, should be either a single `GraphData` instance for 272 | triplet training, or a tuple of (graphs, labels) for pairwise training. 273 | 274 | Returns: 275 | feed_dict: a feed_dict that can be used in a session run call. 276 | """ 277 | if isinstance(batch, GraphData): 278 | graphs = batch 279 | labels = None 280 | else: 281 | graphs, labels = batch 282 | 283 | feed_dict = { 284 | placeholders["node_features"]: graphs.node_features, 285 | placeholders["edge_features"]: graphs.edge_features, 286 | placeholders["from_idx"]: graphs.from_idx, 287 | placeholders["to_idx"]: graphs.to_idx, 288 | placeholders["graph_idx"]: graphs.graph_idx, 289 | } 290 | if labels is not None: 291 | feed_dict[placeholders["labels"]] = labels 292 | return feed_dict 293 | 294 | 295 | def evaluate(sess, eval_metrics, placeholders, validation_set, batch_size): 296 | """Evaluate model performance on the given validation set. 297 | 298 | Args: 299 | sess: a `tf.Session` instance used to run the computation. 300 | eval_metrics: a dict containing two tensors 'pair_auc' and 'triplet_acc'. 301 | placeholders: a placeholder dict. 302 | validation_set: a `GraphSimilarityDataset` instance, calling `pairs` and 303 | `triplets` functions with `batch_size` creates iterators over a finite 304 | sequence of batches to evaluate on. 305 | batch_size: number of batches to use for each session run call. 306 | 307 | Returns: 308 | metrics: a dict of metric name => value mapping. 309 | """ 310 | accumulated_pair_auc = [] 311 | for batch in validation_set.pairs(batch_size): 312 | feed_dict = fill_feed_dict(placeholders, batch) 313 | pair_auc = sess.run(eval_metrics["pair_auc"], feed_dict=feed_dict) 314 | accumulated_pair_auc.append(pair_auc) 315 | 316 | accumulated_triplet_acc = [] 317 | for batch in validation_set.triplets(batch_size): 318 | feed_dict = fill_feed_dict(placeholders, batch) 319 | triplet_acc = sess.run(eval_metrics["triplet_acc"], feed_dict=feed_dict) 320 | accumulated_triplet_acc.append(triplet_acc) 321 | 322 | return { 323 | "pair_auc": np.mean(accumulated_pair_auc), 324 | "triplet_acc": np.mean(accumulated_triplet_acc), 325 | } 326 | 327 | """Main run process""" 328 | 329 | config = get_default_config() 330 | config["training"]["n_training_steps"] = 50000 331 | tf.reset_default_graph() 332 | 333 | # Set random seeds 334 | seed = config["seed"] 335 | random.seed(seed) 336 | np.random.seed(seed + 1) 337 | tf.set_random_seed(seed + 2) 338 | 339 | training_set, validation_set = build_datasets(config) 340 | 341 | 342 | 343 | if config["training"]["mode"] == "pair": 344 | training_data_iter = training_set.pairs(config["training"]["batch_size"]) 345 | first_batch_graphs, _ = next(training_data_iter) 346 | else: 347 | training_data_iter = training_set.triplets(config["training"]["batch_size"]) 348 | first_batch_graphs = next(training_data_iter) 349 | 350 | node_feature_dim = first_batch_graphs.node_features.shape[-1] 351 | edge_feature_dim = first_batch_graphs.edge_features.shape[-1] 352 | 353 | tensors, placeholders, model = build_model(config, node_feature_dim, edge_feature_dim) 354 | 355 | accumulated_metrics = collections.defaultdict(list) 356 | 357 | t_start = time.time() 358 | 359 | init_ops = (tf.global_variables_initializer(), tf.local_variables_initializer()) 360 | 361 | # If we already have a session instance, close it and start a new one 362 | if "sess" in globals(): 363 | sess.close() 364 | 365 | # We will need to keep this session instance around for e.g. visualization. 366 | # But you should probably wrap it in a `with tf.Session() sess:` context if you 367 | # want to use the code elsewhere. 368 | sess = tf.Session() 369 | sess.run(init_ops) 370 | 371 | for i_iter in range(config["training"]["n_training_steps"]): 372 | batch = next(training_data_iter) 373 | _, train_metrics = sess.run( 374 | [tensors["train_step"], tensors["metrics"]["training"]], 375 | feed_dict=fill_feed_dict(placeholders, batch), 376 | ) 377 | 378 | # accumulate over minibatches to reduce variance in the training metrics 379 | for k, v in train_metrics.items(): 380 | accumulated_metrics[k].append(v) 381 | 382 | if (i_iter + 1) % config["training"]["print_after"] == 0: 383 | metrics_to_print = {k: np.mean(v) for k, v in accumulated_metrics.items()} 384 | info_str = ", ".join(["%s %.4f" % (k, v) for k, v in metrics_to_print.items()]) 385 | # reset the metrics 386 | accumulated_metrics = collections.defaultdict(list) 387 | 388 | if (i_iter + 1) // config["training"]["print_after"] % config["training"][ 389 | "eval_after" 390 | ] == 0: 391 | eval_metrics = evaluate( 392 | sess, 393 | tensors["metrics"]["validation"], 394 | placeholders, 395 | validation_set, 396 | config["evaluation"]["batch_size"], 397 | ) 398 | info_str += ", " + ", ".join( 399 | ["%s %.4f" % ("val/" + k, v) for k, v in eval_metrics.items()] 400 | ) 401 | 402 | print("iter %d, %s, time %.2fs" % (i_iter + 1, info_str, time.time() - t_start)) 403 | t_start = time.time() 404 | 405 | 406 | # Note that albeit a bit noisy, the loss is going down, the similarity gap 407 | # between positive and negative pairs are growing and the evaluation results, i.e. pair AUC and triplet accuracies are going up as well. Overall training seems to be working! 408 | # 409 | # You can train this much longer. We observed improvement in performance even after training for 500,000 steps, but didn't push this much further as it is a synthetic task after all. 410 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 4 | 5 | 6 | def pairwise_euclidean_similarity(x, y): 7 | """Compute the pairwise Euclidean similarity between x and y. 8 | 9 | This function computes the following similarity value between each pair of x_i 10 | and y_j: s(x_i, y_j) = -|x_i - y_j|^2. 11 | 12 | Args: 13 | x: NxD float tensor. 14 | y: MxD float tensor. 15 | 16 | Returns: 17 | s: NxM float tensor, the pairwise euclidean similarity. 18 | """ 19 | s = 2 * tf.matmul(x, y, transpose_b=True) 20 | diag_x = tf.reduce_sum(x * x, axis=-1, keepdims=True) 21 | diag_y = tf.reshape(tf.reduce_sum(y * y, axis=-1), (1, -1)) 22 | return s - diag_x - diag_y 23 | 24 | 25 | def pairwise_dot_product_similarity(x, y): 26 | """Compute the dot product similarity between x and y. 27 | 28 | This function computes the following similarity value between each pair of x_i 29 | and y_j: s(x_i, y_j) = x_i^T y_j. 30 | 31 | Args: 32 | x: NxD float tensor. 33 | y: MxD float tensor. 34 | 35 | Returns: 36 | s: NxM float tensor, the pairwise dot product similarity. 37 | """ 38 | return tf.matmul(x, y, transpose_b=True) 39 | 40 | 41 | def pairwise_cosine_similarity(x, y): 42 | """Compute the cosine similarity between x and y. 43 | 44 | This function computes the following similarity value between each pair of x_i 45 | and y_j: s(x_i, y_j) = x_i^T y_j / (|x_i||y_j|). 46 | 47 | Args: 48 | x: NxD float tensor. 49 | y: MxD float tensor. 50 | 51 | Returns: 52 | s: NxM float tensor, the pairwise cosine similarity. 53 | """ 54 | x = tf.nn.l2_normalize(x, axis=-1) 55 | y = tf.nn.l2_normalize(y, axis=-1) 56 | return tf.matmul(x, y, transpose_b=True) 57 | 58 | 59 | PAIRWISE_SIMILARITY_FUNCTION = { 60 | "euclidean": pairwise_euclidean_similarity, 61 | "dotproduct": pairwise_dot_product_similarity, 62 | "cosine": pairwise_cosine_similarity, 63 | } 64 | 65 | 66 | def get_pairwise_similarity(name): 67 | """Get pairwise similarity metric by name. 68 | 69 | Args: 70 | name: string, name of the similarity metric, one of {dot-product, cosine, 71 | euclidean}. 72 | 73 | Returns: 74 | similarity: a (x, y) -> sim function. 75 | 76 | Raises: 77 | ValueError: if name is not supported. 78 | """ 79 | if name not in PAIRWISE_SIMILARITY_FUNCTION: 80 | raise ValueError('Similarity metric name "%s" not supported.' % name) 81 | else: 82 | return PAIRWISE_SIMILARITY_FUNCTION[name] 83 | 84 | """Cross Graph Attention""" 85 | 86 | def compute_cross_attention(x, y, sim): 87 | """Compute cross attention. 88 | 89 | x_i attend to y_j: 90 | a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j)) 91 | y_j attend to x_i: 92 | a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j)) 93 | attention_x = sum_j a_{i->j} y_j 94 | attention_y = sum_i a_{j->i} x_i 95 | 96 | Args: 97 | x: NxD float tensor. 98 | y: MxD float tensor. 99 | sim: a (x, y) -> similarity function. 100 | 101 | Returns: 102 | attention_x: NxD float tensor. 103 | attention_y: NxD float tensor. 104 | """ 105 | a = sim(x, y) 106 | a_x = tf.nn.softmax(a, axis=1) # i->j 107 | a_y = tf.nn.softmax(a, axis=0) # j->i 108 | attention_x = tf.matmul(a_x, y) 109 | attention_y = tf.matmul(a_y, x, transpose_a=True) 110 | return attention_x, attention_y 111 | 112 | 113 | def batch_block_pair_attention(data, block_idx, n_blocks, similarity="dotproduct"): 114 | """Compute batched attention between pairs of blocks. 115 | 116 | This function partitions the batch data into blocks according to block_idx. 117 | For each pair of blocks, x = data[block_idx == 2i], and 118 | y = data[block_idx == 2i+1], we compute 119 | 120 | x_i attend to y_j: 121 | a_{i->j} = exp(sim(x_i, y_j)) / sum_j exp(sim(x_i, y_j)) 122 | y_j attend to x_i: 123 | a_{j->i} = exp(sim(x_i, y_j)) / sum_i exp(sim(x_i, y_j)) 124 | 125 | and 126 | 127 | attention_x = sum_j a_{i->j} y_j 128 | attention_y = sum_i a_{j->i} x_i. 129 | 130 | Args: 131 | data: NxD float tensor. 132 | block_idx: N-dim int tensor. 133 | n_blocks: integer. 134 | similarity: a string, the similarity metric. 135 | 136 | Returns: 137 | attention_output: NxD float tensor, each x_i replaced by attention_x_i. 138 | 139 | Raises: 140 | ValueError: if n_blocks is not an integer or not a multiple of 2. 141 | """ 142 | if not isinstance(n_blocks, int): 143 | raise ValueError("n_blocks (%s) has to be an integer." % str(n_blocks)) 144 | 145 | if n_blocks % 2 != 0: 146 | raise ValueError("n_blocks (%d) must be a multiple of 2." % n_blocks) 147 | 148 | # Equation 9 149 | sim = get_pairwise_similarity(similarity) 150 | 151 | results = [] 152 | partitions = tf.dynamic_partition(data, block_idx, n_blocks) 153 | 154 | # It is rather complicated to allow n_blocks be a tf tensor and do this in a 155 | # dynamic loop, and probably unnecessary to do so. Therefore we are 156 | # restricting n_blocks to be a integer constant here and using the plain for 157 | # loop. 158 | for i in range(0, n_blocks, 2): 159 | x = partitions[i] 160 | y = partitions[i + 1] 161 | attention_x, attention_y = compute_cross_attention(x, y, sim) 162 | results.append(attention_x) 163 | results.append(attention_y) 164 | 165 | results = tf.concat(results, axis=0) 166 | # the shape of the first dimension is lost after concat, reset it back 167 | results.set_shape(data.shape) 168 | return results 169 | 170 | """Training on Pairs""" 171 | 172 | def euclidean_distance(x, y): 173 | """This is the squared Euclidean distance.""" 174 | return tf.reduce_sum((x - y) ** 2, axis=-1) 175 | 176 | def approximate_hamming_similarity(x, y): 177 | """Approximate Hamming similarity.""" 178 | return tf.reduce_mean(tf.tanh(x) * tf.tanh(y), axis=1) 179 | 180 | def pairwise_loss(x, y, labels, loss_type="margin", margin=1.0): 181 | """Compute pairwise loss. 182 | 183 | Args: 184 | x: [N, D] float tensor, representations for N examples. 185 | y: [N, D] float tensor, representations for another N examples. 186 | labels: [N] int tensor, with values in -1 or +1. labels[i] = +1 if x[i] 187 | and y[i] are similar, and -1 otherwise. 188 | loss_type: margin or hamming. 189 | margin: float scalar, margin for the margin loss. 190 | 191 | Returns: 192 | loss: [N] float tensor. Loss for each pair of representations. 193 | """ 194 | labels = tf.cast(labels, x.dtype) 195 | if loss_type == "margin": 196 | return tf.nn.relu(margin - labels * (1 - euclidean_distance(x, y))) 197 | elif loss_type == "hamming": 198 | return 0.25 * (labels - approximate_hamming_similarity(x, y)) ** 2 199 | else: 200 | raise ValueError("Unknown loss_type %s" % loss_type) 201 | 202 | """Training on Triplets""" 203 | 204 | def triplet_loss(x_1, y, x_2, z, loss_type="margin", margin=1.0): 205 | """Compute triplet loss. 206 | 207 | This function computes loss on a triplet of inputs (x, y, z). A similarity or 208 | distance value is computed for each pair of (x, y) and (x, z). Since the 209 | representations for x can be different in the two pairs (like our matching 210 | model) we distinguish the two x representations by x_1 and x_2. 211 | 212 | Args: 213 | x_1: [N, D] float tensor. 214 | y: [N, D] float tensor. 215 | x_2: [N, D] float tensor. 216 | z: [N, D] float tensor. 217 | loss_type: margin or hamming. 218 | margin: float scalar, margin for the margin loss. 219 | 220 | Returns: 221 | loss: [N] float tensor. Loss for each pair of representations. 222 | """ 223 | if loss_type == "margin": 224 | return tf.nn.relu( 225 | margin + euclidean_distance(x_1, y) - euclidean_distance(x_2, z) 226 | ) 227 | elif loss_type == "hamming": 228 | return 0.125 * ( 229 | (approximate_hamming_similarity(x_1, y) - 1) ** 2 230 | + (approximate_hamming_similarity(x_2, z) + 1) ** 2 231 | ) 232 | else: 233 | raise ValueError("Unknown loss_type %s" % loss_type) 234 | 235 | """Configs""" 236 | def get_default_config(): 237 | """The default configs.""" 238 | node_state_dim = 32 239 | graph_rep_dim = 128 240 | graph_embedding_net_config = dict( 241 | node_state_dim=node_state_dim, 242 | edge_hidden_sizes=[node_state_dim * 2, node_state_dim * 2], 243 | node_hidden_sizes=[node_state_dim * 2], 244 | n_prop_layers=5, 245 | # set to False to not share parameters across message passing layers 246 | share_prop_params=True, 247 | # initialize message MLP with small parameter weights to prevent 248 | # aggregated message vectors blowing up, alternatively we could also use 249 | # e.g. layer normalization to keep the scale of these under control. 250 | edge_net_init_scale=0.1, 251 | # other types of update like `mlp` and `residual` can also be used here. 252 | node_update_type="gru", 253 | # set to False if your graph already contains edges in both directions. 254 | use_reverse_direction=True, 255 | # set to True if your graph is directed 256 | reverse_dir_param_different=False, 257 | # we didn't use layer norm in our experiments but sometimes this can help. 258 | layer_norm=False, 259 | ) 260 | graph_matching_net_config = graph_embedding_net_config.copy() 261 | graph_matching_net_config["similarity"] = "dotproduct" 262 | 263 | return dict( 264 | encoder=dict(node_hidden_sizes=[node_state_dim], edge_hidden_sizes=None), 265 | aggregator=dict( 266 | node_hidden_sizes=[graph_rep_dim], 267 | graph_transform_sizes=[graph_rep_dim], 268 | gated=True, 269 | aggregation_type="sum", 270 | ), 271 | graph_embedding_net=graph_embedding_net_config, 272 | graph_matching_net=graph_matching_net_config, 273 | # Set to `embedding` to use the graph embedding net. 274 | model_type="matching", 275 | data=dict( 276 | problem="graph_edit_distance", 277 | dataset_params=dict( 278 | # always generate graphs with 20 nodes and p_edge=0.2. 279 | n_nodes_range=[20, 20], 280 | p_edge_range=[0.2, 0.2], 281 | n_changes_positive=1, 282 | n_changes_negative=2, 283 | validation_dataset_size=1000, 284 | ), 285 | ), 286 | training=dict( 287 | batch_size=20, 288 | learning_rate=1e-3, 289 | mode="pair", 290 | loss="margin", 291 | margin=1.0, 292 | # A small regularizer on the graph vector scales to avoid the graph 293 | # vectors blowing up. If numerical issues is particularly bad in the 294 | # model we can add `snt.LayerNorm` to the outputs of each layer, the 295 | # aggregated messages and aggregated node representations to 296 | # keep the network activation scale in a reasonable range. 297 | graph_vec_regularizer_weight=1e-6, 298 | # Add gradient clipping to avoid large gradients. 299 | clip_value=10.0, 300 | # Increase this to train longer. 301 | n_training_steps=10000, 302 | # Print training information every this many training steps. 303 | print_after=100, 304 | # Evaluate on validation set every `eval_after * print_after` steps. 305 | eval_after=10, 306 | ), 307 | evaluation=dict(batch_size=20), 308 | seed=8, 309 | ) 310 | 311 | --------------------------------------------------------------------------------