├── README.md ├── __init__.py ├── layers.py ├── models.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Graph Matching Networks 2 | 3 | This is a reimplementation of the ICLR 2019 paper [Graph Matching Networks for Learning the Similarity of Graph Structured Objects 4 | ](https://arxiv.org/abs/1904.12787?source=techstories.org) (Li et al.) in PyTorch. 5 | 6 | **Note: I'm not one of the original authors**. 7 | 8 | ## Requirements 9 | - torch >= 1.6.0 10 | - torch_geometric ([Installation guide](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html)) 11 | 12 | ## Instructions 13 | - Install the requirements 14 | - Import the models from models.py 15 | 16 | ## References: 17 | - [Official Colab notebook](https://github.com/deepmind/deepmind-research/tree/master/graph_matching_networks) (using TensorFlow 1.0 and Sonnet) 18 | - [A repo that adaps the official notebook to python files](https://github.com/chang2000/tfGMN) 19 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.nn import MessagePassing, BatchNorm 3 | from torch_scatter import scatter_mean 4 | 5 | from .utils import batch_block_pair_attention 6 | 7 | 8 | class GraphConvolution(MessagePassing): 9 | def __init__(self, in_channels, out_channels, args, aggr="add"): 10 | super(GraphConvolution, self).__init__(aggr=aggr) 11 | self.args = args 12 | self.lin_node = torch.nn.Linear(in_channels, out_channels) 13 | self.lin_message = torch.nn.Linear(out_channels * 2, out_channels) 14 | self.lin_passing = torch.nn.Linear(out_channels + in_channels, out_channels) 15 | self.batch_norm = BatchNorm(out_channels) 16 | 17 | def forward(self, x, edge_index): 18 | x = self.lin_node(x) 19 | return self.propagate(edge_index, x=x) 20 | 21 | def message(self, edge_index_i, x_i, x_j): 22 | m = self.lin_message(torch.cat([x_i, x_j], dim=1)) 23 | return m 24 | 25 | def update(self, aggr_out, edge_index, x): 26 | aggr_out = self.lin_passing(torch.cat([x, aggr_out])) 27 | aggr_out = self.batch_norm(aggr_out) 28 | return aggr_out 29 | 30 | class GraphMatchingConvolution(MessagePassing): 31 | def __init__(self, in_channels, out_channels, args, aggr="add"): 32 | super(GraphMatchingConvolution, self).__init__(aggr=aggr) 33 | self.args = args 34 | self.lin_node = torch.nn.Linear(in_channels, out_channels) 35 | self.lin_message = torch.nn.Linear(out_channels * 2, out_channels) 36 | self.lin_passing = torch.nn.Linear(out_channels + in_channels, out_channels) 37 | self.batch_norm = BatchNorm(out_channels) 38 | 39 | def forward(self, x, edge_index, batch): 40 | x_transformed = self.lin_node(x) 41 | return self.propagate(edge_index, x=x_transformed, original_x=x, batch=batch) 42 | 43 | def message(self, edge_index_i, x_i, x_j): 44 | x = torch.cat([x_i, x_j], dim=1) 45 | m = self.lin_message(x) 46 | return m 47 | 48 | def update(self, aggr_out, edge_index, x, original_x, batch): 49 | n_graphs = torch.unique(batch).shape[0] 50 | cross_graph_attention = batch_block_pair_attention(original_x, batch, n_graphs) 51 | attention_input = original_x - cross_graph_attention 52 | aggr_out = self.lin_passing(torch.cat([aggr_out, attention_input], dim=1)) 53 | aggr_out = self.batch_norm(aggr_out) 54 | return aggr_out, edge_index, batch 55 | 56 | class GraphAggregator(torch.nn.Module): 57 | def __init__(self, in_channels, out_channels, args): 58 | super(GraphAggregator, self).__init__() 59 | self.lin = torch.nn.Linear(in_channels, out_channels) 60 | self.lin_gate = torch.nn.Linear(in_channels, out_channels) 61 | self.lin_final = torch.nn.Linear(out_channels, out_channels) 62 | self.args = args 63 | 64 | def forward(self, x, edge_index, batch): 65 | # print("x:", x.shape) 66 | x_states = self.lin(x) 67 | x_gates = torch.nn.functional.softmax(self.lin_gate(x), dim=1) 68 | x_states = x_states * x_gates 69 | # print("x_states:", x_states.shape) 70 | # print("batch:", batch.shape) 71 | x_states = scatter_mean(x_states, batch, dim=0) 72 | x_states = self.lin_final(x_states) 73 | return x_states 74 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch_geometric.data import Data 4 | from torch.nn import functional as F 5 | 6 | from .layers import GraphConvolution, GraphAggregator, GraphMatchingConvolution 7 | from .utils import adj_matrix_to_edge_index 8 | from .utils import create_batch, trim_feats 9 | from .utils import acc_f1 10 | 11 | 12 | class GenericGNN(torch.nn.Module): 13 | def __init__(self, args): 14 | super(GenericGNN, self).__init__() 15 | self.args = args 16 | if args.n_classes > 2: 17 | self.f1_average = 'micro' 18 | else: 19 | self.f1_average = 'binary' 20 | self.layers = torch.nn.ModuleList() 21 | self.layers.append(GraphConvolution(self.args.feat_dim, self.args.dim, args)) 22 | for _ in range(self.args.num_layers - 1): 23 | self.layers.append( 24 | GraphConvolution(self.args.dim, self.args.dim, args), 25 | ) 26 | self.aggregator = GraphAggregator(self.args.dim, self.args.dim, self.args) 27 | self.cls = torch.nn.Linear(self.args.dim, self.args.n_classes) 28 | 29 | def compute_emb(self, feats, adjs, sizes): 30 | # batch = create_batch(sizes) 31 | # batch = torch.tensor(batch, dtype=torch.int64) 32 | edge_index = adj_matrix_to_edge_index(adjs) 33 | batch = create_batch(sizes) 34 | feats = trim_feats(feats, sizes) 35 | for i in range(self.args.num_layers): 36 | # convolution 37 | feats, edge_index = self.layers[i](feats, edge_index) 38 | # aggregator 39 | feats = self.aggregator(feats, edge_index, batch) 40 | return feats 41 | 42 | def forward(self, feats_1, adjs_1, feats_2, adjs_2, sizes_1, sizes_2): 43 | # computing the embedding 44 | emb_1 = self.compute_emb(feats_1, adjs_1, sizes_1) 45 | emb_2 = self.compute_emb(feats_2, adjs_2, sizes_2) 46 | outputs = torch.cat((emb_1, emb_2), 1) 47 | outputs = outputs.reshape(outputs.size(0), -1) 48 | 49 | # classification 50 | outputs = self.cls.forward(outputs) 51 | return outputs 52 | 53 | def compute_metrics(self, outputs, labels, split, backpropagate=False): 54 | outputs = F.log_softmax(outputs, dim=1) 55 | loss = F.nll_loss(outputs, labels) 56 | if backpropagate: 57 | loss.backward() 58 | if split == "train": 59 | verbose = True 60 | else: 61 | verbose = True 62 | acc, f1 = acc_f1( 63 | outputs, labels, average=self.f1_average, logging=self.args.logging, verbose=verbose) 64 | metrics = {'loss': loss, 'acc': acc, 'f1': f1} 65 | return metrics, outputs.shape[0] 66 | 67 | def init_metric_dict(self): 68 | return {'acc': -1, 'f1': -1} 69 | 70 | def has_improved(self, m1, m2): 71 | return m1["acc"] < m2["acc"] 72 | 73 | class GraphMatchingNetwork(torch.nn.Module): 74 | def __init__(self, args): 75 | super(GraphMatchingNetwork, self).__init__() 76 | self.args = args 77 | if args.n_classes > 2: 78 | self.f1_average = 'micro' 79 | else: 80 | self.f1_average = 'binary' 81 | self.layers = torch.nn.ModuleList() 82 | self.layers.append(GraphMatchingConvolution( 83 | self.args.feat_dim, self.args.dim, args 84 | )) 85 | for _ in range(self.args.num_layers - 1): 86 | self.layers.append( 87 | GraphMatchingConvolution( 88 | self.args.dim, self.args.dim, args 89 | ) 90 | ) 91 | self.aggregator = GraphAggregator(self.args.dim, self.args.dim, self.args) 92 | self.cls = torch.nn.Linear(self.args.dim * 2, self.args.n_classes) 93 | 94 | 95 | def compute_emb(self, feats, edge_index, batch): 96 | # data = Data(x=feats, edge_index=edge_index, batch=batch) 97 | for i in range(self.args.num_layers): 98 | # convolution 99 | feats, edge_index, batch = self.layers[i](feats, edge_index, batch) 100 | # aggregator 101 | feats = self.aggregator(feats, edge_index, batch) 102 | return feats, edge_index, batch 103 | 104 | def combine_pair_embedding(self, feats_1, adjs_1, feats_2, adjs_2, sizes_1, sizes_2): 105 | sizes = torch.cat([sizes_1, sizes_2], dim=0) 106 | feats = torch.cat([feats_1, feats_2], dim=0) 107 | feats = trim_feats(feats, sizes) 108 | edge_index_1 = adj_matrix_to_edge_index(adjs_1) 109 | edge_index_2 = adj_matrix_to_edge_index(adjs_2) 110 | edge_index = torch.cat([edge_index_1, edge_index_2], dim=1) 111 | batch = create_batch(sizes) 112 | feats = feats.to(self.args.device) 113 | edge_index = edge_index.to(self.args.device) 114 | batch = batch.to(self.args.device) 115 | return feats, edge_index, batch 116 | 117 | def forward(self, feats_1, adjs_1, feats_2, adjs_2, sizes_1, sizes_2): 118 | # computing the embedding 119 | feats, edge_index, batch = self.combine_pair_embedding(feats_1, adjs_1, feats_2, adjs_2, sizes_1, sizes_2) 120 | emb, _, _ = self.compute_emb(feats, edge_index, batch) 121 | emb_1 = emb[:emb.shape[0] // 2, :] 122 | emb_2 = emb[emb.shape[0] // 2:, :] 123 | outputs = torch.cat((emb_1, emb_2), 1) 124 | outputs = outputs.reshape(outputs.size(0), -1) 125 | 126 | # classification 127 | outputs = self.cls.forward(outputs) 128 | return outputs 129 | 130 | def compute_metrics(self, outputs, labels, split, backpropagate=False): 131 | outputs = F.log_softmax(outputs, dim=1) 132 | loss = F.nll_loss(outputs, labels) 133 | if backpropagate: 134 | loss.backward() 135 | if split == "train": 136 | verbose = True 137 | else: 138 | verbose = True 139 | acc, f1 = acc_f1( 140 | outputs, labels, average=self.f1_average, logging=self.args.logging, verbose=verbose) 141 | metrics = {'loss': loss, 'acc': acc, 'f1': f1} 142 | return metrics, outputs.shape[0] 143 | 144 | def init_metric_dict(self): 145 | return {'acc': -1, 'f1': -1} 146 | 147 | def has_improved(self, m1, m2): 148 | return m1["acc"] < m2["acc"] 149 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def acc_f1(output, labels, average='binary', logging = None, verbose=True): 6 | preds = output.max(1)[1].type_as(labels) 7 | if preds.is_cuda: 8 | preds = preds.cpu() 9 | labels = labels.cpu() 10 | # if verbose: 11 | # logging.info(f"Target: {labels.tolist()}") 12 | # logging.info(f"Prediction: {preds.tolist()}") 13 | # logging.info("---") 14 | accuracy = accuracy_score(preds, labels) 15 | f1 = f1_score(preds, labels, average=average) 16 | return accuracy, f1 17 | 18 | def dynamic_partition(data, partitions, num_partitions): 19 | res = [] 20 | # print(data.shape, partitions.shape) 21 | for i in range(num_partitions): 22 | res.append(data[torch.where(partitions == i)]) 23 | return res 24 | 25 | def adj_matrix_to_edge_index(adj_matrix, device=None): 26 | edge_index = [[], []] 27 | for i, row in enumerate(adj_matrix.cpu().detach().numpy().tolist()): 28 | for j, cell_value in enumerate(row[i + 1:]): 29 | if cell_value == 1: 30 | edge_index[0].append(i) 31 | edge_index[1].append(j) 32 | edge_index[0].append(i) 33 | edge_index[1].append(i) 34 | edge_index = torch.tensor(edge_index, dtype=torch.int64) 35 | if device: 36 | edge_index = edge_index.to(device) 37 | return edge_index 38 | 39 | def create_batch(sizes): 40 | sizes = sizes.tolist() 41 | sizes = list(map(int, sizes)) 42 | batch = [] 43 | for i, size in enumerate(sizes): 44 | batch.extend([i] * size) 45 | batch = torch.tensor(batch, dtype=torch.int64) 46 | return batch 47 | 48 | def trim_feats(feats, sizes): 49 | stacked_num_nodes = sum(sizes) 50 | stacked_tree_feats = torch.zeros((stacked_num_nodes, feats.shape[-1]), dtype=torch.float64) 51 | start_index = 0 52 | for i, size in enumerate(sizes): 53 | end_index = start_index + size 54 | stacked_tree_feats[start_index:end_index, :] = feats[i, :size, :] 55 | start_index = end_index 56 | return stacked_tree_feats 57 | 58 | def pairwise_cosine_similarity(a, b): 59 | a_norm = torch.norm(a, dim=1).unsqueeze(-1) 60 | b_norm = torch.norm(b, dim=1).unsqueeze(-1) 61 | return torch.matmul(a_norm, b_norm.T) 62 | 63 | def compute_crosss_attention(x_i, x_j): 64 | a = pairwise_cosine_similarity(x_i, x_j) 65 | a_i = F.softmax(a, dim=1) 66 | a_j = F.softmax(a, dim=0) 67 | att_i = torch.matmul(a_i, x_j) 68 | att_j = torch.matmul(a_j.T, x_i) 69 | return att_i, att_j 70 | 71 | def batch_block_pair_attention(data, batch, n_graphs): 72 | results = [None for _ in range(n_graphs * 2)] 73 | partitions = dynamic_partition(data, batch, n_graphs * 2) 74 | for i in range(0, n_graphs): 75 | x = partitions[i] 76 | y = partitions[i + n_graphs] 77 | attention_x, attention_y = compute_crosss_attention(x, y) 78 | results[i] = attention_x 79 | results[i + n_graphs] = attention_y 80 | results = torch.cat(results, dim=0) 81 | results = results.view(data.shape) 82 | return results 83 | --------------------------------------------------------------------------------