├── .coveragerc ├── .gitignore ├── .style.yapf ├── .travis.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── dgmc ├── __init__.py ├── models │ ├── __init__.py │ ├── dgmc.py │ ├── gin.py │ ├── mlp.py │ ├── rel.py │ └── spline.py └── utils │ ├── __init__.py │ └── data.py ├── docs ├── .nojekyll ├── Makefile ├── index.html ├── requirements.txt └── source │ ├── conf.py │ ├── index.rst │ └── modules │ ├── models.rst │ └── utils.rst ├── examples ├── dbp15k.py ├── pascal.py ├── pascal_pf.py └── willow.py ├── figures ├── best_car.png ├── best_duck.png ├── best_motorbike.png ├── overview.png └── worst_duck.png ├── readthedocs.yml ├── setup.cfg ├── setup.py └── test ├── models ├── test_dgmc.py ├── test_gin.py ├── test_mlp.py ├── test_rel.py └── test_spline.py └── utils └── test_data.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=dgmc 3 | [report] 4 | exclude_lines = 5 | pragma: no cover 6 | raise 7 | except 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | _ext/ 3 | build/ 4 | dist/ 5 | .cache/ 6 | data/ 7 | .eggs/ 8 | *.egg-info/ 9 | .coverage 10 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = pep8 3 | split_before_named_assigns = False 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | jobs: 2 | include: 3 | - os: linux 4 | language: python 5 | python: 3.7 6 | install: 7 | - pip install numpy 8 | - pip install torch==1.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html 9 | - pip install torch-scatter==latest+cpu -f https://pytorch-geometric.com/whl/torch-1.5.0.html 10 | - pip install torch-sparse==latest+cpu -f https://pytorch-geometric.com/whl/torch-1.5.0.html 11 | - pip install torch-spline-conv==latest+cpu -f https://pytorch-geometric.com/whl/torch-1.5.0.html 12 | - pip install torch-geometric 13 | - pip install pycodestyle 14 | - pip install flake8 15 | - pip install codecov 16 | - pip install sphinx 17 | - pip install sphinx_rtd_theme 18 | script: 19 | - python -c "import torch; print(torch.__version__)" 20 | - pycodestyle --ignore=E731 . 21 | - flake8 . 22 | - python setup.py install 23 | - python setup.py test 24 | - cd docs && make clean && make html && cd .. 25 | after_success: 26 | - codecov 27 | notifications: 28 | email: false 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Matthias Fey 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | 3 | recursive-include figures * 4 | recursive-include examples * 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [build-image]: https://travis-ci.org/rusty1s/deep-graph-matching-consensus.svg?branch=master 2 | [build-url]: https://travis-ci.org/rusty1s/deep-graph-matching-consensus 3 | [docs-image]: https://readthedocs.org/projects/deep-graph-matching-consensus/badge/?version=latest 4 | [docs-url]: https://deep-graph-matching-consensus.readthedocs.io/en/latest/?badge=latest 5 | [coverage-image]: https://codecov.io/gh/rusty1s/deep-graph-matching-consensus/branch/master/graph/badge.svg 6 | [coverage-url]: https://codecov.io/github/rusty1s/deep-graph-matching-consensus?branch=master 7 | 8 |

Deep Graph Matching Consensus

9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- 13 | 14 | [![Build Status][build-image]][build-url] 15 | [![Docs Status][docs-image]][docs-url] 16 | [![Code Coverage][coverage-image]][coverage-url] 17 | 18 | **[Documentation](https://deep-graph-matching-consensus.readthedocs.io)** 19 | 20 | This is a PyTorch implementation of **Deep Graph Matching Consensus**, as described in our paper: 21 | 22 | Matthias Fey, Jan E. Lenssen, Christopher Morris, Jonathan Masci, Nils M. Kriege: [Deep Graph Matching Consensus](https://openreview.net/forum?id=HyeJf1HKvS) *(ICLR 2020)* 23 | 24 | ## Requirements 25 | 26 | * **[PyTorch](https://pytorch.org/get-started/locally/)** (>=1.2.0) 27 | * **[PyTorch Geometric](https://github.com/rusty1s/pytorch_geometric)** (>=1.5.0) 28 | * **[KeOps](https://github.com/getkeops/keops)** (>=1.1.0) 29 | 30 | ## Installation 31 | 32 | ``` 33 | $ python setup.py install 34 | ``` 35 | 36 | Head over to our [documentation](https://deep-graph-matching-consensus.readthedocs.io) for a detailed overview of the `DGMC` module. 37 | 38 | ## Running examples 39 | 40 | We provide training and evaluation procedures for the [PascalVOC with Berkely annotations](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.PascalVOCKeypoints) dataset, the [WILLOW-ObjectClass](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.WILLOWObjectClass) dataset, the [PascalPF](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.PascalPF) dataset, and the [DBP15K](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.DBP15K) dataset. 41 | Experiments can be run via: 42 | 43 | ``` 44 | $ cd examples/ 45 | $ python pascal.py 46 | $ python willow.py 47 | $ python pascal_pf.py 48 | $ python dbp15k.py --category=zh_en 49 | ``` 50 | 51 |

52 | 53 | 54 | 55 | 56 |

57 | 58 | ## Cite 59 | 60 | Please cite [our paper](https://openreview.net/forum?id=HyeJf1HKvS) if you use this code in your own work: 61 | 62 | ``` 63 | @inproceedings{Fey/etal/2020, 64 | title={Deep Graph Matching Consensus}, 65 | author={Fey, M. and Lenssen, J. E. and Morris, C. and Masci, J. and Kriege, N. M.}, 66 | booktitle={International Conference on Learning Representations (ICLR)}, 67 | year={2020}, 68 | } 69 | ``` 70 | 71 | ## Running tests 72 | 73 | ``` 74 | $ python setup.py test 75 | ``` 76 | -------------------------------------------------------------------------------- /dgmc/__init__.py: -------------------------------------------------------------------------------- 1 | import dgmc.models 2 | import dgmc.utils 3 | from dgmc.models.dgmc import DGMC 4 | 5 | __version__ = '1.0.0' 6 | 7 | __all__ = [ 8 | 'dgmc', 9 | 'DGMC', 10 | '__version__', 11 | ] 12 | -------------------------------------------------------------------------------- /dgmc/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mlp import MLP 2 | from .gin import GIN 3 | from .spline import SplineCNN 4 | from .rel import RelCNN 5 | from .dgmc import DGMC 6 | 7 | __all__ = [ 8 | 'MLP', 9 | 'GIN', 10 | 'SplineCNN', 11 | 'RelCNN', 12 | 'DGMC', 13 | ] 14 | -------------------------------------------------------------------------------- /dgmc/models/dgmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Sequential as Seq, Linear as Lin, ReLU 3 | from torch_scatter import scatter_add 4 | from torch_geometric.utils import to_dense_batch 5 | from torch_geometric.nn.inits import reset 6 | 7 | try: 8 | from pykeops.torch import LazyTensor 9 | except ImportError: 10 | LazyTensor = None 11 | 12 | EPS = 1e-8 13 | 14 | 15 | def masked_softmax(src, mask, dim=-1): 16 | out = src.masked_fill(~mask, float('-inf')) 17 | out = torch.softmax(out, dim=dim) 18 | out = out.masked_fill(~mask, 0) 19 | return out 20 | 21 | 22 | def to_sparse(x, mask): 23 | return x[mask] 24 | 25 | 26 | def to_dense(x, mask): 27 | out = x.new_zeros(tuple(mask.size()) + (x.size(-1), )) 28 | out[mask] = x 29 | return out 30 | 31 | 32 | class DGMC(torch.nn.Module): 33 | r"""The *Deep Graph Matching Consensus* module which first matches nodes 34 | locally via a graph neural network :math:`\Psi_{\theta_1}`, and then 35 | updates correspondence scores iteratively by reaching for neighborhood 36 | consensus via a second graph neural network :math:`\Psi_{\theta_2}`. 37 | 38 | .. note:: 39 | See the `PyTorch Geometric introductory tutorial 40 | `_ for a detailed overview of the used GNN modules 42 | and the respective data format. 43 | 44 | Args: 45 | psi_1 (torch.nn.Module): The first GNN :math:`\Psi_{\theta_1}` which 46 | takes in node features :obj:`x`, edge connectivity 47 | :obj:`edge_index`, and optional edge features :obj:`edge_attr` and 48 | computes node embeddings. 49 | psi_2 (torch.nn.Module): The second GNN :math:`\Psi_{\theta_2}` which 50 | takes in node features :obj:`x`, edge connectivity 51 | :obj:`edge_index`, and optional edge features :obj:`edge_attr` and 52 | validates for neighborhood consensus. 53 | :obj:`psi_2` needs to hold the attributes :obj:`in_channels` and 54 | :obj:`out_channels` which indicates the dimensionality of randomly 55 | drawn node indicator functions and the output dimensionality of 56 | :obj:`psi_2`, respectively. 57 | num_steps (int): Number of consensus iterations. 58 | k (int, optional): Sparsity parameter. If set to :obj:`-1`, will 59 | not sparsify initial correspondence rankings. (default: :obj:`-1`) 60 | detach (bool, optional): If set to :obj:`True`, will detach the 61 | computation of :math:`\Psi_{\theta_1}` from the current computation 62 | graph. (default: :obj:`False`) 63 | """ 64 | def __init__(self, psi_1, psi_2, num_steps, k=-1, detach=False): 65 | super(DGMC, self).__init__() 66 | 67 | self.psi_1 = psi_1 68 | self.psi_2 = psi_2 69 | self.num_steps = num_steps 70 | self.k = k 71 | self.detach = detach 72 | self.backend = 'auto' 73 | 74 | self.mlp = Seq( 75 | Lin(psi_2.out_channels, psi_2.out_channels), 76 | ReLU(), 77 | Lin(psi_2.out_channels, 1), 78 | ) 79 | 80 | def reset_parameters(self): 81 | self.psi_1.reset_parameters() 82 | self.psi_2.reset_parameters() 83 | reset(self.mlp) 84 | 85 | def __top_k__(self, x_s, x_t): # pragma: no cover 86 | r"""Memory-efficient top-k correspondence computation.""" 87 | if LazyTensor is not None: 88 | x_s = x_s.unsqueeze(-2) # [..., n_s, 1, d] 89 | x_t = x_t.unsqueeze(-3) # [..., 1, n_t, d] 90 | x_s, x_t = LazyTensor(x_s), LazyTensor(x_t) 91 | S_ij = (-x_s * x_t).sum(dim=-1) 92 | return S_ij.argKmin(self.k, dim=2, backend=self.backend) 93 | else: 94 | x_s = x_s # [..., n_s, d] 95 | x_t = x_t.transpose(-1, -2) # [..., d, n_t] 96 | S_ij = x_s @ x_t 97 | return S_ij.topk(self.k, dim=2)[1] 98 | 99 | def __include_gt__(self, S_idx, s_mask, y): 100 | r"""Includes the ground-truth values in :obj:`y` to the index tensor 101 | :obj:`S_idx`.""" 102 | (B, N_s), (row, col), k = s_mask.size(), y, S_idx.size(-1) 103 | 104 | gt_mask = (S_idx[s_mask][row] != col.view(-1, 1)).all(dim=-1) 105 | 106 | sparse_mask = gt_mask.new_zeros((s_mask.sum(), )) 107 | sparse_mask[row] = gt_mask 108 | 109 | dense_mask = sparse_mask.new_zeros((B, N_s)) 110 | dense_mask[s_mask] = sparse_mask 111 | last_entry = torch.zeros(k, dtype=torch.bool, device=gt_mask.device) 112 | last_entry[-1] = 1 113 | dense_mask = dense_mask.view(B, N_s, 1) * last_entry.view(1, 1, k) 114 | 115 | return S_idx.masked_scatter(dense_mask, col[gt_mask]) 116 | 117 | def forward(self, x_s, edge_index_s, edge_attr_s, batch_s, x_t, 118 | edge_index_t, edge_attr_t, batch_t, y=None): 119 | r""" 120 | Args: 121 | x_s (Tensor): Source graph node features of shape 122 | :obj:`[batch_size * num_nodes, C_in]`. 123 | edge_index_s (LongTensor): Source graph edge connectivity of shape 124 | :obj:`[2, num_edges]`. 125 | edge_attr_s (Tensor): Source graph edge features of shape 126 | :obj:`[num_edges, D]`. Set to :obj:`None` if the GNNs are not 127 | taking edge features into account. 128 | batch_s (LongTensor): Source graph batch vector of shape 129 | :obj:`[batch_size * num_nodes]` indicating node to graph 130 | assignment. Set to :obj:`None` if operating on single graphs. 131 | x_t (Tensor): Target graph node features of shape 132 | :obj:`[batch_size * num_nodes, C_in]`. 133 | edge_index_t (LongTensor): Target graph edge connectivity of shape 134 | :obj:`[2, num_edges]`. 135 | edge_attr_t (Tensor): Target graph edge features of shape 136 | :obj:`[num_edges, D]`. Set to :obj:`None` if the GNNs are not 137 | taking edge features into account. 138 | batch_s (LongTensor): Target graph batch vector of shape 139 | :obj:`[batch_size * num_nodes]` indicating node to graph 140 | assignment. Set to :obj:`None` if operating on single graphs. 141 | y (LongTensor, optional): Ground-truth matchings of shape 142 | :obj:`[2, num_ground_truths]` to include ground-truth values 143 | when training against sparse correspondences. Ground-truths 144 | are only used in case the model is in training mode. 145 | (default: :obj:`None`) 146 | 147 | Returns: 148 | Initial and refined correspondence matrices :obj:`(S_0, S_L)` 149 | of shapes :obj:`[batch_size * num_nodes, num_nodes]`. The 150 | correspondence matrix are either given as dense or sparse matrices. 151 | """ 152 | h_s = self.psi_1(x_s, edge_index_s, edge_attr_s) 153 | h_t = self.psi_1(x_t, edge_index_t, edge_attr_t) 154 | 155 | h_s, h_t = (h_s.detach(), h_t.detach()) if self.detach else (h_s, h_t) 156 | 157 | h_s, s_mask = to_dense_batch(h_s, batch_s, fill_value=0) 158 | h_t, t_mask = to_dense_batch(h_t, batch_t, fill_value=0) 159 | 160 | assert h_s.size(0) == h_t.size(0), 'Encountered unequal batch-sizes' 161 | (B, N_s, C_out), N_t = h_s.size(), h_t.size(1) 162 | R_in, R_out = self.psi_2.in_channels, self.psi_2.out_channels 163 | 164 | if self.k < 1: 165 | # ------ Dense variant ------ # 166 | S_hat = h_s @ h_t.transpose(-1, -2) # [B, N_s, N_t, C_out] 167 | S_mask = s_mask.view(B, N_s, 1) & t_mask.view(B, 1, N_t) 168 | S_0 = masked_softmax(S_hat, S_mask, dim=-1)[s_mask] 169 | 170 | for _ in range(self.num_steps): 171 | S = masked_softmax(S_hat, S_mask, dim=-1) 172 | r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, 173 | device=h_s.device) 174 | r_t = S.transpose(-1, -2) @ r_s 175 | 176 | r_s, r_t = to_sparse(r_s, s_mask), to_sparse(r_t, t_mask) 177 | o_s = self.psi_2(r_s, edge_index_s, edge_attr_s) 178 | o_t = self.psi_2(r_t, edge_index_t, edge_attr_t) 179 | o_s, o_t = to_dense(o_s, s_mask), to_dense(o_t, t_mask) 180 | 181 | D = o_s.view(B, N_s, 1, R_out) - o_t.view(B, 1, N_t, R_out) 182 | S_hat = S_hat + self.mlp(D).squeeze(-1).masked_fill(~S_mask, 0) 183 | 184 | S_L = masked_softmax(S_hat, S_mask, dim=-1)[s_mask] 185 | 186 | return S_0, S_L 187 | else: 188 | # ------ Sparse variant ------ # 189 | S_idx = self.__top_k__(h_s, h_t) # [B, N_s, k] 190 | 191 | # In addition to the top-k, randomly sample negative examples and 192 | # ensure that the ground-truth is included as a sparse entry. 193 | if self.training and y is not None: 194 | rnd_size = (B, N_s, min(self.k, N_t - self.k)) 195 | S_rnd_idx = torch.randint(N_t, rnd_size, dtype=torch.long, 196 | device=S_idx.device) 197 | S_idx = torch.cat([S_idx, S_rnd_idx], dim=-1) 198 | S_idx = self.__include_gt__(S_idx, s_mask, y) 199 | 200 | k = S_idx.size(-1) 201 | tmp_s = h_s.view(B, N_s, 1, C_out) 202 | idx = S_idx.view(B, N_s * k, 1).expand(-1, -1, C_out) 203 | tmp_t = torch.gather(h_t.view(B, N_t, C_out), -2, idx) 204 | S_hat = (tmp_s * tmp_t.view(B, N_s, k, C_out)).sum(dim=-1) 205 | S_0 = S_hat.softmax(dim=-1)[s_mask] 206 | 207 | for _ in range(self.num_steps): 208 | S = S_hat.softmax(dim=-1) 209 | r_s = torch.randn((B, N_s, R_in), dtype=h_s.dtype, 210 | device=h_s.device) 211 | 212 | tmp_t = r_s.view(B, N_s, 1, R_in) * S.view(B, N_s, k, 1) 213 | tmp_t = tmp_t.view(B, N_s * k, R_in) 214 | idx = S_idx.view(B, N_s * k, 1) 215 | r_t = scatter_add(tmp_t, idx, dim=1, dim_size=N_t) 216 | 217 | r_s, r_t = to_sparse(r_s, s_mask), to_sparse(r_t, t_mask) 218 | o_s = self.psi_2(r_s, edge_index_s, edge_attr_s) 219 | o_t = self.psi_2(r_t, edge_index_t, edge_attr_t) 220 | o_s, o_t = to_dense(o_s, s_mask), to_dense(o_t, t_mask) 221 | 222 | o_s = o_s.view(B, N_s, 1, R_out).expand(-1, -1, k, -1) 223 | idx = S_idx.view(B, N_s * k, 1).expand(-1, -1, R_out) 224 | tmp_t = torch.gather(o_t.view(B, N_t, R_out), -2, idx) 225 | D = o_s - tmp_t.view(B, N_s, k, R_out) 226 | S_hat = S_hat + self.mlp(D).squeeze(-1) 227 | 228 | S_L = S_hat.softmax(dim=-1)[s_mask] 229 | S_idx = S_idx[s_mask] 230 | 231 | # Convert sparse layout to `torch.sparse_coo_tensor`. 232 | row = torch.arange(x_s.size(0), device=S_idx.device) 233 | row = row.view(-1, 1).repeat(1, k) 234 | idx = torch.stack([row.view(-1), S_idx.view(-1)], dim=0) 235 | size = torch.Size([x_s.size(0), N_t]) 236 | 237 | S_sparse_0 = torch.sparse_coo_tensor( 238 | idx, S_0.view(-1), size, requires_grad=S_0.requires_grad) 239 | S_sparse_0.__idx__ = S_idx 240 | S_sparse_0.__val__ = S_0 241 | 242 | S_sparse_L = torch.sparse_coo_tensor( 243 | idx, S_L.view(-1), size, requires_grad=S_L.requires_grad) 244 | S_sparse_L.__idx__ = S_idx 245 | S_sparse_L.__val__ = S_L 246 | 247 | return S_sparse_0, S_sparse_L 248 | 249 | def loss(self, S, y, reduction='mean'): 250 | r"""Computes the negative log-likelihood loss on the correspondence 251 | matrix. 252 | 253 | Args: 254 | S (Tensor): Sparse or dense correspondence matrix of shape 255 | :obj:`[batch_size * num_nodes, num_nodes]`. 256 | y (LongTensor): Ground-truth matchings of shape 257 | :obj:`[2, num_ground_truths]`. 258 | reduction (string, optional): Specifies the reduction to apply to 259 | the output: :obj:`'none'|'mean'|'sum'`. 260 | (default: :obj:`'mean'`) 261 | """ 262 | assert reduction in ['none', 'mean', 'sum'] 263 | if not S.is_sparse: 264 | val = S[y[0], y[1]] 265 | else: 266 | assert S.__idx__ is not None and S.__val__ is not None 267 | mask = S.__idx__[y[0]] == y[1].view(-1, 1) 268 | val = S.__val__[[y[0]]][mask] 269 | nll = -torch.log(val + EPS) 270 | return nll if reduction == 'none' else getattr(torch, reduction)(nll) 271 | 272 | def acc(self, S, y, reduction='mean'): 273 | r"""Computes the accuracy of correspondence predictions. 274 | 275 | Args: 276 | S (Tensor): Sparse or dense correspondence matrix of shape 277 | :obj:`[batch_size * num_nodes, num_nodes]`. 278 | y (LongTensor): Ground-truth matchings of shape 279 | :obj:`[2, num_ground_truths]`. 280 | reduction (string, optional): Specifies the reduction to apply to 281 | the output: :obj:`'mean'|'sum'`. (default: :obj:`'mean'`) 282 | """ 283 | assert reduction in ['mean', 'sum'] 284 | if not S.is_sparse: 285 | pred = S[y[0]].argmax(dim=-1) 286 | else: 287 | assert S.__idx__ is not None and S.__val__ is not None 288 | pred = S.__idx__[y[0], S.__val__[y[0]].argmax(dim=-1)] 289 | 290 | correct = (pred == y[1]).sum().item() 291 | return correct / y.size(1) if reduction == 'mean' else correct 292 | 293 | def hits_at_k(self, k, S, y, reduction='mean'): 294 | r"""Computes the hits@k of correspondence predictions. 295 | 296 | Args: 297 | k (int): The :math:`\mathrm{top}_k` predictions to consider. 298 | S (Tensor): Sparse or dense correspondence matrix of shape 299 | :obj:`[batch_size * num_nodes, num_nodes]`. 300 | y (LongTensor): Ground-truth matchings of shape 301 | :obj:`[2, num_ground_truths]`. 302 | reduction (string, optional): Specifies the reduction to apply to 303 | the output: :obj:`'mean'|'sum'`. (default: :obj:`'mean'`) 304 | """ 305 | assert reduction in ['mean', 'sum'] 306 | if not S.is_sparse: 307 | pred = S[y[0]].argsort(dim=-1, descending=True)[:, :k] 308 | else: 309 | assert S.__idx__ is not None and S.__val__ is not None 310 | perm = S.__val__[y[0]].argsort(dim=-1, descending=True)[:, :k] 311 | pred = torch.gather(S.__idx__[y[0]], -1, perm) 312 | 313 | correct = (pred == y[1].view(-1, 1)).sum().item() 314 | return correct / y.size(1) if reduction == 'mean' else correct 315 | 316 | def __repr__(self): 317 | return ('{}(\n' 318 | ' psi_1={},\n' 319 | ' psi_2={},\n' 320 | ' num_steps={}, k={}\n)').format(self.__class__.__name__, 321 | self.psi_1, self.psi_2, 322 | self.num_steps, self.k) 323 | -------------------------------------------------------------------------------- /dgmc/models/gin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear as Lin 3 | from torch_geometric.nn import GINConv 4 | 5 | from .mlp import MLP 6 | 7 | 8 | class GIN(torch.nn.Module): 9 | def __init__(self, in_channels, out_channels, num_layers, batch_norm=False, 10 | cat=True, lin=True): 11 | super(GIN, self).__init__() 12 | 13 | self.in_channels = in_channels 14 | self.num_layers = num_layers 15 | self.batch_norm = batch_norm 16 | self.cat = cat 17 | self.lin = lin 18 | 19 | self.convs = torch.nn.ModuleList() 20 | for _ in range(num_layers): 21 | mlp = MLP(in_channels, out_channels, 2, batch_norm, dropout=0.0) 22 | self.convs.append(GINConv(mlp, train_eps=True)) 23 | in_channels = out_channels 24 | 25 | if self.cat: 26 | in_channels = self.in_channels + num_layers * out_channels 27 | else: 28 | in_channels = out_channels 29 | 30 | if self.lin: 31 | self.out_channels = out_channels 32 | self.final = Lin(in_channels, out_channels) 33 | else: 34 | self.out_channels = in_channels 35 | 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | for conv in self.convs: 40 | conv.reset_parameters() 41 | if self.lin: 42 | self.final.reset_parameters() 43 | 44 | def forward(self, x, edge_index, *args): 45 | """""" 46 | xs = [x] 47 | 48 | for conv in self.convs: 49 | xs += [conv(xs[-1], edge_index)] 50 | 51 | x = torch.cat(xs, dim=-1) if self.cat else xs[-1] 52 | x = self.final(x) if self.lin else x 53 | return x 54 | 55 | def __repr__(self): 56 | return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, ' 57 | 'lin={})').format(self.__class__.__name__, self.in_channels, 58 | self.out_channels, self.num_layers, 59 | self.batch_norm, self.cat, self.lin) 60 | -------------------------------------------------------------------------------- /dgmc/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear as Lin, BatchNorm1d as BN 3 | import torch.nn.functional as F 4 | 5 | 6 | class MLP(torch.nn.Module): 7 | def __init__(self, in_channels, out_channels, num_layers, batch_norm=False, 8 | dropout=0.0): 9 | super(MLP, self).__init__() 10 | 11 | self.in_channels = in_channels 12 | self.out_channels = out_channels 13 | self.num_layers = num_layers 14 | self.batch_norm = batch_norm 15 | self.dropout = dropout 16 | 17 | self.lins = torch.nn.ModuleList() 18 | self.batch_norms = torch.nn.ModuleList() 19 | for _ in range(num_layers): 20 | self.lins.append(Lin(in_channels, out_channels)) 21 | self.batch_norms.append(BN(out_channels)) 22 | in_channels = out_channels 23 | 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | for lin, batch_norm in zip(self.lins, self.batch_norms): 28 | lin.reset_parameters() 29 | batch_norm.reset_parameters() 30 | 31 | def forward(self, x, *args): 32 | for i, (lin, bn) in enumerate(zip(self.lins, self.batch_norms)): 33 | if i == self.num_layers - 1: 34 | x = F.dropout(x, p=self.dropout, training=self.training) 35 | x = lin(x) 36 | if i < self.num_layers - 1: 37 | x = F.relu(x) 38 | x = bn(x) if self.batch_norm else x 39 | return x 40 | 41 | def __repr__(self): 42 | return '{}({}, {}, num_layers={}, batch_norm={}, dropout={})'.format( 43 | self.__class__.__name__, self.in_channels, self.out_channels, 44 | self.num_layers, self.batch_norm, self.dropout) 45 | -------------------------------------------------------------------------------- /dgmc/models/rel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear as Lin, BatchNorm1d as BN 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import MessagePassing 5 | 6 | 7 | class RelConv(MessagePassing): 8 | def __init__(self, in_channels, out_channels): 9 | super(RelConv, self).__init__(aggr='mean') 10 | 11 | self.in_channels = in_channels 12 | self.out_channels = out_channels 13 | 14 | self.lin1 = Lin(in_channels, out_channels, bias=False) 15 | self.lin2 = Lin(in_channels, out_channels, bias=False) 16 | self.root = Lin(in_channels, out_channels) 17 | 18 | self.reset_parameters() 19 | 20 | def reset_parameters(self): 21 | self.lin1.reset_parameters() 22 | self.lin2.reset_parameters() 23 | self.root.reset_parameters() 24 | 25 | def forward(self, x, edge_index): 26 | """""" 27 | self.flow = 'source_to_target' 28 | out1 = self.propagate(edge_index, x=self.lin1(x)) 29 | self.flow = 'target_to_source' 30 | out2 = self.propagate(edge_index, x=self.lin2(x)) 31 | return self.root(x) + out1 + out2 32 | 33 | def message(self, x_j): 34 | return x_j 35 | 36 | def __repr__(self): 37 | return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, 38 | self.out_channels) 39 | 40 | 41 | class RelCNN(torch.nn.Module): 42 | def __init__(self, in_channels, out_channels, num_layers, batch_norm=False, 43 | cat=True, lin=True, dropout=0.0): 44 | super(RelCNN, self).__init__() 45 | 46 | self.in_channels = in_channels 47 | self.num_layers = num_layers 48 | self.batch_norm = batch_norm 49 | self.cat = cat 50 | self.lin = lin 51 | self.dropout = dropout 52 | 53 | self.convs = torch.nn.ModuleList() 54 | self.batch_norms = torch.nn.ModuleList() 55 | for _ in range(num_layers): 56 | self.convs.append(RelConv(in_channels, out_channels)) 57 | self.batch_norms.append(BN(out_channels)) 58 | in_channels = out_channels 59 | 60 | if self.cat: 61 | in_channels = self.in_channels + num_layers * out_channels 62 | else: 63 | in_channels = out_channels 64 | 65 | if self.lin: 66 | self.out_channels = out_channels 67 | self.final = Lin(in_channels, out_channels) 68 | else: 69 | self.out_channels = in_channels 70 | 71 | self.reset_parameters() 72 | 73 | def reset_parameters(self): 74 | for conv, batch_norm in zip(self.convs, self.batch_norms): 75 | conv.reset_parameters() 76 | batch_norm.reset_parameters() 77 | if self.lin: 78 | self.final.reset_parameters() 79 | 80 | def forward(self, x, edge_index, *args): 81 | """""" 82 | xs = [x] 83 | 84 | for conv, batch_norm in zip(self.convs, self.batch_norms): 85 | x = conv(xs[-1], edge_index) 86 | x = batch_norm(F.relu(x)) if self.batch_norm else F.relu(x) 87 | x = F.dropout(x, p=self.dropout, training=self.training) 88 | xs.append(x) 89 | 90 | x = torch.cat(xs, dim=-1) if self.cat else xs[-1] 91 | x = self.final(x) if self.lin else x 92 | return x 93 | 94 | def __repr__(self): 95 | return ('{}({}, {}, num_layers={}, batch_norm={}, cat={}, lin={}, ' 96 | 'dropout={})').format(self.__class__.__name__, 97 | self.in_channels, self.out_channels, 98 | self.num_layers, self.batch_norm, 99 | self.cat, self.lin, self.dropout) 100 | -------------------------------------------------------------------------------- /dgmc/models/spline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Linear as Lin 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import SplineConv 5 | 6 | 7 | class SplineCNN(torch.nn.Module): 8 | def __init__(self, in_channels, out_channels, dim, num_layers, cat=True, 9 | lin=True, dropout=0.0): 10 | super(SplineCNN, self).__init__() 11 | 12 | self.in_channels = in_channels 13 | self.dim = dim 14 | self.num_layers = num_layers 15 | self.cat = cat 16 | self.lin = lin 17 | self.dropout = dropout 18 | 19 | self.convs = torch.nn.ModuleList() 20 | for _ in range(num_layers): 21 | conv = SplineConv(in_channels, out_channels, dim, kernel_size=5) 22 | self.convs.append(conv) 23 | in_channels = out_channels 24 | 25 | if self.cat: 26 | in_channels = self.in_channels + num_layers * out_channels 27 | else: 28 | in_channels = out_channels 29 | 30 | if self.lin: 31 | self.out_channels = out_channels 32 | self.final = Lin(in_channels, out_channels) 33 | else: 34 | self.out_channels = in_channels 35 | 36 | self.reset_parameters() 37 | 38 | def reset_parameters(self): 39 | for conv in self.convs: 40 | conv.reset_parameters() 41 | if self.lin: 42 | self.final.reset_parameters() 43 | 44 | def forward(self, x, edge_index, edge_attr, *args): 45 | """""" 46 | xs = [x] 47 | 48 | for conv in self.convs: 49 | xs += [F.relu(conv(xs[-1], edge_index, edge_attr))] 50 | 51 | x = torch.cat(xs, dim=-1) if self.cat else xs[-1] 52 | x = F.dropout(x, p=self.dropout, training=self.training) 53 | x = self.final(x) if self.lin else x 54 | return x 55 | 56 | def __repr__(self): 57 | return ('{}({}, {}, dim={}, num_layers={}, cat={}, lin={}, ' 58 | 'dropout={})').format(self.__class__.__name__, 59 | self.in_channels, self.out_channels, 60 | self.dim, self.num_layers, self.cat, 61 | self.lin, self.dropout) 62 | -------------------------------------------------------------------------------- /dgmc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import PairDataset, ValidPairDataset 2 | 3 | __all__ = [ 4 | 'PairDataset', 5 | 'ValidPairDataset', 6 | ] 7 | -------------------------------------------------------------------------------- /dgmc/utils/data.py: -------------------------------------------------------------------------------- 1 | import re 2 | from itertools import chain 3 | 4 | import torch 5 | import random 6 | from torch_geometric.data import Data 7 | 8 | 9 | class PairData(Data): # pragma: no cover 10 | def __inc__(self, key, value, *args): 11 | if bool(re.search('index_s', key)): 12 | return self.x_s.size(0) 13 | if bool(re.search('index_t', key)): 14 | return self.x_t.size(0) 15 | else: 16 | return 0 17 | 18 | 19 | class PairDataset(torch.utils.data.Dataset): 20 | r"""Combines two datasets, a source dataset and a target dataset, by 21 | building pairs between separate dataset examples. 22 | 23 | Args: 24 | dataset_s (torch.utils.data.Dataset): The source dataset. 25 | dataset_t (torch.utils.data.Dataset): The target dataset. 26 | sample (bool, optional): If set to :obj:`True`, will sample exactly 27 | one target example for every source example instead of holding the 28 | product of all source and target examples. (default: :obj:`False`) 29 | """ 30 | def __init__(self, dataset_s, dataset_t, sample=False): 31 | self.dataset_s = dataset_s 32 | self.dataset_t = dataset_t 33 | self.sample = sample 34 | 35 | def __len__(self): 36 | return len(self.dataset_s) if self.sample else len( 37 | self.dataset_s) * len(self.dataset_t) 38 | 39 | def __getitem__(self, idx): 40 | if self.sample: 41 | data_s = self.dataset_s[idx] 42 | data_t = self.dataset_t[random.randint(0, len(self.dataset_t) - 1)] 43 | else: 44 | data_s = self.dataset_s[idx // len(self.dataset_t)] 45 | data_t = self.dataset_t[idx % len(self.dataset_t)] 46 | 47 | return PairData( 48 | x_s=data_s.x, 49 | edge_index_s=data_s.edge_index, 50 | edge_attr_s=data_s.edge_attr, 51 | x_t=data_t.x, 52 | edge_index_t=data_t.edge_index, 53 | edge_attr_t=data_t.edge_attr, 54 | num_nodes=None, 55 | ) 56 | 57 | def __repr__(self): 58 | return '{}({}, {}, sample={})'.format(self.__class__.__name__, 59 | self.dataset_s, self.dataset_t, 60 | self.sample) 61 | 62 | 63 | class ValidPairDataset(torch.utils.data.Dataset): 64 | r"""Combines two datasets, a source dataset and a target dataset, by 65 | building valid pairs between separate dataset examples. 66 | A pair is valid if each node class in the source graph also exists in the 67 | target graph. 68 | 69 | Args: 70 | dataset_s (torch.utils.data.Dataset): The source dataset. 71 | dataset_t (torch.utils.data.Dataset): The target dataset. 72 | sample (bool, optional): If set to :obj:`True`, will sample exactly 73 | one target example for every source example instead of holding the 74 | product of all source and target examples. (default: :obj:`False`) 75 | """ 76 | def __init__(self, dataset_s, dataset_t, sample=False): 77 | self.dataset_s = dataset_s 78 | self.dataset_t = dataset_t 79 | self.sample = sample 80 | self.pairs, self.cumdeg = self.__compute_pairs__() 81 | 82 | def __compute_pairs__(self): 83 | num_classes = 0 84 | for data in chain(self.dataset_s, self.dataset_t): 85 | num_classes = max(num_classes, data.y.max().item() + 1) 86 | 87 | y_s = torch.zeros((len(self.dataset_s), num_classes), dtype=torch.bool) 88 | y_t = torch.zeros((len(self.dataset_t), num_classes), dtype=torch.bool) 89 | 90 | for i, data in enumerate(self.dataset_s): 91 | y_s[i, data.y] = 1 92 | for i, data in enumerate(self.dataset_t): 93 | y_t[i, data.y] = 1 94 | 95 | y_s = y_s.view(len(self.dataset_s), 1, num_classes) 96 | y_t = y_t.view(1, len(self.dataset_t), num_classes) 97 | 98 | pairs = ((y_s * y_t).sum(dim=-1) == y_s.sum(dim=-1)).nonzero() 99 | cumdeg = pairs[:, 0].bincount().cumsum(dim=0) 100 | 101 | return pairs.tolist(), [0] + cumdeg.tolist() 102 | 103 | def __len__(self): 104 | return len(self.dataset_s) if self.sample else len(self.pairs) 105 | 106 | def __getitem__(self, idx): 107 | if self.sample: 108 | data_s = self.dataset_s[idx] 109 | i = random.randint(self.cumdeg[idx], self.cumdeg[idx + 1] - 1) 110 | data_t = self.dataset_t[self.pairs[i][1]] 111 | else: 112 | data_s = self.dataset_s[self.pairs[idx][0]] 113 | data_t = self.dataset_t[self.pairs[idx][1]] 114 | 115 | y = data_s.y.new_full((data_t.y.max().item() + 1, ), -1) 116 | y[data_t.y] = torch.arange(data_t.num_nodes) 117 | y = y[data_s.y] 118 | 119 | return PairData( 120 | x_s=data_s.x, 121 | edge_index_s=data_s.edge_index, 122 | edge_attr_s=data_s.edge_attr, 123 | x_t=data_t.x, 124 | edge_index_t=data_t.edge_index, 125 | edge_attr_t=data_t.edge_attr, 126 | y=y, 127 | num_nodes=None, 128 | ) 129 | 130 | def __repr__(self): 131 | return '{}({}, {}, sample={})'.format(self.__class__.__name__, 132 | self.dataset_s, self.dataset_t, 133 | self.sample) 134 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD := sphinx-build 2 | SPHINXPROJ := dgmc 3 | SOURCEDIR := source 4 | BUILDDIR := build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" 10 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Redirect 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | https://download.pytorch.org/whl/cpu/torch-1.4.0%2Bcpu-cp37-cp37m-linux_x86_64.whl 3 | https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.4.0/torch_scatter-latest%2Bcpu-cp37-cp37m-linux_x86_64.whl 4 | https://s3.eu-central-1.amazonaws.com/pytorch-geometric.com/whl/torch-1.4.0/torch_sparse-latest%2Bcpu-cp37-cp37m-linux_x86_64.whl 5 | torch-geometric 6 | sphinx>=3 7 | sphinx_rtd_theme 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sphinx_rtd_theme 3 | import doctest 4 | import dgmc 5 | 6 | extensions = [ 7 | 'sphinx.ext.autodoc', 8 | 'sphinx.ext.doctest', 9 | 'sphinx.ext.intersphinx', 10 | 'sphinx.ext.mathjax', 11 | 'sphinx.ext.napoleon', 12 | 'sphinx.ext.viewcode', 13 | 'sphinx.ext.githubpages', 14 | ] 15 | 16 | source_suffix = '.rst' 17 | master_doc = 'index' 18 | 19 | author = 'Matthias Fey' 20 | project = 'deep-graph-matching-consensus' 21 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 22 | 23 | version = dgmc.__version__ 24 | release = dgmc.__version__ 25 | 26 | html_theme = 'sphinx_rtd_theme' 27 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 28 | 29 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE 30 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)} 31 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/rusty1s/deep-graph-matching-consensus 2 | 3 | Deep Graph Matching Consensus 4 | ============================= 5 | 6 | .. toctree:: 7 | :glob: 8 | :maxdepth: 1 9 | :caption: Package Reference 10 | 11 | modules/models 12 | modules/utils 13 | 14 | .. autoclass:: dgmc.DGMC 15 | :members: 16 | -------------------------------------------------------------------------------- /docs/source/modules/models.rst: -------------------------------------------------------------------------------- 1 | dgmc.models 2 | =========== 3 | 4 | .. autoclass:: dgmc.models.GIN 5 | :members: 6 | :undoc-members: 7 | 8 | .. autoclass:: dgmc.models.SplineCNN 9 | :members: 10 | :undoc-members: 11 | 12 | .. autoclass:: dgmc.models.RelCNN 13 | :members: 14 | :undoc-members: 15 | -------------------------------------------------------------------------------- /docs/source/modules/utils.rst: -------------------------------------------------------------------------------- 1 | dgmc.utils 2 | ========== 3 | 4 | .. autoclass:: dgmc.utils.PairDataset 5 | :members: 6 | 7 | .. autoclass:: dgmc.utils.ValidPairDataset 8 | :members: 9 | -------------------------------------------------------------------------------- /examples/dbp15k.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | from torch_geometric.datasets import DBP15K 6 | 7 | from dgmc.models import DGMC, RelCNN 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--category', type=str, required=True) 11 | parser.add_argument('--dim', type=int, default=256) 12 | parser.add_argument('--rnd_dim', type=int, default=32) 13 | parser.add_argument('--num_layers', type=int, default=3) 14 | parser.add_argument('--num_steps', type=int, default=10) 15 | parser.add_argument('--k', type=int, default=10) 16 | args = parser.parse_args() 17 | 18 | 19 | class SumEmbedding(object): 20 | def __call__(self, data): 21 | data.x1, data.x2 = data.x1.sum(dim=1), data.x2.sum(dim=1) 22 | return data 23 | 24 | 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | path = osp.join('..', 'data', 'DBP15K') 27 | data = DBP15K(path, args.category, transform=SumEmbedding())[0].to(device) 28 | 29 | psi_1 = RelCNN(data.x1.size(-1), args.dim, args.num_layers, batch_norm=False, 30 | cat=True, lin=True, dropout=0.5) 31 | psi_2 = RelCNN(args.rnd_dim, args.rnd_dim, args.num_layers, batch_norm=False, 32 | cat=True, lin=True, dropout=0.0) 33 | model = DGMC(psi_1, psi_2, num_steps=None, k=args.k).to(device) 34 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 35 | 36 | 37 | def train(): 38 | model.train() 39 | optimizer.zero_grad() 40 | 41 | _, S_L = model(data.x1, data.edge_index1, None, None, data.x2, 42 | data.edge_index2, None, None, data.train_y) 43 | 44 | loss = model.loss(S_L, data.train_y) 45 | loss.backward() 46 | optimizer.step() 47 | return loss 48 | 49 | 50 | @torch.no_grad() 51 | def test(): 52 | model.eval() 53 | 54 | _, S_L = model(data.x1, data.edge_index1, None, None, data.x2, 55 | data.edge_index2, None, None) 56 | 57 | hits1 = model.acc(S_L, data.test_y) 58 | hits10 = model.hits_at_k(10, S_L, data.test_y) 59 | 60 | return hits1, hits10 61 | 62 | 63 | print('Optimize initial feature matching...') 64 | model.num_steps = 0 65 | for epoch in range(1, 201): 66 | if epoch == 101: 67 | print('Refine correspondence matrix...') 68 | model.num_steps = args.num_steps 69 | model.detach = True 70 | 71 | loss = train() 72 | 73 | if epoch % 10 == 0 or epoch > 100: 74 | hits1, hits10 = test() 75 | print((f'{epoch:03d}: Loss: {loss:.4f}, Hits@1: {hits1:.4f}, ' 76 | f'Hits@10: {hits10:.4f}')) 77 | -------------------------------------------------------------------------------- /examples/pascal.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import argparse 4 | import torch 5 | from torch_geometric.datasets import PascalVOCKeypoints as PascalVOC 6 | import torch_geometric.transforms as T 7 | from torch_geometric.data import DataLoader 8 | 9 | from dgmc.utils import ValidPairDataset 10 | from dgmc.models import DGMC, SplineCNN 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--isotropic', action='store_true') 14 | parser.add_argument('--dim', type=int, default=256) 15 | parser.add_argument('--rnd_dim', type=int, default=128) 16 | parser.add_argument('--num_layers', type=int, default=2) 17 | parser.add_argument('--num_steps', type=int, default=10) 18 | parser.add_argument('--lr', type=float, default=0.001) 19 | parser.add_argument('--batch_size', type=int, default=512) 20 | parser.add_argument('--epochs', type=int, default=15) 21 | parser.add_argument('--test_samples', type=int, default=1000) 22 | args = parser.parse_args() 23 | 24 | pre_filter = lambda data: data.pos.size(0) > 0 # noqa 25 | transform = T.Compose([ 26 | T.Delaunay(), 27 | T.FaceToEdge(), 28 | T.Distance() if args.isotropic else T.Cartesian(), 29 | ]) 30 | 31 | train_datasets = [] 32 | test_datasets = [] 33 | path = osp.join('..', 'data', 'PascalVOC') 34 | for category in PascalVOC.categories: 35 | dataset = PascalVOC(path, category, train=True, transform=transform, 36 | pre_filter=pre_filter) 37 | train_datasets += [ValidPairDataset(dataset, dataset, sample=True)] 38 | dataset = PascalVOC(path, category, train=False, transform=transform, 39 | pre_filter=pre_filter) 40 | test_datasets += [ValidPairDataset(dataset, dataset, sample=True)] 41 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 42 | train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, 43 | follow_batch=['x_s', 'x_t']) 44 | 45 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 46 | psi_1 = SplineCNN(dataset.num_node_features, args.dim, 47 | dataset.num_edge_features, args.num_layers, cat=False, 48 | dropout=0.5) 49 | psi_2 = SplineCNN(args.rnd_dim, args.rnd_dim, dataset.num_edge_features, 50 | args.num_layers, cat=True, dropout=0.0) 51 | model = DGMC(psi_1, psi_2, num_steps=args.num_steps).to(device) 52 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 53 | 54 | 55 | def generate_y(y_col): 56 | y_row = torch.arange(y_col.size(0), device=device) 57 | return torch.stack([y_row, y_col], dim=0) 58 | 59 | 60 | def train(): 61 | model.train() 62 | 63 | total_loss = 0 64 | for data in train_loader: 65 | optimizer.zero_grad() 66 | data = data.to(device) 67 | S_0, S_L = model(data.x_s, data.edge_index_s, data.edge_attr_s, 68 | data.x_s_batch, data.x_t, data.edge_index_t, 69 | data.edge_attr_t, data.x_t_batch) 70 | y = generate_y(data.y) 71 | loss = model.loss(S_0, y) 72 | loss = model.loss(S_L, y) + loss if model.num_steps > 0 else loss 73 | loss.backward() 74 | optimizer.step() 75 | total_loss += loss.item() * (data.x_s_batch.max().item() + 1) 76 | 77 | return total_loss / len(train_loader.dataset) 78 | 79 | 80 | @torch.no_grad() 81 | def test(dataset): 82 | model.eval() 83 | 84 | loader = DataLoader(dataset, args.batch_size, shuffle=False, 85 | follow_batch=['x_s', 'x_t']) 86 | 87 | correct = num_examples = 0 88 | while (num_examples < args.test_samples): 89 | for data in loader: 90 | data = data.to(device) 91 | S_0, S_L = model(data.x_s, data.edge_index_s, data.edge_attr_s, 92 | data.x_s_batch, data.x_t, data.edge_index_t, 93 | data.edge_attr_t, data.x_t_batch) 94 | y = generate_y(data.y) 95 | correct += model.acc(S_L, y, reduction='sum') 96 | num_examples += y.size(1) 97 | 98 | if num_examples >= args.test_samples: 99 | return correct / num_examples 100 | 101 | 102 | for epoch in range(1, args.epochs + 1): 103 | loss = train() 104 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') 105 | 106 | accs = [100 * test(test_dataset) for test_dataset in test_datasets] 107 | accs += [sum(accs) / len(accs)] 108 | 109 | print(' '.join([c[:5].ljust(5) for c in PascalVOC.categories] + ['mean'])) 110 | print(' '.join([f'{acc:.1f}'.ljust(5) for acc in accs])) 111 | -------------------------------------------------------------------------------- /examples/pascal_pf.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import random 3 | 4 | import argparse 5 | import torch 6 | from torch_geometric.data import Data, DataLoader 7 | import torch_geometric.transforms as T 8 | from torch_geometric.datasets import PascalPF 9 | 10 | from dgmc.models import DGMC, SplineCNN 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--dim', type=int, default=256) 14 | parser.add_argument('--rnd_dim', type=int, default=64) 15 | parser.add_argument('--num_layers', type=int, default=2) 16 | parser.add_argument('--num_steps', type=int, default=10) 17 | parser.add_argument('--lr', type=float, default=0.001) 18 | parser.add_argument('--batch_size', type=int, default=64) 19 | parser.add_argument('--epochs', type=int, default=200) 20 | args = parser.parse_args() 21 | 22 | 23 | class RandomGraphDataset(torch.utils.data.Dataset): 24 | def __init__(self, min_inliers, max_inliers, min_outliers, max_outliers, 25 | min_scale=0.9, max_scale=1.2, noise=0.05, transform=None): 26 | 27 | self.min_inliers = min_inliers 28 | self.max_inliers = max_inliers 29 | self.min_outliers = min_outliers 30 | self.max_outliers = max_outliers 31 | self.min_scale = min_scale 32 | self.max_scale = max_scale 33 | self.noise = noise 34 | self.transform = transform 35 | 36 | def __len__(self): 37 | return 1024 38 | 39 | def __getitem__(self, idx): 40 | num_inliers = random.randint(self.min_inliers, self.max_inliers) 41 | num_outliers = random.randint(self.min_outliers, self.max_outliers) 42 | 43 | pos_s = 2 * torch.rand((num_inliers, 2)) - 1 44 | pos_t = pos_s + self.noise * torch.randn_like(pos_s) 45 | 46 | y_s = torch.arange(pos_s.size(0)) 47 | y_t = torch.arange(pos_t.size(0)) 48 | 49 | pos_s = torch.cat([pos_s, 3 - torch.rand((num_outliers, 2))], dim=0) 50 | pos_t = torch.cat([pos_t, 3 - torch.rand((num_outliers, 2))], dim=0) 51 | 52 | data_s = Data(pos=pos_s, y_index=y_s) 53 | data_t = Data(pos=pos_t, y=y_t) 54 | 55 | if self.transform is not None: 56 | data_s = self.transform(data_s) 57 | data_t = self.transform(data_t) 58 | 59 | data = Data(num_nodes=pos_s.size(0)) 60 | for key in data_s.keys: 61 | data['{}_s'.format(key)] = data_s[key] 62 | for key in data_t.keys: 63 | data['{}_t'.format(key)] = data_t[key] 64 | 65 | return data 66 | 67 | 68 | transform = T.Compose([ 69 | T.Constant(), 70 | T.KNNGraph(k=8), 71 | T.Cartesian(), 72 | ]) 73 | train_dataset = RandomGraphDataset(30, 60, 0, 20, transform=transform) 74 | train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, 75 | follow_batch=['x_s', 'x_t']) 76 | 77 | path = osp.join('..', 'data', 'PascalPF') 78 | test_datasets = [PascalPF(path, cat, transform) for cat in PascalPF.categories] 79 | 80 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 81 | psi_1 = SplineCNN(1, args.dim, 2, args.num_layers, cat=False, dropout=0.0) 82 | psi_2 = SplineCNN(args.rnd_dim, args.rnd_dim, 2, args.num_layers, cat=True, 83 | dropout=0.0) 84 | model = DGMC(psi_1, psi_2, num_steps=args.num_steps).to(device) 85 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 86 | 87 | 88 | def train(): 89 | model.train() 90 | 91 | total_loss = total_examples = total_correct = 0 92 | for i, data in enumerate(train_loader): 93 | optimizer.zero_grad() 94 | data = data.to(device) 95 | S_0, S_L = model(data.x_s, data.edge_index_s, data.edge_attr_s, 96 | data.x_s_batch, data.x_t, data.edge_index_t, 97 | data.edge_attr_t, data.x_t_batch) 98 | y = torch.stack([data.y_index_s, data.y_t], dim=0) 99 | loss = model.loss(S_0, y) 100 | loss = model.loss(S_L, y) + loss if model.num_steps > 0 else loss 101 | loss.backward() 102 | optimizer.step() 103 | total_loss += loss.item() 104 | total_correct += model.acc(S_L, y, reduction='sum') 105 | total_examples += y.size(1) 106 | 107 | return total_loss / len(train_loader), total_correct / total_examples 108 | 109 | 110 | @torch.no_grad() 111 | def test(dataset): 112 | model.eval() 113 | 114 | correct = num_examples = 0 115 | for pair in dataset.pairs: 116 | data_s, data_t = dataset[pair[0]], dataset[pair[1]] 117 | data_s, data_t = data_s.to(device), data_t.to(device) 118 | S_0, S_L = model(data_s.x, data_s.edge_index, data_s.edge_attr, None, 119 | data_t.x, data_t.edge_index, data_t.edge_attr, None) 120 | y = torch.arange(data_s.num_nodes, device=device) 121 | y = torch.stack([y, y], dim=0) 122 | correct += model.acc(S_L, y, reduction='sum') 123 | num_examples += y.size(1) 124 | 125 | return correct / num_examples 126 | 127 | 128 | for epoch in range(1, 33): 129 | loss, acc = train() 130 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {acc:.2f}') 131 | 132 | accs = [100 * test(test_dataset) for test_dataset in test_datasets] 133 | accs += [sum(accs) / len(accs)] 134 | 135 | print(' '.join([c[:5].ljust(5) for c in PascalPF.categories] + ['mean'])) 136 | print(' '.join([f'{acc:.1f}'.ljust(5) for acc in accs])) 137 | -------------------------------------------------------------------------------- /examples/willow.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os.path as osp 3 | 4 | import argparse 5 | import torch 6 | import torch_geometric.transforms as T 7 | from torch_geometric.datasets import WILLOWObjectClass as WILLOW 8 | from torch_geometric.datasets import PascalVOCKeypoints as PascalVOC 9 | from torch_geometric.data import DataLoader 10 | 11 | from dgmc.utils import ValidPairDataset, PairDataset 12 | from dgmc.models import DGMC, SplineCNN 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--isotropic', action='store_true') 16 | parser.add_argument('--dim', type=int, default=256) 17 | parser.add_argument('--rnd_dim', type=int, default=128) 18 | parser.add_argument('--num_layers', type=int, default=2) 19 | parser.add_argument('--num_steps', type=int, default=10) 20 | parser.add_argument('--lr', type=float, default=0.001) 21 | parser.add_argument('--batch_size', type=int, default=512) 22 | parser.add_argument('--pre_epochs', type=int, default=15) 23 | parser.add_argument('--epochs', type=int, default=15) 24 | parser.add_argument('--runs', type=int, default=20) 25 | parser.add_argument('--test_samples', type=int, default=100) 26 | args = parser.parse_args() 27 | 28 | pre_filter1 = lambda d: d.num_nodes > 0 # noqa 29 | pre_filter2 = lambda d: d.num_nodes > 0 and d.name[:4] != '2007' # noqa 30 | 31 | transform = T.Compose([ 32 | T.Delaunay(), 33 | T.FaceToEdge(), 34 | T.Distance() if args.isotropic else T.Cartesian(), 35 | ]) 36 | 37 | path = osp.join('..', 'data', 'PascalVOC-WILLOW') 38 | pretrain_datasets = [] 39 | for category in PascalVOC.categories: 40 | dataset = PascalVOC( 41 | path, category, train=True, transform=transform, pre_filter=pre_filter2 42 | if category in ['car', 'motorbike'] else pre_filter1) 43 | pretrain_datasets += [ValidPairDataset(dataset, dataset, sample=True)] 44 | pretrain_dataset = torch.utils.data.ConcatDataset(pretrain_datasets) 45 | pretrain_loader = DataLoader(pretrain_dataset, args.batch_size, shuffle=True, 46 | follow_batch=['x_s', 'x_t']) 47 | 48 | path = osp.join('..', 'data', 'WILLOW') 49 | datasets = [WILLOW(path, cat, transform) for cat in WILLOW.categories] 50 | 51 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 52 | psi_1 = SplineCNN(dataset.num_node_features, args.dim, 53 | dataset.num_edge_features, args.num_layers, cat=False, 54 | dropout=0.5) 55 | psi_2 = SplineCNN(args.rnd_dim, args.rnd_dim, dataset.num_edge_features, 56 | args.num_layers, cat=True, dropout=0.0) 57 | model = DGMC(psi_1, psi_2, num_steps=args.num_steps).to(device) 58 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 59 | 60 | 61 | def generate_voc_y(y_col): 62 | y_row = torch.arange(y_col.size(0), device=device) 63 | return torch.stack([y_row, y_col], dim=0) 64 | 65 | 66 | def pretrain(): 67 | model.train() 68 | 69 | total_loss = 0 70 | for data in pretrain_loader: 71 | optimizer.zero_grad() 72 | data = data.to(device) 73 | S_0, S_L = model(data.x_s, data.edge_index_s, data.edge_attr_s, 74 | data.x_s_batch, data.x_t, data.edge_index_t, 75 | data.edge_attr_t, data.x_t_batch) 76 | y = generate_voc_y(data.y) 77 | loss = model.loss(S_0, y) 78 | loss = model.loss(S_L, y) + loss if model.num_steps > 0 else loss 79 | loss.backward() 80 | optimizer.step() 81 | total_loss += loss.item() * (data.x_s_batch.max().item() + 1) 82 | 83 | return total_loss / len(pretrain_loader.dataset) 84 | 85 | 86 | print('Pretraining model on PascalVOC...') 87 | for epoch in range(1, args.pre_epochs + 1): 88 | loss = pretrain() 89 | print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}') 90 | state_dict = copy.deepcopy(model.state_dict()) 91 | print('Done!') 92 | 93 | 94 | def generate_y(num_nodes, batch_size): 95 | row = torch.arange(num_nodes * batch_size, device=device) 96 | col = row[:num_nodes].view(1, -1).repeat(batch_size, 1).view(-1) 97 | return torch.stack([row, col], dim=0) 98 | 99 | 100 | def train(train_loader, optimizer): 101 | model.train() 102 | 103 | total_loss = 0 104 | for data in train_loader: 105 | optimizer.zero_grad() 106 | data = data.to(device) 107 | S_0, S_L = model(data.x_s, data.edge_index_s, data.edge_attr_s, 108 | data.x_s_batch, data.x_t, data.edge_index_t, 109 | data.edge_attr_t, data.x_t_batch) 110 | num_graphs = data.x_s_batch.max().item() + 1 111 | y = generate_y(num_nodes=10, batch_size=num_graphs) 112 | loss = model.loss(S_0, y) 113 | loss = model.loss(S_L, y) + loss if model.num_steps > 0 else loss 114 | loss.backward() 115 | optimizer.step() 116 | total_loss += loss.item() * num_graphs 117 | 118 | return total_loss / len(train_loader.dataset) 119 | 120 | 121 | @torch.no_grad() 122 | def test(test_dataset): 123 | model.eval() 124 | 125 | test_loader1 = DataLoader(test_dataset, args.batch_size, shuffle=True) 126 | test_loader2 = DataLoader(test_dataset, args.batch_size, shuffle=True) 127 | 128 | correct = num_examples = 0 129 | while (num_examples < args.test_samples): 130 | for data_s, data_t in zip(test_loader1, test_loader2): 131 | data_s, data_t = data_s.to(device), data_t.to(device) 132 | _, S_L = model(data_s.x, data_s.edge_index, data_s.edge_attr, 133 | data_s.batch, data_t.x, data_t.edge_index, 134 | data_t.edge_attr, data_t.batch) 135 | y = generate_y(num_nodes=10, batch_size=data_t.num_graphs) 136 | correct += model.acc(S_L, y, reduction='sum') 137 | num_examples += y.size(1) 138 | 139 | if num_examples >= args.test_samples: 140 | return correct / num_examples 141 | 142 | 143 | def run(i, datasets): 144 | datasets = [dataset.shuffle() for dataset in datasets] 145 | train_datasets = [dataset[:20] for dataset in datasets] 146 | test_datasets = [dataset[20:] for dataset in datasets] 147 | train_datasets = [ 148 | PairDataset(train_dataset, train_dataset, sample=False) 149 | for train_dataset in train_datasets 150 | ] 151 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 152 | train_loader = DataLoader(train_dataset, args.batch_size, shuffle=True, 153 | follow_batch=['x_s', 'x_t']) 154 | 155 | model.load_state_dict(state_dict) 156 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 157 | 158 | for epoch in range(1, 1 + args.epochs): 159 | train(train_loader, optimizer) 160 | 161 | accs = [100 * test(test_dataset) for test_dataset in test_datasets] 162 | 163 | print(f'Run {i:02d}:') 164 | print(' '.join([category.ljust(13) for category in WILLOW.categories])) 165 | print(' '.join([f'{acc:.2f}'.ljust(13) for acc in accs])) 166 | 167 | return accs 168 | 169 | 170 | accs = [run(i, datasets) for i in range(1, 1 + args.runs)] 171 | print('-' * 14 * 5) 172 | accs, stds = torch.tensor(accs).mean(dim=0), torch.tensor(accs).std(dim=0) 173 | print(' '.join([category.ljust(13) for category in WILLOW.categories])) 174 | print(' '.join([f'{a:.2f} ± {s:.2f}'.ljust(13) for a, s in zip(accs, stds)])) 175 | -------------------------------------------------------------------------------- /figures/best_car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/figures/best_car.png -------------------------------------------------------------------------------- /figures/best_duck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/figures/best_duck.png -------------------------------------------------------------------------------- /figures/best_motorbike.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/figures/best_motorbike.png -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/figures/overview.png -------------------------------------------------------------------------------- /figures/worst_duck.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/deep-graph-matching-consensus/cb99870b11755ce610ec8dcd6275f3f1d342f894/figures/worst_duck.png -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.7 8 | system_packages: true 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: setuptools 12 | path: . 13 | 14 | formats: [] 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [aliases] 5 | test = pytest 6 | 7 | [tool:pytest] 8 | addopts = --capture=no --cov 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = '1.0.0' 4 | url = 'https://github.com/rusty1s/deep-graph-matching-consensus' 5 | 6 | install_requires = [] 7 | setup_requires = ['pytest-runner'] 8 | tests_require = ['pytest', 'pytest-cov'] 9 | 10 | setup( 11 | name='dgmc', 12 | version=__version__, 13 | description='Implementation of Deep Graph Matching Consensus in PyTorch', 14 | author='Matthias Fey', 15 | author_email='matthias.fey@tu-dortmund.de', 16 | url=url, 17 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), 18 | keywords=[ 19 | 'pytorch', 20 | 'geometric-deep-learning', 21 | 'graph-neural-networks', 22 | 'graph-matching', 23 | 'neighborhood-consensus', 24 | ], 25 | install_requires=install_requires, 26 | setup_requires=setup_requires, 27 | tests_require=tests_require, 28 | packages=find_packages(), 29 | ) 30 | -------------------------------------------------------------------------------- /test/models/test_dgmc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data, Batch 3 | from dgmc.models import DGMC, GIN 4 | 5 | x = torch.randn(4, 32) 6 | edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) 7 | data = Data(x=x, edge_index=edge_index) 8 | 9 | psi_1 = GIN(data.num_node_features, 16, num_layers=2) 10 | psi_2 = GIN(8, 8, num_layers=2) 11 | 12 | 13 | def set_seed(): 14 | torch.manual_seed(12345) 15 | 16 | 17 | def test_dgmc_repr(): 18 | model = DGMC(psi_1, psi_2, num_steps=1) 19 | assert model.__repr__() == ( 20 | 'DGMC(\n' 21 | ' psi_1=GIN(32, 16, num_layers=2, batch_norm=False, cat=True, ' 22 | 'lin=True),\n' 23 | ' psi_2=GIN(8, 8, num_layers=2, batch_norm=False, cat=True, ' 24 | 'lin=True),\n' 25 | ' num_steps=1, k=-1\n)') 26 | model.reset_parameters() 27 | 28 | 29 | def test_dgmc_on_single_graphs(): 30 | set_seed() 31 | model = DGMC(psi_1, psi_2, num_steps=1) 32 | x, e = data.x, data.edge_index 33 | y = torch.arange(data.num_nodes) 34 | y = torch.stack([y, y], dim=0) 35 | 36 | set_seed() 37 | S1_0, S1_L = model(x, e, None, None, x, e, None, None) 38 | loss1 = model.loss(S1_0, y) 39 | loss1.backward() 40 | acc1 = model.acc(S1_0, y) 41 | hits1_1 = model.hits_at_k(1, S1_0, y) 42 | hits1_10 = model.hits_at_k(10, S1_0, y) 43 | hits1_all = model.hits_at_k(data.num_nodes, S1_0, y) 44 | 45 | set_seed() 46 | model.k = data.num_nodes # Test a sparse "dense" variant. 47 | y = torch.arange(data.num_nodes) 48 | y = torch.stack([y, y], dim=0) 49 | S2_0, S2_L = model(x, e, None, None, x, e, None, None, y) 50 | loss2 = model.loss(S2_0, y) 51 | loss2.backward() 52 | acc2 = model.acc(S2_0, y) 53 | hits2_1 = model.hits_at_k(1, S2_0, y) 54 | hits2_10 = model.hits_at_k(10, S2_0, y) 55 | hits2_all = model.hits_at_k(data.num_nodes, S2_0, y) 56 | 57 | assert S1_0.size() == (data.num_nodes, data.num_nodes) 58 | assert S1_L.size() == (data.num_nodes, data.num_nodes) 59 | assert torch.allclose(S1_0, S2_0.to_dense()) 60 | assert torch.allclose(S1_L, S2_L.to_dense()) 61 | assert torch.allclose(loss1, loss2) 62 | assert acc1 == acc2 == hits1_1 == hits2_1 63 | assert hits1_1 <= hits1_10 == hits2_10 <= hits1_all 64 | assert hits1_all == hits2_all == 1.0 65 | 66 | 67 | def test_dgmc_on_multiple_graphs(): 68 | set_seed() 69 | model = DGMC(psi_1, psi_2, num_steps=1) 70 | 71 | batch = Batch.from_data_list([data, data]) 72 | x, e, b = batch.x, batch.edge_index, batch.batch 73 | 74 | set_seed() 75 | S1_0, S1_L = model(x, e, None, b, x, e, None, b) 76 | assert S1_0.size() == (batch.num_nodes, data.num_nodes) 77 | assert S1_L.size() == (batch.num_nodes, data.num_nodes) 78 | 79 | set_seed() 80 | model.k = data.num_nodes # Test a sparse "dense" variant. 81 | S2_0, S2_L = model(x, e, None, b, x, e, None, b) 82 | 83 | assert torch.allclose(S1_0, S2_0.to_dense()) 84 | assert torch.allclose(S1_L, S2_L.to_dense()) 85 | 86 | 87 | def test_dgmc_include_gt(): 88 | model = DGMC(psi_1, psi_2, num_steps=1) 89 | 90 | S_idx = torch.tensor([[[0, 1], [1, 2]], [[1, 2], [0, 1]]]) 91 | s_mask = torch.tensor([[True, False], [True, True]]) 92 | y = torch.tensor([[0, 1], [0, 0]]) 93 | 94 | S_idx = model.__include_gt__(S_idx, s_mask, y) 95 | assert S_idx.tolist() == [[[0, 1], [1, 2]], [[1, 0], [0, 1]]] 96 | -------------------------------------------------------------------------------- /test/models/test_gin.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import torch 4 | from dgmc.models import GIN 5 | 6 | 7 | def test_gin(): 8 | model = GIN(16, 32, num_layers=2, batch_norm=True, cat=True, lin=True) 9 | assert model.__repr__() == ('GIN(16, 32, num_layers=2, batch_norm=True, ' 10 | 'cat=True, lin=True)') 11 | 12 | x = torch.randn(100, 16) 13 | edge_index = torch.randint(100, (2, 400), dtype=torch.long) 14 | for cat, lin in product([False, True], [False, True]): 15 | model = GIN(16, 32, 2, True, cat, lin) 16 | out = model(x, edge_index) 17 | assert out.size() == (100, 16 + 2 * 32 if not lin and cat else 32) 18 | assert out.size() == (100, model.out_channels) 19 | -------------------------------------------------------------------------------- /test/models/test_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dgmc.models import MLP 3 | 4 | 5 | def test_mlp(): 6 | model = MLP(16, 32, num_layers=2, batch_norm=True, dropout=0.5) 7 | assert model.__repr__() == ('MLP(16, 32, num_layers=2, batch_norm=True' 8 | ', dropout=0.5)') 9 | 10 | x = torch.randn(100, 16) 11 | out = model(x) 12 | assert out.size() == (100, 32) 13 | -------------------------------------------------------------------------------- /test/models/test_rel.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import torch 4 | from dgmc.models import RelCNN 5 | 6 | 7 | def test_rel(): 8 | model = RelCNN(16, 32, num_layers=2, batch_norm=True, cat=True, lin=True, 9 | dropout=0.5) 10 | assert model.__repr__() == ('RelCNN(16, 32, num_layers=2, batch_norm=True' 11 | ', cat=True, lin=True, dropout=0.5)') 12 | assert model.convs[0].__repr__() == 'RelConv(16, 32)' 13 | 14 | x = torch.randn(100, 16) 15 | edge_index = torch.randint(100, (2, 400), dtype=torch.long) 16 | for cat, lin in product([False, True], [False, True]): 17 | model = RelCNN(16, 32, 2, True, cat, lin, 0.5) 18 | out = model(x, edge_index) 19 | assert out.size() == (100, 16 + 2 * 32 if not lin and cat else 32) 20 | assert out.size() == (100, model.out_channels) 21 | -------------------------------------------------------------------------------- /test/models/test_spline.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import torch 4 | from dgmc.models import SplineCNN 5 | 6 | 7 | def test_spline(): 8 | model = SplineCNN(16, 32, dim=3, num_layers=2, cat=True, lin=True, 9 | dropout=0.5) 10 | assert model.__repr__() == ('SplineCNN(16, 32, dim=3, num_layers=2, ' 11 | 'cat=True, lin=True, dropout=0.5)') 12 | 13 | x = torch.randn(100, 16) 14 | edge_index = torch.randint(100, (2, 400), dtype=torch.long) 15 | edge_attr = torch.rand((400, 3)) 16 | for cat, lin in product([False, True], [False, True]): 17 | model = SplineCNN(16, 32, 3, 2, cat, lin, 0.5) 18 | out = model(x, edge_index, edge_attr) 19 | assert out.size() == (100, 16 + 2 * 32 if not lin and cat else 32) 20 | assert out.size() == (100, model.out_channels) 21 | -------------------------------------------------------------------------------- /test/utils/test_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_geometric.data import Data 3 | from dgmc.utils import PairDataset, ValidPairDataset 4 | 5 | 6 | def test_pair_dataset(): 7 | x = torch.randn(10, 16) 8 | edge_index = torch.randint(x.size(0), (2, 30), dtype=torch.long) 9 | data = Data(x=x, edge_index=edge_index) 10 | 11 | dataset = PairDataset([data, data], [data, data], sample=True) 12 | assert dataset.__repr__() == ( 13 | 'PairDataset([Data(edge_index=[2, 30], x=[10, 16]), ' 14 | 'Data(edge_index=[2, 30], x=[10, 16])], [' 15 | 'Data(edge_index=[2, 30], x=[10, 16]), ' 16 | 'Data(edge_index=[2, 30], x=[10, 16])], sample=True)') 17 | assert len(dataset) == 2 18 | pair = dataset[0] 19 | assert len(pair) == 4 20 | assert torch.allclose(pair.x_s, x) 21 | assert pair.edge_index_s.tolist() == edge_index.tolist() 22 | assert torch.allclose(pair.x_t, x) 23 | assert pair.edge_index_t.tolist() == edge_index.tolist() 24 | 25 | dataset = PairDataset([data, data], [data, data], sample=False) 26 | assert dataset.__repr__() == ( 27 | 'PairDataset([Data(edge_index=[2, 30], x=[10, 16]), ' 28 | 'Data(edge_index=[2, 30], x=[10, 16])], [' 29 | 'Data(edge_index=[2, 30], x=[10, 16]), ' 30 | 'Data(edge_index=[2, 30], x=[10, 16])], sample=False)') 31 | assert len(dataset) == 4 32 | pair = dataset[0] 33 | assert len(pair) == 4 34 | assert torch.allclose(pair.x_s, x) 35 | assert pair.edge_index_s.tolist() == edge_index.tolist() 36 | assert torch.allclose(pair.x_t, x) 37 | assert pair.edge_index_t.tolist() == edge_index.tolist() 38 | 39 | 40 | def test_valid_pair_dataset(): 41 | x = torch.randn(10, 16) 42 | edge_index = torch.randint(x.size(0), (2, 30), dtype=torch.long) 43 | y = torch.randperm(x.size(0)) 44 | data = Data(x=x, edge_index=edge_index, y=y) 45 | 46 | dataset = ValidPairDataset([data, data], [data, data], sample=True) 47 | assert dataset.__repr__() == ( 48 | 'ValidPairDataset([Data(edge_index=[2, 30], x=[10, 16], y=[10]), ' 49 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10])], [' 50 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10]), ' 51 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10])], sample=True)') 52 | assert len(dataset) == 2 53 | pair = dataset[0] 54 | assert len(pair) == 5 55 | assert torch.allclose(pair.x_s, x) 56 | assert pair.edge_index_s.tolist() == edge_index.tolist() 57 | assert torch.allclose(pair.x_t, x) 58 | assert pair.edge_index_t.tolist() == edge_index.tolist() 59 | assert pair.y.tolist() == torch.arange(x.size(0)).tolist() 60 | 61 | dataset = ValidPairDataset([data, data], [data, data], sample=False) 62 | assert dataset.__repr__() == ( 63 | 'ValidPairDataset([Data(edge_index=[2, 30], x=[10, 16], y=[10]), ' 64 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10])], [' 65 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10]), ' 66 | 'Data(edge_index=[2, 30], x=[10, 16], y=[10])], sample=False)') 67 | assert len(dataset) == 4 68 | pair = dataset[0] 69 | assert len(pair) == 5 70 | assert torch.allclose(pair.x_s, x) 71 | assert pair.edge_index_s.tolist() == edge_index.tolist() 72 | assert torch.allclose(pair.x_t, x) 73 | assert pair.edge_index_t.tolist() == edge_index.tolist() 74 | assert pair.y.tolist() == torch.arange(x.size(0)).tolist() 75 | --------------------------------------------------------------------------------