├── images ├── match_aero.png ├── match_chair.png └── match_sheep.png ├── data ├── split │ └── voc2011_pairs.npz ├── base_dataset.py ├── data_loader_multigraph.py ├── SPair71k.py ├── willow_obj.py └── pascal_voc.py ├── experiments ├── voc_unfiltered.json ├── willow │ ├── pretrain_nofinetune.json │ ├── nopretrain_finetune.json │ ├── pretrain_finetune.json │ └── voc_pretrain.json ├── spair.json ├── voc_unfiltered_multimatching.json └── voc_basic.json ├── Pipfile ├── BB_GM ├── affinity_layer.py ├── sconv_archs.py └── model.py ├── utils ├── decorators.py ├── dup_stdout_manager.py ├── build_graphs.py ├── backbone.py ├── feature_align.py ├── config.py ├── evaluation_metric.py ├── utils.py ├── latex_utils.py └── visualization.py ├── download_data.sh ├── .gitignore ├── README.md ├── eval.py └── train_eval.py /images/match_aero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/blackbox-deep-graph-matching/HEAD/images/match_aero.png -------------------------------------------------------------------------------- /images/match_chair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/blackbox-deep-graph-matching/HEAD/images/match_chair.png -------------------------------------------------------------------------------- /images/match_sheep.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/blackbox-deep-graph-matching/HEAD/images/match_sheep.png -------------------------------------------------------------------------------- /data/split/voc2011_pairs.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martius-lab/blackbox-deep-graph-matching/HEAD/data/split/voc2011_pairs.npz -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | class BaseDataset: 2 | def __init__(self): 3 | pass 4 | 5 | def get_k_samples(self, idx, k, mode, cls=None, shuffle=True): 6 | raise NotImplementedError 7 | -------------------------------------------------------------------------------- /experiments/voc_unfiltered.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "train_sampling": "all", 4 | "eval_sampling": "all", 5 | "visualize": true, 6 | "model_dir": "results/voc_all_keypoints" 7 | } 8 | -------------------------------------------------------------------------------- /experiments/willow/pretrain_nofinetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "DATASET_NAME": "WillowObject", 4 | "evaluate_only": true, 5 | "warmstart_path": "results/cached_weights/params/0010", 6 | "EVAL": { 7 | "SAMPLES": 100 8 | } 9 | } -------------------------------------------------------------------------------- /experiments/willow/nopretrain_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "DATASET_NAME": "WillowObject", 4 | "TRAIN": { 5 | "EPOCH_ITERS": 100, 6 | "lr_schedule": "short_halving" 7 | }, 8 | "EVAL": { 9 | "SAMPLES": 100 10 | } 11 | } -------------------------------------------------------------------------------- /experiments/willow/pretrain_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "DATASET_NAME": "WillowObject", 4 | "TRAIN": { 5 | "EPOCH_ITERS": 200, 6 | "lr_schedule": "short_halving" 7 | }, 8 | "warmstart_path": "results/cached_weights/params/0010", 9 | "EVAL": { 10 | "SAMPLES": 100 11 | } 12 | } -------------------------------------------------------------------------------- /experiments/willow/voc_pretrain.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "model_dir": "results/cached_weights", 4 | "save_checkpoint": true, 5 | "exclude_willow_classes": true, 6 | "TRAIN": { 7 | "lr_schedule": "long_nodrop" 8 | }, 9 | "EVAL": { 10 | "SAMPLES": 50 11 | }, 12 | "visualize": false 13 | } -------------------------------------------------------------------------------- /experiments/spair.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "DATASET_NAME": "SPair71k", 4 | "model_dir": "results/spair", 5 | "TRAIN": { 6 | "EPOCH_ITERS": 400, 7 | "difficulty_params": { 8 | } 9 | }, 10 | "EVAL": { 11 | "SAMPLES": null, 12 | "difficulty_params": { 13 | } 14 | }, 15 | "visualize": true, 16 | "visualization_params": { 17 | "reduced_vis": true, 18 | "produce_pdf": false 19 | }, 20 | "combine_classes": false 21 | } -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | pillow = "<7" 10 | scipy = "==1.4.1" 11 | easydict = "==1.9" 12 | torch = "==1.4.0" 13 | torchvision = "==0.5.0" 14 | lpmp_py = {git = "https://github.com/lpmp/lpmp.git"} 15 | torch-geometric = {git = "https://github.com/rusty1s/pytorch_geometric.git",ref = "master"} 16 | torch-sparse = "*" 17 | torch-scatter = "*" 18 | torch-cluster = "*" 19 | torch-spline-conv = "*" 20 | 21 | [requires] 22 | python_version = "3.6" 23 | 24 | 25 | 26 | 27 | -------------------------------------------------------------------------------- /BB_GM/affinity_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class InnerProductWithWeightsAffinity(nn.Module): 6 | def __init__(self, input_dim, output_dim): 7 | super(InnerProductWithWeightsAffinity, self).__init__() 8 | self.d = output_dim 9 | self.A = torch.nn.Linear(input_dim, output_dim) 10 | 11 | def _forward(self, X, Y, weights): 12 | assert X.shape[1] == Y.shape[1] == self.d, (X.shape[1], Y.shape[1], self.d) 13 | coefficients = torch.tanh(self.A(weights)) 14 | res = torch.matmul(X * coefficients, Y.transpose(0, 1)) 15 | res = torch.nn.functional.softplus(res) - 0.5 16 | return res 17 | 18 | def forward(self, Xs, Ys, Ws): 19 | return [self._forward(X, Y, W) for X, Y, W in zip(Xs, Ys, Ws)] 20 | -------------------------------------------------------------------------------- /experiments/voc_unfiltered_multimatching.json: -------------------------------------------------------------------------------- 1 | { 2 | "default_json": "experiments/voc_basic.json", 3 | "train_sampling": "all", 4 | "eval_sampling": "all", 5 | "BB_GM": { 6 | "solver_name": "multigraph", 7 | "solver_params": { 8 | "maxIter": 200, 9 | "innerIteration": 10, 10 | "presolveIterations": 30, 11 | "primalCheckingTriplets": 100, 12 | "multigraphMatchingRoundingMethod": "MCF_PS", 13 | "tighten": "", 14 | "tightenIteration": 50, 15 | "tightenInterval": 20, 16 | "tightenConstraintsPercentage": 0.1, 17 | "tightenReparametrization": "uniform:0.5" 18 | } 19 | }, 20 | "EVAL": { 21 | "num_graphs_in_matching_instance": 5 22 | }, 23 | "TRAIN": { 24 | "num_graphs_in_matching_instance": 2 25 | }, 26 | "model_dir": "results/voc_all_multimatching" 27 | } 28 | -------------------------------------------------------------------------------- /experiments/voc_basic.json: -------------------------------------------------------------------------------- 1 | { 2 | "BATCH_SIZE": 8, 3 | "DATASET_NAME": "PascalVOC", 4 | "exclude_willow_classes": false, 5 | "EVAL": { 6 | "SAMPLES": 1000, 7 | "num_graphs_in_matching_instance": 2 8 | }, 9 | "BB_GM": { 10 | "lambda_val": 80.0, 11 | "solver_name": "lpmp", 12 | "solver_params": { 13 | "timeout": 1000, 14 | "primalComputationInterval": 10, 15 | "maxIter": 100 16 | } 17 | }, 18 | "train_sampling": "intersection", 19 | "eval_sampling": "intersection", 20 | "save_checkpoint": false, 21 | "RANDOM_SEED": 123, 22 | "warmstart_path": null, 23 | "TRAIN": { 24 | "EPOCH_ITERS": 2000, 25 | "LR": 0.002, 26 | "lr_schedule": "long_halving", 27 | "num_graphs_in_matching_instance": 2 28 | }, 29 | "STATISTIC_STEP": 20, 30 | "visualize": true, 31 | "visualization_params": { 32 | "reduced_vis": true 33 | }, 34 | "evaluate_only": false, 35 | "model_dir": "results/voc_basic" 36 | } 37 | -------------------------------------------------------------------------------- /utils/decorators.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from functools import update_wrapper, partial 3 | 4 | import torch 5 | 6 | 7 | class Decorator(ABC): 8 | def __init__(self, f): 9 | self.func = f 10 | update_wrapper(self, f, updated=[]) # updated=[] so that 'self' attributes are not overwritten 11 | 12 | @abstractmethod 13 | def __call__(self, *args, **kwargs): 14 | pass 15 | 16 | def __get__(self, instance, owner): 17 | new_f = partial(self.__call__, instance) 18 | update_wrapper(new_f, self.func) 19 | return new_f 20 | 21 | 22 | def to_numpy(x): 23 | if isinstance(x, torch.Tensor): 24 | return x.cpu().detach().numpy() 25 | elif isinstance(x, list): 26 | return [to_numpy(_) for _ in x] 27 | else: 28 | return x 29 | 30 | 31 | # noinspection PyPep8Naming 32 | class input_to_numpy(Decorator): 33 | def __call__(self, *args, **kwargs): 34 | new_args = [to_numpy(arg) for arg in args] 35 | new_kwargs = {key: to_numpy(value) for key, value in kwargs.items()} 36 | return self.func(*new_args, **new_kwargs) 37 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | mkdir ./data/downloaded 5 | mkdir ./data/downloaded/PascalVOC 6 | mkdir ./data/downloaded/WILLOW 7 | mkdir ./data/downloaded/SPair-71k 8 | 9 | cd ./data/downloaded/PascalVOC 10 | 11 | echo -e "\e[1mGetting PascalVOC annotations...\e[0m" 12 | wget https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/shape/poselets/voc2011_keypoints_Feb2012.tgz 13 | tar xzf voc2011_keypoints_Feb2012.tgz 14 | echo -e "\e[32m... done\e[0m" 15 | 16 | echo -e "\e[1mGetting PascalVOC data\e[0m" 17 | wget http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar 18 | tar xf VOCtrainval_25-May-2011.tar 19 | mv TrainVal/VOCdevkit/VOC2011 ./ 20 | rmdir TrainVal/VOCdevkit 21 | rmdir TrainVal 22 | echo -e "\e[32m... done\e[0m" 23 | 24 | echo -e "\e[1mGetting WILLOW data\e[0m" 25 | cd ../WILLOW 26 | wget http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip 27 | unzip WILLOW-ObjectClass_dataset.zip 28 | echo -e "\e[32m... done\e[0m" 29 | 30 | 31 | echo -e "\e[1mGetting SPair-71k data\e[0m" 32 | cd ../SPair-71k 33 | wget http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz 34 | tar xzf SPair-71k.tar.gz 35 | mv SPair-71k/* ./ 36 | echo -e "\e[32m... done\e[0m" -------------------------------------------------------------------------------- /utils/dup_stdout_manager.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class DupStdoutFileWriter(object): 5 | def __init__(self, stdout, path, mode): 6 | self.path = path 7 | self._content = "" 8 | self._stdout = stdout 9 | self._file = open(path, mode) 10 | 11 | def write(self, msg): 12 | while "\n" in msg: 13 | pos = msg.find("\n") 14 | self._content += msg[: pos + 1] 15 | self.flush() 16 | msg = msg[pos + 1 :] 17 | self._content += msg 18 | if len(self._content) > 1000: 19 | self.flush() 20 | 21 | def flush(self): 22 | self._stdout.write(self._content) 23 | self._stdout.flush() 24 | self._file.write(self._content) 25 | self._file.flush() 26 | self._content = "" 27 | 28 | def __del__(self): 29 | self._file.close() 30 | 31 | 32 | class DupStdoutFileManager(object): 33 | def __init__(self, path, mode="w+"): 34 | self.path = path 35 | self.mode = mode 36 | 37 | def __enter__(self): 38 | self._stdout = sys.stdout 39 | self._file = DupStdoutFileWriter(self._stdout, self.path, self.mode) 40 | sys.stdout = self._file 41 | 42 | def __exit__(self, exc_type, exc_value, traceback): 43 | sys.stdout = self._stdout 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | data/downloaded 107 | 108 | data/cache 109 | 110 | output/ 111 | 112 | \.idea/ 113 | 114 | results/ -------------------------------------------------------------------------------- /utils/build_graphs.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import Delaunay 2 | from scipy.spatial.qhull import QhullError 3 | 4 | import itertools 5 | import numpy as np 6 | 7 | 8 | def locations_to_features_diffs(x_1, y_1, x_2, y_2): 9 | res = np.array([0.5 + 0.5 * (x_1 - x_2) / 256.0, 0.5 + 0.5 * (y_1 - y_2) / 256.0]) 10 | return res 11 | 12 | 13 | def build_graphs(P_np: np.ndarray, n: int, n_pad: int = None, edge_pad: int = None): 14 | 15 | A = delaunay_triangulate(P_np[0:n, :]) 16 | edge_num = int(np.sum(A, axis=(0, 1))) 17 | 18 | if n_pad is None: 19 | n_pad = n 20 | if edge_pad is None: 21 | edge_pad = edge_num 22 | assert n_pad >= n 23 | assert edge_pad >= edge_num 24 | 25 | edge_list = [[], []] 26 | features = [] 27 | for i in range(n): 28 | for j in range(n): 29 | if A[i, j] == 1: 30 | edge_list[0].append(i) 31 | edge_list[1].append(j) 32 | features.append(locations_to_features_diffs(*P_np[i], *P_np[j])) 33 | 34 | if not features: 35 | features = np.zeros(shape=(0, 2)) 36 | 37 | return np.array(edge_list, dtype=np.int), np.array(features) 38 | 39 | 40 | def delaunay_triangulate(P: np.ndarray): 41 | """ 42 | Perform delaunay triangulation on point set P. 43 | :param P: point set 44 | :return: adjacency matrix A 45 | """ 46 | n = P.shape[0] 47 | if n < 3: 48 | A = np.ones((n, n)) - np.eye(n) 49 | else: 50 | try: 51 | d = Delaunay(P) 52 | A = np.zeros((n, n)) 53 | for simplex in d.simplices: 54 | for pair in itertools.permutations(simplex, 2): 55 | A[pair] = 1 56 | except QhullError as err: 57 | print("Delaunay triangulation error detected. Return fully-connected graph.") 58 | print("Traceback:") 59 | print(err) 60 | A = np.ones((n, n)) - np.eye(n) 61 | return A 62 | -------------------------------------------------------------------------------- /utils/backbone.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision import models 3 | 4 | 5 | class VGG16_base(nn.Module): 6 | def __init__(self, batch_norm=True): 7 | super(VGG16_base, self).__init__() 8 | self.node_layers, self.edge_layers, self.final_layers = self.get_backbone(batch_norm) 9 | 10 | def forward(self, *input): 11 | raise NotImplementedError 12 | 13 | @staticmethod 14 | def get_backbone(batch_norm): 15 | """ 16 | Get pretrained VGG16 models for feature extraction. 17 | :return: feature sequence 18 | """ 19 | if batch_norm: 20 | model = models.vgg16_bn(pretrained=True) 21 | else: 22 | model = models.vgg16(pretrained=True) 23 | 24 | conv_layers = nn.Sequential(*list(model.features.children())) 25 | 26 | conv_list = node_list = edge_list = [] 27 | 28 | # get the output of relu4_2(node features) and relu5_1(edge features) 29 | cnt_m, cnt_r = 1, 0 30 | for layer, module in enumerate(conv_layers): 31 | if isinstance(module, nn.Conv2d): 32 | cnt_r += 1 33 | if isinstance(module, nn.MaxPool2d): 34 | cnt_r = 0 35 | cnt_m += 1 36 | conv_list += [module] 37 | 38 | if cnt_m == 4 and cnt_r == 2 and isinstance(module, nn.ReLU): 39 | node_list = conv_list 40 | conv_list = [] 41 | elif cnt_m == 5 and cnt_r == 1 and isinstance(module, nn.ReLU): 42 | edge_list = conv_list 43 | conv_list = [] 44 | 45 | assert len(node_list) > 0 and len(edge_list) > 0 46 | 47 | # Set the layers as a nn.Sequential module 48 | node_layers = nn.Sequential(*node_list) 49 | edge_layers = nn.Sequential(*edge_list) 50 | final_layers = nn.Sequential(*conv_list, nn.AdaptiveMaxPool2d(1, 1)) 51 | 52 | return node_layers, edge_layers, final_layers 53 | 54 | 55 | class VGG16_bn(VGG16_base): 56 | def __init__(self): 57 | super(VGG16_bn, self).__init__(True) 58 | -------------------------------------------------------------------------------- /utils/feature_align.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def feature_align(raw_feature: Tensor, P: Tensor, ns_t: Tensor, ori_size: tuple, device=None): 6 | """ 7 | Perform feature align from the raw feature map. 8 | :param raw_feature: raw feature map 9 | :param P: point set containing point coordinates 10 | :param ns_t: number of exact points in the point set 11 | :param ori_size: size of the original image 12 | :param device: device. If not specified, it will be the same as the input 13 | :return: F 14 | """ 15 | 16 | if device is None: 17 | device = raw_feature.device 18 | 19 | f_dim = raw_feature.shape[-1] 20 | ori_size_t = torch.tensor(ori_size, dtype=torch.float32, device=device) 21 | step = ori_size_t[0] / f_dim 22 | 23 | channel_num = raw_feature.shape[1] 24 | n_max = P.shape[1] 25 | bs = raw_feature.shape[0] 26 | 27 | p_calc = (P - step / 2) / step 28 | p_floor = p_calc.floor() 29 | shifts = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], device=device) 30 | p_shifted = torch.stack([p_floor + shift for shift in shifts]) 31 | 32 | p_shifted_clamped = p_shifted.clamp(0, f_dim - 1) 33 | p_shifted_flat = p_shifted_clamped[..., 1] * f_dim + p_shifted_clamped[..., 0] 34 | 35 | w_feat = 1 - (p_calc - p_shifted).abs() 36 | w_feat_mul = w_feat[..., 0] * w_feat[..., 1] 37 | 38 | raw_features_flat = raw_feature.flatten(2, 3) 39 | 40 | # mask to disregard information in keypoints that don't matter (meaning that for the given image the number of keypoints is smaller than the maximum number in the batch) 41 | mask = torch.zeros(bs, n_max, device=device) 42 | for i in range(bs): 43 | mask[i][0 : ns_t[i]] = 1 44 | mask = mask.unsqueeze(1).expand(bs, channel_num, n_max) 45 | 46 | raw_f_exp = raw_features_flat.unsqueeze(0).expand(4, bs, channel_num, f_dim ** 2) 47 | p_flat_exp = p_shifted_flat.unsqueeze(2).expand(4, bs, channel_num, n_max).long() 48 | features = raw_f_exp.gather(3, p_flat_exp) 49 | w_exp = w_feat_mul.unsqueeze(2).expand(4, bs, channel_num, n_max) 50 | f = torch.sum(features * w_exp, dim=0) * mask 51 | return f 52 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | """Graph matching config system.""" 2 | 3 | import os 4 | from easydict import EasyDict as edict 5 | 6 | __C = edict() 7 | # Consumers can get config by: 8 | 9 | cfg = __C 10 | 11 | __C.combine_classes = False 12 | 13 | # VOC2011-Keypoint Dataset 14 | __C.VOC2011 = edict() 15 | __C.VOC2011.KPT_ANNO_DIR = "./data/downloaded/PascalVOC/annotations/" # keypoint annotation 16 | __C.VOC2011.ROOT_DIR = "./data/downloaded/PascalVOC/VOC2011/" # original VOC2011 dataset 17 | __C.VOC2011.SET_SPLIT = "./data/split/voc2011_pairs.npz" # set split path 18 | __C.VOC2011.CLASSES = [ 19 | "aeroplane", 20 | "bicycle", 21 | "bird", 22 | "boat", 23 | "bottle", 24 | "bus", 25 | "car", 26 | "cat", 27 | "chair", 28 | "cow", 29 | "diningtable", 30 | "dog", 31 | "horse", 32 | "motorbike", 33 | "person", 34 | "pottedplant", 35 | "sheep", 36 | "sofa", 37 | "train", 38 | "tvmonitor", 39 | ] 40 | 41 | # Willow-Object Dataset 42 | __C.WILLOW = edict() 43 | __C.WILLOW.ROOT_DIR = "./data/downloaded/WILLOW/WILLOW-ObjectClass" 44 | __C.WILLOW.CLASSES = ["Car", "Duck", "Face", "Motorbike", "Winebottle"] 45 | __C.WILLOW.KPT_LEN = 10 46 | __C.WILLOW.TRAIN_NUM = 20 47 | __C.WILLOW.TRAIN_OFFSET = 0 48 | 49 | # SPair Dataset 50 | __C.SPair = edict() 51 | __C.SPair.ROOT_DIR = "./data/downloaded/SPair-71k" 52 | __C.SPair.size = "large" 53 | __C.SPair.CLASSES = [ 54 | "aeroplane", 55 | "bicycle", 56 | "bird", 57 | "boat", 58 | "boat", 59 | "bottle", 60 | "bus", 61 | "car", 62 | "cat", 63 | "chair", 64 | "cow", 65 | "dog", 66 | "horse", 67 | "motorbike", 68 | "person", 69 | "pottedplant", 70 | "sheep", 71 | "train", 72 | "tvmonitor", 73 | ] 74 | 75 | 76 | # 77 | # Training options 78 | # 79 | 80 | __C.TRAIN = edict() 81 | __C.TRAIN.difficulty_params = {} 82 | # Iterations per epochs 83 | 84 | __C.EVAL = edict() 85 | __C.EVAL.difficulty_params = {} 86 | 87 | # Mean and std to normalize images 88 | __C.NORM_MEANS = [0.485, 0.456, 0.406] 89 | __C.NORM_STD = [0.229, 0.224, 0.225] 90 | 91 | # Data cache path 92 | __C.CACHE_PATH = "data/cache" 93 | 94 | # random seed used for data loading 95 | __C.RANDOM_SEED = 123 96 | -------------------------------------------------------------------------------- /BB_GM/sconv_archs.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | import torch 3 | import torch.nn.functional as F 4 | from torch_geometric.nn import SplineConv 5 | 6 | 7 | class SConv(torch.nn.Module): 8 | def __init__(self, input_features, output_features): 9 | super(SConv, self).__init__() 10 | 11 | self.in_channels = input_features 12 | self.num_layers = 2 13 | self.convs = torch.nn.ModuleList() 14 | 15 | for _ in range(self.num_layers): 16 | conv = SplineConv(input_features, output_features, dim=2, kernel_size=5, aggr="max") 17 | self.convs.append(conv) 18 | input_features = output_features 19 | 20 | input_features = output_features 21 | self.out_channels = input_features 22 | self.reset_parameters() 23 | 24 | def reset_parameters(self): 25 | for conv in self.convs: 26 | conv.reset_parameters() 27 | 28 | def forward(self, data): 29 | x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr 30 | xs = [x] 31 | 32 | for conv in self.convs[:-1]: 33 | xs += [F.relu(conv(xs[-1], edge_index, edge_attr))] 34 | 35 | xs += [self.convs[-1](xs[-1], edge_index, edge_attr)] 36 | return xs[-1] 37 | 38 | 39 | class SiameseSConvOnNodes(torch.nn.Module): 40 | def __init__(self, input_node_dim): 41 | super(SiameseSConvOnNodes, self).__init__() 42 | self.num_node_features = input_node_dim 43 | self.mp_network = SConv( 44 | input_features=self.num_node_features, output_features=self.num_node_features) 45 | 46 | def forward(self, graph): 47 | old_features = graph.x 48 | result = self.mp_network(graph) 49 | graph.x = old_features + 0.1 * result 50 | return graph 51 | 52 | 53 | class SiameseNodeFeaturesToEdgeFeatures(torch.nn.Module): 54 | def __init__(self, total_num_nodes): 55 | super(SiameseNodeFeaturesToEdgeFeatures, self).__init__() 56 | self.num_edge_features = total_num_nodes 57 | 58 | def forward(self, graph): 59 | orig_graphs = graph.to_data_list() 60 | orig_graphs = [self.vertex_attr_to_edge_attr(graph) for graph in orig_graphs] 61 | return orig_graphs 62 | 63 | def vertex_attr_to_edge_attr(self, graph): 64 | """Assigns the difference of node features to each edge""" 65 | flat_edges = graph.edge_index.transpose(0, 1).reshape(-1) 66 | vertex_attrs = torch.index_select(graph.x, dim=0, index=flat_edges) 67 | 68 | new_shape = (graph.edge_index.shape[1], 2, vertex_attrs.shape[1]) 69 | vertex_attrs_reshaped = vertex_attrs.reshape(new_shape).transpose(0, 1) 70 | new_edge_attrs = vertex_attrs_reshaped[0] - vertex_attrs_reshaped[1] 71 | graph.edge_attr = new_edge_attrs 72 | return graph 73 | -------------------------------------------------------------------------------- /utils/evaluation_metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def f1_score(tp, fp, fn): 5 | """ 6 | F1 score (harmonic mix of precision and recall) between predicted permutation matrix and ground truth permutation matrix. 7 | :param tp: number of true positives 8 | :param fp: number of false positives 9 | :param fn: number of false negatives 10 | :return: F1 score 11 | """ 12 | device = tp.device 13 | 14 | const = torch.tensor(1e-7, device=device) 15 | precision = tp / (tp + fp + const) 16 | recall = tp / (tp + fn + const) 17 | f1 = 2 * precision * recall / (precision + recall + const) 18 | return f1 19 | 20 | 21 | def get_pos_neg(pmat_pred, pmat_gt): 22 | """ 23 | Calculates number of true positives, false positives and false negatives 24 | :param pmat_pred: predicted permutation matrix 25 | :param pmat_gt: ground truth permutation matrix 26 | :return: tp, fp, fn 27 | """ 28 | device = pmat_pred.device 29 | pmat_gt = pmat_gt.to(device) 30 | 31 | tp = torch.sum(pmat_pred * pmat_gt).float() 32 | fp = torch.sum(pmat_pred * (1 - pmat_gt)).float() 33 | fn = torch.sum((1 - pmat_pred) * pmat_gt).float() 34 | return tp, fp, fn 35 | 36 | 37 | def get_pos_neg_from_lists(pmat_pred_list, pmat_gt_list): 38 | device = pmat_pred_list[0].device 39 | tp = torch.zeros(1, device=device) 40 | fp = torch.zeros(1, device=device) 41 | fn = torch.zeros(1, device=device) 42 | for pmat_pred, pmat_gt in zip(pmat_pred_list, pmat_gt_list): 43 | _tp, _fp, _fn = get_pos_neg(pmat_pred, pmat_gt) 44 | tp += _tp 45 | fp += _fp 46 | fn += _fn 47 | return tp, fp, fn 48 | 49 | 50 | def matching_accuracy_from_lists(pmat_pred_list, pmat_gt_list): 51 | device = pmat_pred_list[0].device 52 | match_num = torch.zeros(1, device=device) 53 | total_num = torch.zeros(1, device=device) 54 | for pmat_pred, pmat_gt in zip(pmat_pred_list, pmat_gt_list): 55 | _, _match_num, _total_num = matching_accuracy(pmat_pred, pmat_gt) 56 | match_num += _match_num 57 | total_num += _total_num 58 | return match_num / total_num, match_num, total_num 59 | 60 | 61 | def matching_accuracy(pmat_pred, pmat_gt): 62 | """ 63 | Matching Accuracy between predicted permutation matrix and ground truth permutation matrix. 64 | :param pmat_pred: predicted permutation matrix 65 | :param pmat_gt: ground truth permutation matrix 66 | :param ns: number of exact pairs 67 | :return: matching accuracy, matched num of pairs, total num of pairs 68 | """ 69 | device = pmat_pred.device 70 | batch_num = pmat_pred.shape[0] 71 | 72 | pmat_gt = pmat_gt.to(device) 73 | 74 | assert torch.all((pmat_pred == 0) + (pmat_pred == 1)), "pmat_pred can noly contain 0/1 elements." 75 | assert torch.all((pmat_gt == 0) + (pmat_gt == 1)), "pmat_gt should noly contain 0/1 elements." 76 | assert torch.all(torch.sum(pmat_pred, dim=-1) <= 1) and torch.all(torch.sum(pmat_pred, dim=-2) <= 1) 77 | assert torch.all(torch.sum(pmat_gt, dim=-1) <= 1) and torch.all(torch.sum(pmat_gt, dim=-2) <= 1) 78 | 79 | match_num = 0 80 | total_num = 0 81 | 82 | for b in range(batch_num): 83 | match_num += torch.sum(pmat_pred[b] * pmat_gt[b]) 84 | total_num += torch.sum(pmat_gt[b]) 85 | 86 | return match_num / total_num, match_num, total_num 87 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers 2 | 3 | This repository contains PyTorch implementation of the paper: [Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers .](https://arxiv.org/abs/2003.11657) 4 | 5 | It also contains the configuration files to reproduce the numbers reported in the paper for the following experiments: 6 | * **PascalVOC** Using keypoint-intersection filtering and unfiltered keypoints. Also with multi-matching solver in postprocessing. 7 | * **Willow** Pre-training on PascalVOC and fine tuning on Willow can be controlled separately 8 | * **SPair-71k** With the default *intersection* keypoint filetering. 9 | 10 | See also the [LPMP repository](https://github.com/LPMP/LPMP) with the combinatorial solvers for graph matching and multi-graph matching as well corresponding PyTorch modules. The solvers were made differentiable via [blackbox-backprop](https://github.com/martius-lab/blackbox-backprop) ([Differentiation of Blackbox Combinatorial Solvers](https://openreview.net/forum?id=BkevoJSYPB)) 11 | 12 | Sheep | Chair | Airplane 13 | :-------------------------:|:-------------------------:|:-------------------------: 14 | ![alt text](images/match_sheep.png "Sheep matching example") | ![alt text](images/match_chair.png "Chair matching example") | ![alt text](images/match_aero.png "Airplane matching example") 15 | 16 | 17 | ## Get started 18 | 1. Check if gcc-9, g++-9, cmake are available (for building `lpmp_py`). 19 | 1. Check if findutils (>=4.7.0) is available 20 | 1. Check if hdf5 is installed (``apt install libhdf5-serial-dev``) 21 | 1. Check if cuda 10.1 and cudnn 7 are available 22 | 1. Check if texlive-latex-extra is installed (``apt install texlive-latex-extra``) 23 | 1. Run ``pipenv install`` (at your own risk with `--skip-lock` to save some time). 24 | 1. Run ``chmod +x ./download_data.sh && ./download_data.sh``. 25 | 1. Try running a training example, if the import of torch_geometric fails, follow [this.](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html) 26 | 27 | ## Training 28 | 29 | Run training and evaluation 30 | 31 | ``` 32 | python3 -m pipenv shell 33 | python3 train_eval.py path/to/your/json 34 | ``` 35 | 36 | where ``path/to/your/json`` is the path to your configuration file. Configurations that reproduce the scores reported in the paper are in ``./experiments``. 37 | 38 | ### Willow 39 | In order to run Willow with an architecture pretrained on PascalVOC, you need to create a snapshot to warm-start with. For this purpose, run ``python3 train_eval.py experiments/willow/voc_pretrain.json``. Then enter the path to the checkpoint into `pretrain_[no]finetune.json` in the field `warmstart_path`. 40 | 41 | ## Troubleshooting 42 | * **NANs or significantly worse scores** Check your installation of torch_geometric, torch_sparse, torche_scatter, torch_cluster and torch_spline_conv. Go to the repositories and check the latest installation instructions and make sure to compile locally. 43 | 44 | ## Citation 45 | 46 | ```text 47 | @article{rolinek2020deep, 48 | title={Deep Graph Matching via Blackbox Differentiation of Combinatorial Solvers}, 49 | author={Michal Rolínek and Paul Swoboda and Dominik Zietlow and Anselm Paulus and Vít Musil and Georg Martius}, 50 | year={2020}, 51 | eprint={2003.11657}, 52 | archivePrefix={arXiv}, 53 | primaryClass={cs.LG} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from utils.config import cfg 7 | from utils.evaluation_metric import matching_accuracy, f1_score, get_pos_neg 8 | 9 | 10 | def eval_model(model, dataloader, eval_epoch=None, verbose=False): 11 | print("Start evaluation...") 12 | since = time.time() 13 | 14 | device = next(model.parameters()).device 15 | 16 | if eval_epoch is not None: 17 | model_path = str(Path(cfg.OUTPUT_PATH) / "params" / "params_{:04}.pt".format(eval_epoch)) 18 | print("Loading model parameters from {}".format(model_path)) 19 | model.load_state_dict(torch.load(model_path)) 20 | 21 | was_training = model.training 22 | model.eval() 23 | 24 | ds = dataloader.dataset 25 | ds.set_num_graphs(cfg.EVAL.num_graphs_in_matching_instance) 26 | classes = ds.classes 27 | cls_cache = ds.cls 28 | 29 | accs = torch.zeros(len(classes), device=device) 30 | f1_scores = torch.zeros(len(classes), device=device) 31 | 32 | for i, cls in enumerate(classes): 33 | if verbose: 34 | print("Evaluating class {}: {}/{}".format(cls, i, len(classes))) 35 | 36 | running_since = time.time() 37 | iter_num = 0 38 | 39 | ds.set_cls(cls) 40 | acc_match_num = torch.zeros(1, device=device) 41 | acc_total_num = torch.zeros(1, device=device) 42 | tp = torch.zeros(1, device=device) 43 | fp = torch.zeros(1, device=device) 44 | fn = torch.zeros(1, device=device) 45 | for k, inputs in enumerate(dataloader): 46 | data_list = [_.cuda() for _ in inputs["images"]] 47 | 48 | points_gt = [_.cuda() for _ in inputs["Ps"]] 49 | n_points_gt = [_.cuda() for _ in inputs["ns"]] 50 | edges = [_.to("cuda") for _ in inputs["edges"]] 51 | perm_mat_list = [perm_mat.cuda() for perm_mat in inputs["gt_perm_mat"]] 52 | 53 | batch_num = data_list[0].size(0) 54 | 55 | iter_num = iter_num + 1 56 | 57 | visualize = k == 0 and cfg.visualize 58 | visualization_params = {**cfg.visualization_params, **dict(string_info=cls, true_matchings=perm_mat_list)} 59 | with torch.set_grad_enabled(False): 60 | s_pred_list = model( 61 | data_list, 62 | points_gt, 63 | edges, 64 | n_points_gt, 65 | perm_mat_list, 66 | visualize_flag=visualize, 67 | visualization_params=visualization_params, 68 | ) 69 | 70 | 71 | _, _acc_match_num, _acc_total_num = matching_accuracy(s_pred_list[0], perm_mat_list[0]) 72 | _tp, _fp, _fn = get_pos_neg(s_pred_list[0], perm_mat_list[0]) 73 | 74 | acc_match_num += _acc_match_num 75 | acc_total_num += _acc_total_num 76 | tp += _tp 77 | fp += _fp 78 | fn += _fn 79 | 80 | if iter_num % cfg.STATISTIC_STEP == 0 and verbose: 81 | running_speed = cfg.STATISTIC_STEP * batch_num / (time.time() - running_since) 82 | print("Class {:<8} Iteration {:<4} {:>4.2f}sample/s".format(cls, iter_num, running_speed)) 83 | running_since = time.time() 84 | 85 | accs[i] = acc_match_num / acc_total_num 86 | f1_scores[i] = f1_score(tp, fp, fn) 87 | if verbose: 88 | print("Class {} acc = {:.4f} F1 = {:.4f}".format(cls, accs[i], f1_scores[i])) 89 | 90 | time_elapsed = time.time() - since 91 | print("Evaluation complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) 92 | 93 | model.train(mode=was_training) 94 | ds.cls = cls_cache 95 | 96 | print("Matching accuracy") 97 | for cls, single_acc, f1_sc in zip(classes, accs, f1_scores): 98 | print("{} = {:.4f}, {:.4f}".format(cls, single_acc, f1_sc)) 99 | print("average = {:.4f}, {:.4f}".format(torch.mean(accs), torch.mean(f1_scores))) 100 | 101 | return accs, f1_scores 102 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations as comb 2 | import torch 3 | import collections 4 | import json 5 | from copy import deepcopy 6 | import ast 7 | from warnings import warn 8 | import sys 9 | import time 10 | import os 11 | 12 | JSON_FILE_KEY = 'default_json' 13 | 14 | class UnNormalize(object): 15 | def __init__(self, mean, std): 16 | self.mean = mean 17 | self.std = std 18 | 19 | def __call__(self, tensor): 20 | for index, (t, m, s) in enumerate(zip(tensor, self.mean, self.std)): 21 | tensor[index] = t * s + m 22 | return tensor 23 | 24 | 25 | def n_and_l_iter_parallel(n, l, enum=False): 26 | def lexico_iter_list(lex_list): 27 | for lex in lex_list: 28 | yield lexico_iter(lex) 29 | if enum: 30 | yield lexico_iter(range(len(lex_list[0]))) 31 | 32 | for zipped in zip(*n, *lexico_iter_list(l)): 33 | yield zipped 34 | 35 | 36 | def lexico_iter(lex): 37 | return comb(lex, 2) 38 | 39 | 40 | def torch_to_numpy_list(list_of_tensors): 41 | return [x.cpu().detach().numpy() for x in list_of_tensors] 42 | 43 | 44 | def numpy_to_torch_list(list_of_np_arrays, device, dtype): 45 | return [torch.from_numpy(x).to(dtype).to(device) for x in list_of_np_arrays] 46 | 47 | 48 | class ParamDict(dict): 49 | """ An immutable dict where elements can be accessed with a dot""" 50 | __getattr__ = dict.__getitem__ 51 | 52 | def __delattr__(self, item): 53 | raise TypeError("Setting object not mutable after settings are fixed!") 54 | 55 | def __setattr__(self, key, value): 56 | raise TypeError("Setting object not mutable after settings are fixed!") 57 | 58 | def __setitem__(self, key, value): 59 | raise TypeError("Setting object not mutable after settings are fixed!") 60 | 61 | def __deepcopy__(self, memo): 62 | """ In order to support deepcopy""" 63 | return ParamDict([(deepcopy(k, memo), deepcopy(v, memo)) for k, v in self.items()]) 64 | 65 | def __repr__(self): 66 | return json.dumps(self, indent=4, sort_keys=True) 67 | 68 | def recursive_objectify(nested_dict): 69 | "Turns a nested_dict into a nested ParamDict" 70 | result = deepcopy(nested_dict) 71 | for k, v in result.items(): 72 | if isinstance(v, collections.Mapping): 73 | result[k] = recursive_objectify(v) 74 | return ParamDict(result) 75 | 76 | 77 | class SafeDict(dict): 78 | """ A dict with prohibiting init from a list of pairs containing duplicates""" 79 | def __init__(self, *args, **kwargs): 80 | if args and args[0] and not isinstance(args[0], dict): 81 | keys, _ = zip(*args[0]) 82 | duplicates =[item for item, count in collections.Counter(keys).items() if count > 1] 83 | if duplicates: 84 | raise TypeError("Keys {} repeated in json parsing".format(duplicates)) 85 | super().__init__(*args, **kwargs) 86 | 87 | def load_json(file): 88 | """ Safe load of a json file (doubled entries raise exception)""" 89 | with open(file, 'r') as f: 90 | data = json.load(f, object_pairs_hook=SafeDict) 91 | return data 92 | 93 | 94 | def update_recursive(d, u, defensive=False): 95 | for k, v in u.items(): 96 | if defensive and k not in d: 97 | raise KeyError("Updating a non-existing key") 98 | if isinstance(v, collections.Mapping): 99 | d[k] = update_recursive(d.get(k, {}), v) 100 | else: 101 | d[k] = v 102 | return d 103 | 104 | def is_json_file(cmd_line): 105 | try: 106 | return os.path.isfile(cmd_line) 107 | except Exception as e: 108 | warn('JSON parsing suppressed exception: ', e) 109 | return False 110 | 111 | 112 | def is_parseable_dict(cmd_line): 113 | try: 114 | res = ast.literal_eval(cmd_line) 115 | return isinstance(res, dict) 116 | except Exception as e: 117 | warn('Dict literal eval suppressed exception: ', e) 118 | return False 119 | 120 | def update_params_from_cmdline(cmd_line=None, default_params=None, custom_parser=None, verbose=True): 121 | """ Updates default settings based on command line input. 122 | 123 | :param cmd_line: Expecting (same format as) sys.argv 124 | :param default_params: Dictionary of default params 125 | :param custom_parser: callable that returns a dict of params on success 126 | and None on failure (suppress exceptions!) 127 | :param verbose: Boolean to determine if final settings are pretty printed 128 | :return: Immutable nested dict with (deep) dot access. Priority: default_params < default_json < cmd_line 129 | """ 130 | if not cmd_line: 131 | cmd_line = sys.argv 132 | 133 | if default_params is None: 134 | default_params = {} 135 | 136 | if len(cmd_line) < 2: 137 | cmd_params = {} 138 | elif custom_parser and custom_parser(cmd_line): # Custom parsing, typically for flags 139 | cmd_params = custom_parser(cmd_line) 140 | elif len(cmd_line) == 2 and is_json_file(cmd_line[1]): 141 | cmd_params = load_json(cmd_line[1]) 142 | elif len(cmd_line) == 2 and is_parseable_dict(cmd_line[1]): 143 | cmd_params = ast.literal_eval(cmd_line[1]) 144 | else: 145 | raise ValueError('Failed to parse command line') 146 | 147 | update_recursive(default_params, cmd_params) 148 | 149 | if JSON_FILE_KEY in default_params: 150 | json_params = load_json(default_params[JSON_FILE_KEY]) 151 | if 'default_json' in json_params: 152 | json_base = load_json(json_params[JSON_FILE_KEY]) 153 | else: 154 | json_base = {} 155 | update_recursive(json_base, json_params) 156 | update_recursive(default_params, json_base) 157 | 158 | update_recursive(default_params, cmd_params) 159 | final_params = recursive_objectify(default_params) 160 | if verbose: 161 | print(final_params) 162 | 163 | update_params_from_cmdline.start_time = time.time() 164 | return final_params 165 | 166 | update_params_from_cmdline.start_time = None -------------------------------------------------------------------------------- /BB_GM/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utils.backbone 4 | from BB_GM.affinity_layer import InnerProductWithWeightsAffinity 5 | from BB_GM.sconv_archs import SiameseSConvOnNodes, SiameseNodeFeaturesToEdgeFeatures 6 | from lpmp_py import GraphMatchingModule 7 | from lpmp_py import MultiGraphMatchingModule 8 | from utils.config import cfg 9 | from utils.feature_align import feature_align 10 | from utils.utils import lexico_iter 11 | from utils.visualization import easy_visualize 12 | 13 | 14 | def normalize_over_channels(x): 15 | channel_norms = torch.norm(x, dim=1, keepdim=True) 16 | return x / channel_norms 17 | 18 | 19 | def concat_features(embeddings, num_vertices): 20 | res = torch.cat([embedding[:, :num_v] for embedding, num_v in zip(embeddings, num_vertices)], dim=-1) 21 | return res.transpose(0, 1) 22 | 23 | 24 | class Net(utils.backbone.VGG16_bn): 25 | def __init__(self): 26 | super(Net, self).__init__() 27 | self.message_pass_node_features = SiameseSConvOnNodes(input_node_dim=1024) 28 | self.build_edge_features_from_node_features = SiameseNodeFeaturesToEdgeFeatures( 29 | total_num_nodes=self.message_pass_node_features.num_node_features 30 | ) 31 | self.global_state_dim = 1024 32 | self.vertex_affinity = InnerProductWithWeightsAffinity( 33 | self.global_state_dim, self.message_pass_node_features.num_node_features) 34 | self.edge_affinity = InnerProductWithWeightsAffinity( 35 | self.global_state_dim, 36 | self.build_edge_features_from_node_features.num_edge_features) 37 | 38 | def forward( 39 | self, 40 | images, 41 | points, 42 | graphs, 43 | n_points, 44 | perm_mats, 45 | visualize_flag=False, 46 | visualization_params=None, 47 | ): 48 | 49 | global_list = [] 50 | orig_graph_list = [] 51 | for image, p, n_p, graph in zip(images, points, n_points, graphs): 52 | # extract feature 53 | nodes = self.node_layers(image) 54 | edges = self.edge_layers(nodes) 55 | 56 | global_list.append(self.final_layers(edges)[0].reshape((nodes.shape[0], -1))) 57 | nodes = normalize_over_channels(nodes) 58 | edges = normalize_over_channels(edges) 59 | 60 | # arrange features 61 | U = concat_features(feature_align(nodes, p, n_p, (256, 256)), n_p) 62 | F = concat_features(feature_align(edges, p, n_p, (256, 256)), n_p) 63 | node_features = torch.cat((U, F), dim=-1) 64 | graph.x = node_features 65 | 66 | graph = self.message_pass_node_features(graph) 67 | orig_graph = self.build_edge_features_from_node_features(graph) 68 | orig_graph_list.append(orig_graph) 69 | 70 | global_weights_list = [ 71 | torch.cat([global_src, global_tgt], axis=-1) for global_src, global_tgt in lexico_iter(global_list) 72 | ] 73 | global_weights_list = [normalize_over_channels(g) for g in global_weights_list] 74 | 75 | unary_costs_list = [ 76 | self.vertex_affinity([item.x for item in g_1], [item.x for item in g_2], global_weights) 77 | for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list), global_weights_list) 78 | ] 79 | 80 | # Similarities to costs 81 | unary_costs_list = [[-x for x in unary_costs] for unary_costs in unary_costs_list] 82 | 83 | if self.training: 84 | unary_costs_list = [ 85 | [ 86 | x + 1.0*gt[:dim_src, :dim_tgt] # Add margin with alpha = 1.0 87 | for x, gt, dim_src, dim_tgt in zip(unary_costs, perm_mat, ns_src, ns_tgt) 88 | ] 89 | for unary_costs, perm_mat, (ns_src, ns_tgt) in zip(unary_costs_list, perm_mats, lexico_iter(n_points)) 90 | ] 91 | 92 | quadratic_costs_list = [ 93 | self.edge_affinity([item.edge_attr for item in g_1], [item.edge_attr for item in g_2], global_weights) 94 | for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list), global_weights_list) 95 | ] 96 | 97 | # Aimilarities to costs 98 | quadratic_costs_list = [[-0.5 * x for x in quadratic_costs] for quadratic_costs in quadratic_costs_list] 99 | 100 | if cfg.BB_GM.solver_name == "lpmp": 101 | all_edges = [[item.edge_index for item in graph] for graph in orig_graph_list] 102 | gm_solvers = [ 103 | GraphMatchingModule( 104 | all_left_edges, 105 | all_right_edges, 106 | ns_src, 107 | ns_tgt, 108 | cfg.BB_GM.lambda_val, 109 | cfg.BB_GM.solver_params, 110 | ) 111 | for (all_left_edges, all_right_edges), (ns_src, ns_tgt) in zip( 112 | lexico_iter(all_edges), lexico_iter(n_points) 113 | ) 114 | ] 115 | matchings = [ 116 | gm_solver(unary_costs, quadratic_costs) 117 | for gm_solver, unary_costs, quadratic_costs in zip(gm_solvers, unary_costs_list, quadratic_costs_list) 118 | ] 119 | elif cfg.BB_GM.solver_name == "multigraph": 120 | all_edges = [[item.edge_index for item in graph] for graph in orig_graph_list] 121 | gm_solver = MultiGraphMatchingModule( 122 | all_edges, n_points, cfg.BB_GM.lambda_val, cfg.BB_GM.solver_params) 123 | matchings = gm_solver(unary_costs_list, quadratic_costs_list) 124 | else: 125 | raise ValueError(f"Unknown solver {cfg.BB_GM.solver_name}") 126 | 127 | if visualize_flag: 128 | easy_visualize( 129 | orig_graph_list, 130 | points, 131 | n_points, 132 | images, 133 | unary_costs_list, 134 | quadratic_costs_list, 135 | matchings, 136 | **visualization_params, 137 | ) 138 | 139 | return matchings 140 | -------------------------------------------------------------------------------- /utils/latex_utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | from shutil import copyfile 4 | from subprocess import run, PIPE 5 | from tempfile import TemporaryDirectory 6 | 7 | def subsection(section_name, content): 8 | begin = '\\begin{{subsection}}{{{}}}\n'.format(section_name) 9 | end = '\\end{subsection}\n' 10 | return '{}\n\\leavevmode\n\n\\medskip\n{}{}'.format(begin, content, end) 11 | 12 | 13 | def add_subsection_from_figures(section_name, file_list, common_scale=0.7): 14 | content = '\n'.join([include_figure(filename, common_scale) for filename in file_list]) 15 | return subsection(section_name, content) 16 | 17 | 18 | def include_figure(filename, scale_linewidth=1.0): 19 | return '\\includegraphics[width={}\linewidth]{{\detokenize{{{}}}}}\n'.format(scale_linewidth, filename) 20 | 21 | 22 | def section(section_name, content): 23 | begin = '\\begin{{section}}{{{}}}\n'.format(section_name) 24 | end = '\\end{section}\n' 25 | return '{}{}{}'.format(begin, content, end) 26 | 27 | 28 | class LatexFile(object): 29 | def __init__(self, title): 30 | self.title = title 31 | self.date = str(datetime.datetime.today()).split()[0] 32 | self.sections = [] 33 | 34 | def add_section_from_figures(self, name, list_of_filenames, common_scale=1.0): 35 | begin = '\\begin{center}' 36 | end = '\\end{center}' 37 | content = '\n'.join([begin] + [include_figure(filename, common_scale) 38 | for filename in list_of_filenames] + [end]) 39 | self.sections.append(section(name, content)) 40 | 41 | def add_subsection_from_figures(self, section_name, file_list, common_scale=1.0): 42 | content = '\n'.join([include_figure(filename, common_scale) for filename in file_list]) 43 | return self.sections.append(subsection(section_name, content)) 44 | 45 | def add_section_from_dataframe(self, name, dataframe): 46 | begin = '\\begin{center}' 47 | end = '\\end{center}' 48 | section_content = '\n'.join([begin, dataframe.to_latex(), end]) 49 | self.sections.append(section(name, section_content)) 50 | 51 | def add_section_from_python_script(self, name, python_file): 52 | with open(python_file) as f: 53 | raw = f.read() 54 | content = '\\begin{{lstlisting}}[language=Python]\n {}\\end{{lstlisting}}'.format(raw) 55 | self.sections.append(section(name, content)) 56 | 57 | def add_generic_section(self, name, content): 58 | self.sections.append(section(name, content)) 59 | 60 | def add_section_from_json(self, name, json_file): 61 | with open(json_file) as f: 62 | raw = f.read() 63 | content = '\\begin{{lstlisting}}[language=json]\n {}\\end{{lstlisting}}'.format(raw) 64 | self.sections.append(section(name, content)) 65 | 66 | def produce_pdf(self, output_file): 67 | full_content = '\n'.join(self.sections) 68 | title_str = LATEX_TITLE.format(self.title) 69 | date_str = LATEX_DATE.format(self.date) 70 | whole_latex = '\n'.join([LATEX_BEGIN, title_str, date_str, full_content, LATEX_END]) 71 | with TemporaryDirectory() as tmpdir: 72 | latex_file = os.path.join(tmpdir, 'latex.tex') 73 | with open(latex_file, 'w') as f: 74 | f.write(whole_latex) 75 | run(['pdflatex', latex_file], cwd=tmpdir, check=True, stdout=PIPE) 76 | output_tmp = os.path.join(tmpdir, 'latex.pdf') 77 | copyfile(output_tmp, output_file) 78 | 79 | LATEX_BEGIN = ''' 80 | \\documentclass{amsart} 81 | \\usepackage{graphicx} 82 | \\usepackage{framed} 83 | \\usepackage{verbatim} 84 | \\usepackage{booktabs} 85 | \\usepackage{url} 86 | \\usepackage{underscore} 87 | \\usepackage{listings} 88 | \\usepackage{mdframed} 89 | \\usepackage[margin=1.0in]{geometry} 90 | 91 | \\usepackage{color} 92 | \\usepackage{xcolor} 93 | 94 | \\definecolor{eclipseStrings}{RGB}{42,0.0,255} 95 | \\definecolor{eclipseKeywords}{RGB}{127,0,85} 96 | \\colorlet{numb}{magenta!60!black} 97 | 98 | \\lstdefinelanguage{json}{ 99 | basicstyle=\\normalfont\\ttfamily, 100 | commentstyle=\\color{eclipseStrings}, % style of comment 101 | stringstyle=\\color{eclipseKeywords}, % style of strings 102 | numbers=left, 103 | numberstyle=\scriptsize, 104 | stepnumber=1, 105 | numbersep=8pt, 106 | showstringspaces=false, 107 | breaklines=true, 108 | frame=lines, 109 | backgroundcolor=\\color{white}, %only if you like 110 | string=[s]{"}{"}, 111 | comment=[l]{:\\ "}, 112 | morecomment=[l]{:"}, 113 | literate= 114 | *{0}{{{\\color{numb}0}}}{1} 115 | {1}{{{\\color{numb}1}}}{1} 116 | {2}{{{\\color{numb}2}}}{1} 117 | {3}{{{\\color{numb}3}}}{1} 118 | {4}{{{\\color{numb}4}}}{1} 119 | {5}{{{\\color{numb}5}}}{1} 120 | {6}{{{\\color{numb}6}}}{1} 121 | {7}{{{\\color{numb}7}}}{1} 122 | {8}{{{\\color{numb}8}}}{1} 123 | {9}{{{\\color{numb}9}}}{1} 124 | } 125 | 126 | \\definecolor{codegreen}{rgb}{0,0.6,0} 127 | \\definecolor{codegray}{rgb}{0.5,0.5,0.5} 128 | \\definecolor{codepurple}{rgb}{0.58,0,0.82} 129 | \\definecolor{backcolour}{rgb}{0.95,0.95,0.92} 130 | 131 | \\lstdefinestyle{mystyle}{ 132 | backgroundcolor=\\color{backcolour}, 133 | commentstyle=\\color{codegreen}, 134 | keywordstyle=\\color{magenta}, 135 | numberstyle=\\tiny\\color{codegray}, 136 | stringstyle=\\color{codepurple}, 137 | basicstyle=\\footnotesize, 138 | breakatwhitespace=false, 139 | breaklines=true, 140 | captionpos=b, 141 | keepspaces=true, 142 | numbers=left, 143 | numbersep=5pt, 144 | showspaces=false, 145 | showstringspaces=false, 146 | showtabs=false, 147 | tabsize=2 148 | } 149 | 150 | \\lstset{style=mystyle} 151 | 152 | \\graphicspath{{./figures/}} 153 | \\newtheorem{theorem}{Theorem}[section] 154 | \\newtheorem{conj}[theorem]{Conjecture} 155 | \\newtheorem{lemma}[theorem]{Lemma} 156 | \\newtheorem{prop}[theorem]{Proposition} 157 | \\newtheorem{cor}[theorem]{Corollary} 158 | \\def \\qbar {\\overline{\\mathbb{Q}}} 159 | \\theoremstyle{definition} 160 | \\newtheorem{definition}[theorem]{Definition} 161 | \\newtheorem{example}[theorem]{Example} 162 | \\newtheorem{xca}[theorem]{Exercise} 163 | \\theoremstyle{remark} 164 | \\newtheorem{remark}[theorem]{Remark} 165 | \\numberwithin{equation}{section} 166 | \\begin{document}\n 167 | ''' 168 | 169 | LATEX_TITLE = '\\title{{{}}}\n' 170 | LATEX_DATE = '\\date{{{}}}\n \\maketitle\n' 171 | LATEX_END = '\\end{document}' 172 | -------------------------------------------------------------------------------- /data/data_loader_multigraph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.utils.data import Dataset 4 | from torchvision import transforms 5 | import numpy as np 6 | import random 7 | from data.pascal_voc import PascalVOC 8 | from data.willow_obj import WillowObject 9 | from data.SPair71k import SPair71k 10 | from utils.build_graphs import build_graphs 11 | 12 | from utils.config import cfg 13 | from torch_geometric.data import Data, Batch 14 | 15 | datasets = {"PascalVOC": PascalVOC, 16 | "WillowObject": WillowObject, 17 | "SPair71k": SPair71k} 18 | 19 | class GMDataset(Dataset): 20 | def __init__(self, name, length, **args): 21 | self.name = name 22 | self.ds = datasets[name](**args) 23 | self.true_epochs = length is None 24 | self.length = ( 25 | self.ds.total_size if self.true_epochs else length 26 | ) # NOTE images pairs are sampled randomly, so there is no exact definition of dataset size 27 | if self.true_epochs: 28 | print(f"Initializing {self.ds.sets}-set with all {self.length} examples.") 29 | else: 30 | print(f"Initializing {self.ds.sets}-set. Randomly sampling {self.length} examples.") 31 | # length here represents the iterations between two checkpoints 32 | # if length is None the length is set to the size of the ds 33 | self.obj_size = self.ds.obj_resize 34 | self.classes = self.ds.classes 35 | self.cls = None 36 | self.num_graphs_in_matching_instance = None 37 | 38 | def set_cls(self, cls): 39 | if cls == "none": 40 | cls = None 41 | self.cls = cls 42 | if self.true_epochs: # Update length of dataset for dataloader according to class 43 | self.length = self.ds.total_size if cls is None else self.ds.size_by_cls[cls] 44 | 45 | def set_num_graphs(self, num_graphs_in_matching_instance): 46 | self.num_graphs_in_matching_instance = num_graphs_in_matching_instance 47 | 48 | def __len__(self): 49 | return self.length 50 | 51 | def __getitem__(self, idx): 52 | sampling_strategy = cfg.train_sampling if self.ds.sets == "train" else cfg.eval_sampling 53 | if self.num_graphs_in_matching_instance is None: 54 | raise ValueError("Num_graphs has to be set to an integer value.") 55 | 56 | idx = idx if self.true_epochs else None 57 | anno_list, perm_mat_list = self.ds.get_k_samples(idx, k=self.num_graphs_in_matching_instance, cls=self.cls, mode=sampling_strategy) 58 | for perm_mat in perm_mat_list: 59 | if ( 60 | not perm_mat.size 61 | or (perm_mat.size < 2 * 2 and sampling_strategy == "intersection") 62 | and not self.true_epochs 63 | ): 64 | # 'and not self.true_epochs' because we assume all data is valid when sampling a true epoch 65 | next_idx = None if idx is None else idx + 1 66 | return self.__getitem__(next_idx) 67 | 68 | points_gt = [np.array([(kp["x"], kp["y"]) for kp in anno_dict["keypoints"]]) for anno_dict in anno_list] 69 | n_points_gt = [len(p_gt) for p_gt in points_gt] 70 | 71 | graph_list = [] 72 | for p_gt, n_p_gt in zip(points_gt, n_points_gt): 73 | edge_indices, edge_features = build_graphs(p_gt, n_p_gt) 74 | 75 | # Add dummy node features so the __slices__ of them is saved when creating a batch 76 | pos = torch.tensor(p_gt).to(torch.float32) / 256.0 77 | assert (pos > -1e-5).all(), p_gt 78 | graph = Data( 79 | edge_attr=torch.tensor(edge_features).to(torch.float32), 80 | edge_index=torch.tensor(edge_indices, dtype=torch.long), 81 | x=pos, 82 | pos=pos, 83 | ) 84 | graph.num_nodes = n_p_gt 85 | graph_list.append(graph) 86 | 87 | ret_dict = { 88 | "Ps": [torch.Tensor(x) for x in points_gt], 89 | "ns": [torch.tensor(x) for x in n_points_gt], 90 | "gt_perm_mat": perm_mat_list, 91 | "edges": graph_list, 92 | } 93 | 94 | imgs = [anno["image"] for anno in anno_list] 95 | if imgs[0] is not None: 96 | trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize(cfg.NORM_MEANS, cfg.NORM_STD)]) 97 | imgs = [trans(img) for img in imgs] 98 | ret_dict["images"] = imgs 99 | elif "feat" in anno_list[0]["keypoints"][0]: 100 | feat_list = [np.stack([kp["feat"] for kp in anno_dict["keypoints"]], axis=-1) for anno_dict in anno_list] 101 | ret_dict["features"] = [torch.Tensor(x) for x in feat_list] 102 | 103 | return ret_dict 104 | 105 | 106 | def collate_fn(data: list): 107 | """ 108 | Create mini-batch data for training. 109 | :param data: data dict 110 | :return: mini-batch 111 | """ 112 | 113 | def pad_tensor(inp): 114 | assert type(inp[0]) == torch.Tensor 115 | it = iter(inp) 116 | t = next(it) 117 | max_shape = list(t.shape) 118 | while True: 119 | try: 120 | t = next(it) 121 | for i in range(len(max_shape)): 122 | max_shape[i] = int(max(max_shape[i], t.shape[i])) 123 | except StopIteration: 124 | break 125 | max_shape = np.array(max_shape) 126 | 127 | padded_ts = [] 128 | for t in inp: 129 | pad_pattern = np.zeros(2 * len(max_shape), dtype=np.int64) 130 | pad_pattern[::-2] = max_shape - np.array(t.shape) 131 | pad_pattern = tuple(pad_pattern.tolist()) 132 | padded_ts.append(F.pad(t, pad_pattern, "constant", 0)) 133 | 134 | return padded_ts 135 | 136 | def stack(inp): 137 | if type(inp[0]) == list: 138 | ret = [] 139 | for vs in zip(*inp): 140 | ret.append(stack(vs)) 141 | elif type(inp[0]) == dict: 142 | ret = {} 143 | for kvs in zip(*[x.items() for x in inp]): 144 | ks, vs = zip(*kvs) 145 | for k in ks: 146 | assert k == ks[0], "Key value mismatch." 147 | ret[k] = stack(vs) 148 | elif type(inp[0]) == torch.Tensor: 149 | new_t = pad_tensor(inp) 150 | ret = torch.stack(new_t, 0) 151 | elif type(inp[0]) == np.ndarray: 152 | new_t = pad_tensor([torch.from_numpy(x) for x in inp]) 153 | ret = torch.stack(new_t, 0) 154 | elif type(inp[0]) == str: 155 | ret = inp 156 | elif type(inp[0]) == Data: # Graph from torch.geometric, create a batch 157 | ret = Batch.from_data_list(inp) 158 | else: 159 | raise ValueError("Cannot handle type {}".format(type(inp[0]))) 160 | return ret 161 | 162 | ret = stack(data) 163 | return ret 164 | 165 | 166 | def worker_init_fix(worker_id): 167 | """ 168 | Init dataloader workers with fixed seed. 169 | """ 170 | random.seed(cfg.RANDOM_SEED + worker_id) 171 | np.random.seed(cfg.RANDOM_SEED + worker_id) 172 | 173 | 174 | def worker_init_rand(worker_id): 175 | """ 176 | Init dataloader workers with torch.initial_seed(). 177 | torch.initial_seed() returns different seeds when called from different dataloader threads. 178 | """ 179 | random.seed(torch.initial_seed()) 180 | np.random.seed(torch.initial_seed() % 2 ** 32) 181 | 182 | 183 | def get_dataloader(dataset, fix_seed=True, shuffle=False): 184 | return torch.utils.data.DataLoader( 185 | dataset, 186 | batch_size=cfg.BATCH_SIZE, 187 | shuffle=shuffle, 188 | num_workers=2, 189 | collate_fn=collate_fn, 190 | pin_memory=False, 191 | worker_init_fn=worker_init_fix if fix_seed else worker_init_rand, 192 | ) 193 | -------------------------------------------------------------------------------- /data/SPair71k.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import pickle 5 | import random 6 | 7 | import numpy as np 8 | from PIL import Image 9 | 10 | from utils.config import cfg 11 | 12 | cache_path = cfg.CACHE_PATH 13 | pair_ann_path = cfg.SPair.ROOT_DIR + "/PairAnnotation" 14 | layout_path = cfg.SPair.ROOT_DIR + "/Layout" 15 | image_path = cfg.SPair.ROOT_DIR + "/JPEGImages" 16 | dataset_size = cfg.SPair.size 17 | 18 | sets_translation_dict = dict(train="trn", test="test") 19 | difficulty_params_dict = dict( 20 | trn=cfg.TRAIN.difficulty_params, val=cfg.EVAL.difficulty_params, test=cfg.EVAL.difficulty_params 21 | ) 22 | 23 | 24 | class SPair71k: 25 | def __init__(self, sets, obj_resize): 26 | """ 27 | :param sets: 'train' or 'test' 28 | :param obj_resize: resized object size 29 | """ 30 | self.sets = sets_translation_dict[sets] 31 | self.ann_files = open(os.path.join(layout_path, dataset_size, self.sets + ".txt"), "r").read().split("\n") 32 | self.ann_files = self.ann_files[: len(self.ann_files) - 1] 33 | self.difficulty_params = difficulty_params_dict[self.sets] 34 | self.pair_ann_path = pair_ann_path 35 | self.image_path = image_path 36 | self.classes = list(map(lambda x: os.path.basename(x), glob.glob("%s/*" % image_path))) 37 | self.classes.sort() 38 | self.obj_resize = obj_resize 39 | self.combine_classes = cfg.combine_classes 40 | self.ann_files_filtered, self.ann_files_filtered_cls_dict, self.classes = self.filter_annotations( 41 | self.ann_files, self.difficulty_params 42 | ) 43 | self.total_size = len(self.ann_files_filtered) 44 | self.size_by_cls = {cls: len(ann_list) for cls, ann_list in self.ann_files_filtered_cls_dict.items()} 45 | 46 | def filter_annotations(self, ann_files, difficulty_params): 47 | if len(difficulty_params) > 0: 48 | basepath = os.path.join(self.pair_ann_path, "pickled", self.sets) 49 | if not os.path.exists(basepath): 50 | os.makedirs(basepath) 51 | difficulty_paramas_str = self.diff_dict_to_str(difficulty_params) 52 | try: 53 | filepath = os.path.join(basepath, difficulty_paramas_str + ".pickle") 54 | ann_files_filtered = pickle.load(open(filepath, "rb")) 55 | print( 56 | f"Found filtered annotations for difficulty parameters {difficulty_params} and {self.sets}-set at {filepath}" 57 | ) 58 | except (OSError, IOError) as e: 59 | print( 60 | f"No pickled annotations found for difficulty parameters {difficulty_params} and {self.sets}-set. Filtering..." 61 | ) 62 | ann_files_filtered_dict = {} 63 | 64 | for ann_file in ann_files: 65 | with open(os.path.join(self.pair_ann_path, self.sets, ann_file + ".json")) as f: 66 | annotation = json.load(f) 67 | diff = {key: annotation[key] for key in self.difficulty_params.keys()} 68 | diff_str = self.diff_dict_to_str(diff) 69 | if diff_str in ann_files_filtered_dict: 70 | ann_files_filtered_dict[diff_str].append(ann_file) 71 | else: 72 | ann_files_filtered_dict[diff_str] = [ann_file] 73 | total_l = 0 74 | for diff_str, file_list in ann_files_filtered_dict.items(): 75 | total_l += len(file_list) 76 | filepath = os.path.join(basepath, diff_str + ".pickle") 77 | pickle.dump(file_list, open(filepath, "wb")) 78 | assert total_l == len(ann_files) 79 | print(f"Done filtering. Saved filtered annotations to {basepath}.") 80 | ann_files_filtered = ann_files_filtered_dict[difficulty_paramas_str] 81 | else: 82 | print(f"No difficulty parameters for {self.sets}-set. Using all available data.") 83 | ann_files_filtered = ann_files 84 | 85 | ann_files_filtered_cls_dict = { 86 | cls: list(filter(lambda x: cls in x, ann_files_filtered)) for cls in self.classes 87 | } 88 | class_len = {cls: len(ann_list) for cls, ann_list in ann_files_filtered_cls_dict.items()} 89 | print(f"Number of annotation pairs matching the difficulty params in {self.sets}-set: {class_len}") 90 | if self.combine_classes: 91 | cls_name = "combined" 92 | ann_files_filtered_cls_dict = {cls_name: ann_files_filtered} 93 | filtered_classes = [cls_name] 94 | print(f"Combining {self.sets}-set classes. Total of {len(ann_files_filtered)} image pairs used.") 95 | else: 96 | filtered_classes = [] 97 | for cls, ann_f in ann_files_filtered_cls_dict.items(): 98 | if len(ann_f) > 0: 99 | filtered_classes.append(cls) 100 | else: 101 | print(f"Excluding class {cls} from {self.sets}-set.") 102 | return ann_files_filtered, ann_files_filtered_cls_dict, filtered_classes 103 | 104 | def diff_dict_to_str(self, diff): 105 | diff_str = "" 106 | keys = ["mirror", "viewpoint_variation", "scale_variation", "truncation", "occlusion"] 107 | for key in keys: 108 | if key in diff.keys(): 109 | diff_str += key 110 | diff_str += str(diff[key]) 111 | return diff_str 112 | 113 | def get_k_samples(self, idx, k, mode, cls=None, shuffle=True): 114 | """ 115 | Randomly get a sample of k objects from VOC-Berkeley keypoints dataset 116 | :param idx: Index of datapoint to sample, None for random sampling 117 | :param k: number of datapoints in sample 118 | :param mode: sampling strategy 119 | :param cls: None for random class, or specify for a certain set 120 | :param shuffle: random shuffle the keypoints 121 | :return: (k samples of data, k \choose 2 groundtruth permutation matrices) 122 | """ 123 | if k != 2: 124 | raise NotImplementedError( 125 | f"No strategy implemented to sample {k} graphs from SPair dataset. So far only k=2 is possible." 126 | ) 127 | 128 | if cls is None: 129 | cls = self.classes[random.randrange(0, len(self.classes))] 130 | ann_files = self.ann_files_filtered_cls_dict[cls] 131 | elif type(cls) == int: 132 | cls = self.classes[cls] 133 | ann_files = self.ann_files_filtered_cls_dict[cls] 134 | else: 135 | assert type(cls) == str 136 | ann_files = self.ann_files_filtered_cls_dict[cls] 137 | 138 | # get pre-processed images 139 | 140 | assert len(ann_files) > 0 141 | if idx is None: 142 | ann_file = random.choice(ann_files) + ".json" 143 | else: 144 | ann_file = ann_files[idx] + ".json" 145 | with open(os.path.join(self.pair_ann_path, self.sets, ann_file)) as f: 146 | annotation = json.load(f) 147 | 148 | category = annotation["category"] 149 | if cls is not None and not self.combine_classes: 150 | assert cls == category 151 | assert all(annotation[key] == value for key, value in self.difficulty_params.items()) 152 | 153 | if mode == "intersection": 154 | assert len(annotation["src_kps"]) == len(annotation["trg_kps"]) 155 | num_kps = len(annotation["src_kps"]) 156 | perm_mat_init = np.eye(num_kps) 157 | anno_list = [] 158 | perm_list = [] 159 | 160 | for st in ("src", "trg"): 161 | if shuffle: 162 | perm = np.random.permutation(np.arange(num_kps)) 163 | else: 164 | perm = np.arange(num_kps) 165 | kps = annotation[f"{st}_kps"] 166 | img_path = os.path.join(self.image_path, category, annotation[f"{st}_imname"]) 167 | img, kps = self.rescale_im_and_kps(img_path, kps) 168 | kps_permuted = [kps[i] for i in perm] 169 | anno_dict = dict(image=img, keypoints=kps_permuted) 170 | anno_list.append(anno_dict) 171 | perm_list.append(perm) 172 | 173 | perm_mat = perm_mat_init[perm_list[0]][:, perm_list[1]] 174 | else: 175 | raise NotImplementedError(f"Unknown sampling strategy {mode}") 176 | 177 | return anno_list, [perm_mat] 178 | 179 | def rescale_im_and_kps(self, img_path, kps): 180 | 181 | with Image.open(str(img_path)) as img: 182 | w, h = img.size 183 | img = img.resize(self.obj_resize, resample=Image.BICUBIC) 184 | 185 | keypoint_list = [] 186 | for kp in kps: 187 | x = kp[0] * self.obj_resize[0] / w 188 | y = kp[1] * self.obj_resize[1] / h 189 | keypoint_list.append(dict(x=x, y=y)) 190 | 191 | return img, keypoint_list 192 | 193 | 194 | if __name__ == "__main__": 195 | trn_dataset = SPair71k("train", (256, 256)) 196 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import time 4 | from pathlib import Path 5 | 6 | from data.data_loader_multigraph import GMDataset, get_dataloader 7 | 8 | from utils.evaluation_metric import matching_accuracy_from_lists, f1_score, get_pos_neg_from_lists 9 | 10 | from eval import eval_model 11 | 12 | from BB_GM.model import Net 13 | from utils.config import cfg 14 | 15 | from utils.utils import update_params_from_cmdline 16 | 17 | class HammingLoss(torch.nn.Module): 18 | def forward(self, suggested, target): 19 | errors = suggested * (1.0 - target) + (1.0 - suggested) * target 20 | return errors.mean(dim=0).sum() 21 | 22 | 23 | lr_schedules = { 24 | "long_halving": (10, (2, 4, 6, 8, 10), 0.5), 25 | "short_halving": (2, (1,), 0.5), 26 | "long_nodrop": (10, (10,), 1.0), 27 | "minirun": (1, (10,), 1.0), 28 | } 29 | 30 | 31 | def train_eval_model(model, criterion, optimizer, dataloader, num_epochs, resume=False, start_epoch=0): 32 | print("Start training...") 33 | 34 | since = time.time() 35 | dataloader["train"].dataset.set_num_graphs(cfg.TRAIN.num_graphs_in_matching_instance) 36 | dataset_size = len(dataloader["train"].dataset) 37 | 38 | 39 | device = next(model.parameters()).device 40 | print("model on device: {}".format(device)) 41 | 42 | checkpoint_path = Path(cfg.model_dir) / "params" 43 | if not checkpoint_path.exists(): 44 | checkpoint_path.mkdir(parents=True) 45 | 46 | if resume: 47 | params_path = os.path.join(cfg.warmstart_path, f"params.pt") 48 | print("Loading model parameters from {}".format(params_path)) 49 | model.load_state_dict(torch.load(params_path)) 50 | 51 | optim_path = os.path.join(cfg.warmstart_path, f"optim.pt") 52 | print("Loading optimizer state from {}".format(optim_path)) 53 | optimizer.load_state_dict(torch.load(optim_path)) 54 | 55 | # Evaluation only 56 | if cfg.evaluate_only: 57 | assert resume 58 | print(f"Evaluating without training...") 59 | accs, f1_scores = eval_model(model, dataloader["test"]) 60 | acc_dict = { 61 | "acc_{}".format(cls): single_acc for cls, single_acc in zip(dataloader["train"].dataset.classes, accs) 62 | } 63 | f1_dict = { 64 | "f1_{}".format(cls): single_f1_score 65 | for cls, single_f1_score in zip(dataloader["train"].dataset.classes, f1_scores) 66 | } 67 | acc_dict.update(f1_dict) 68 | acc_dict["matching_accuracy"] = torch.mean(accs) 69 | acc_dict["f1_score"] = torch.mean(f1_scores) 70 | 71 | time_elapsed = time.time() - since 72 | print( 73 | "Evaluation complete in {:.0f}h {:.0f}m {:.0f}s".format( 74 | time_elapsed // 3600, (time_elapsed // 60) % 60, time_elapsed % 60 75 | ) 76 | ) 77 | return model, acc_dict 78 | 79 | _, lr_milestones, lr_decay = lr_schedules[cfg.TRAIN.lr_schedule] 80 | scheduler = optim.lr_scheduler.MultiStepLR( 81 | optimizer, milestones=lr_milestones, gamma=lr_decay 82 | ) 83 | 84 | for epoch in range(start_epoch, num_epochs): 85 | print("Epoch {}/{}".format(epoch, num_epochs - 1)) 86 | print("-" * 10) 87 | 88 | model.train() # Set model to training mode 89 | 90 | print("lr = " + ", ".join(["{:.2e}".format(x["lr"]) for x in optimizer.param_groups])) 91 | 92 | epoch_loss = 0.0 93 | running_loss = 0.0 94 | running_acc = 0.0 95 | epoch_acc = 0.0 96 | running_f1 = 0.0 97 | epoch_f1 = 0.0 98 | running_since = time.time() 99 | iter_num = 0 100 | 101 | # Iterate over data. 102 | for inputs in dataloader["train"]: 103 | data_list = [_.cuda() for _ in inputs["images"]] 104 | points_gt_list = [_.cuda() for _ in inputs["Ps"]] 105 | n_points_gt_list = [_.cuda() for _ in inputs["ns"]] 106 | edges_list = [_.to("cuda") for _ in inputs["edges"]] 107 | perm_mat_list = [perm_mat.cuda() for perm_mat in inputs["gt_perm_mat"]] 108 | 109 | iter_num = iter_num + 1 110 | 111 | # zero the parameter gradients 112 | optimizer.zero_grad() 113 | 114 | with torch.set_grad_enabled(True): 115 | # forward 116 | s_pred_list = model(data_list, points_gt_list, edges_list, n_points_gt_list, perm_mat_list) 117 | 118 | loss = sum([criterion(s_pred, perm_mat) for s_pred, perm_mat in zip(s_pred_list, perm_mat_list)]) 119 | loss /= len(s_pred_list) 120 | 121 | # backward + optimize 122 | loss.backward() 123 | optimizer.step() 124 | 125 | tp, fp, fn = get_pos_neg_from_lists(s_pred_list, perm_mat_list) 126 | f1 = f1_score(tp, fp, fn) 127 | acc, _, __ = matching_accuracy_from_lists(s_pred_list, perm_mat_list) 128 | 129 | # statistics 130 | bs = perm_mat_list[0].size(0) 131 | running_loss += loss.item() * bs # multiply with batch size 132 | epoch_loss += loss.item() * bs 133 | running_acc += acc.item() * bs 134 | epoch_acc += acc.item() * bs 135 | running_f1 += f1.item() * bs 136 | epoch_f1 += f1.item() * bs 137 | 138 | if iter_num % cfg.STATISTIC_STEP == 0: 139 | running_speed = cfg.STATISTIC_STEP * bs / (time.time() - running_since) 140 | loss_avg = running_loss / cfg.STATISTIC_STEP / bs 141 | acc_avg = running_acc / cfg.STATISTIC_STEP / bs 142 | f1_avg = running_f1 / cfg.STATISTIC_STEP / bs 143 | print( 144 | "Epoch {:<4} Iter {:<4} {:>4.2f}sample/s Loss={:<8.4f} Accuracy={:<2.3} F1={:<2.3}".format( 145 | epoch, iter_num, running_speed, loss_avg, acc_avg, f1_avg 146 | ) 147 | ) 148 | 149 | running_acc = 0.0 150 | running_f1 = 0.0 151 | running_loss = 0.0 152 | running_since = time.time() 153 | 154 | epoch_loss = epoch_loss / dataset_size 155 | epoch_acc = epoch_acc / dataset_size 156 | epoch_f1 = epoch_f1 / dataset_size 157 | 158 | if cfg.save_checkpoint: 159 | base_path = Path(checkpoint_path / "{:04}".format(epoch + 1)) 160 | Path(base_path).mkdir(parents=True, exist_ok=True) 161 | path = str(base_path / "params.pt") 162 | torch.save(model.state_dict(), path) 163 | torch.save(optimizer.state_dict(), str(base_path / "optim.pt")) 164 | 165 | print( 166 | "Over whole epoch {:<4} -------- Loss: {:.4f} Accuracy: {:.3f} F1: {:.3f}".format( 167 | epoch, epoch_loss, epoch_acc, epoch_f1 168 | ) 169 | ) 170 | print() 171 | 172 | # Eval in each epoch 173 | accs, f1_scores = eval_model(model, dataloader["test"]) 174 | acc_dict = { 175 | "acc_{}".format(cls): single_acc for cls, single_acc in zip(dataloader["train"].dataset.classes, accs) 176 | } 177 | f1_dict = { 178 | "f1_{}".format(cls): single_f1_score 179 | for cls, single_f1_score in zip(dataloader["train"].dataset.classes, f1_scores) 180 | } 181 | acc_dict.update(f1_dict) 182 | acc_dict["matching_accuracy"] = torch.mean(accs) 183 | acc_dict["f1_score"] = torch.mean(f1_scores) 184 | 185 | scheduler.step() 186 | 187 | time_elapsed = time.time() - since 188 | print( 189 | "Training complete in {:.0f}h {:.0f}m {:.0f}s".format( 190 | time_elapsed // 3600, (time_elapsed // 60) % 60, time_elapsed % 60 191 | ) 192 | ) 193 | 194 | return model, acc_dict 195 | 196 | 197 | if __name__ == "__main__": 198 | from utils.dup_stdout_manager import DupStdoutFileManager 199 | 200 | cfg = update_params_from_cmdline(default_params=cfg) 201 | import json 202 | import os 203 | 204 | os.makedirs(cfg.model_dir, exist_ok=True) 205 | with open(os.path.join(cfg.model_dir, "settings.json"), "w") as f: 206 | json.dump(cfg, f) 207 | 208 | torch.manual_seed(cfg.RANDOM_SEED) 209 | 210 | dataset_len = {"train": cfg.TRAIN.EPOCH_ITERS * cfg.BATCH_SIZE, "test": cfg.EVAL.SAMPLES} 211 | image_dataset = { 212 | x: GMDataset(cfg.DATASET_NAME, sets=x, length=dataset_len[x], obj_resize=(256, 256)) for x in ("train", "test") 213 | } 214 | dataloader = {x: get_dataloader(image_dataset[x], fix_seed=(x == "test")) for x in ("train", "test")} 215 | 216 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 217 | 218 | model = Net() 219 | model = model.cuda() 220 | 221 | 222 | criterion = HammingLoss() 223 | 224 | backbone_params = list(model.node_layers.parameters()) + list(model.edge_layers.parameters()) 225 | backbone_params += list(model.final_layers.parameters()) 226 | 227 | backbone_ids = [id(item) for item in backbone_params] 228 | 229 | new_params = [param for param in model.parameters() if id(param) not in backbone_ids] 230 | opt_params = [ 231 | dict(params=backbone_params, lr=cfg.TRAIN.LR * 0.01), 232 | dict(params=new_params, lr=cfg.TRAIN.LR), 233 | ] 234 | optimizer = optim.Adam(opt_params) 235 | 236 | if not Path(cfg.model_dir).exists(): 237 | Path(cfg.model_dir).mkdir(parents=True) 238 | 239 | num_epochs, _, __ = lr_schedules[cfg.TRAIN.lr_schedule] 240 | with DupStdoutFileManager(str(Path(cfg.model_dir) / ("train_log.log"))) as _: 241 | model, accs = train_eval_model( 242 | model, 243 | criterion, 244 | optimizer, 245 | dataloader, 246 | num_epochs=num_epochs, 247 | resume=cfg.warmstart_path is not None, 248 | start_epoch=0, 249 | ) 250 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import networkx as nx 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | from utils import latex_utils as lu 11 | from torch_geometric.data import Data 12 | from torch_geometric.utils.convert import to_networkx 13 | 14 | from utils.config import cfg 15 | from utils.decorators import input_to_numpy 16 | from utils.utils import UnNormalize, n_and_l_iter_parallel, lexico_iter 17 | 18 | 19 | colors = [ 20 | (0.368, 0.507, 0.71), 21 | (0.881, 0.611, 0.142), 22 | (0.56, 0.692, 0.195), 23 | (0.923, 0.386, 0.209), 24 | (0.528, 0.471, 0.701), 25 | (0.772, 0.432, 0.102), 26 | (0.364, 0.619, 0.782), 27 | (0.572, 0.586, 0.0), 28 | ] 29 | 30 | 31 | def visualize_graph( 32 | graph, pos, im, suffix, idx, vis_dir, mode="full", edge_colors=None, node_colors=None, true_graph=None 33 | ): 34 | im = np.rollaxis(im, axis=0, start=3) 35 | 36 | network = to_networkx(graph) 37 | 38 | plt.figure() 39 | plt.imshow(im) 40 | 41 | if mode == "only_edges": 42 | true_network = to_networkx(true_graph) 43 | nx.draw_networkx_edges( 44 | true_network, 45 | pos=pos, 46 | arrowstyle="-", 47 | style="dashed", 48 | alpha=0.8, 49 | node_size=15, 50 | edge_color="white", 51 | arrowsize=1, 52 | connectionstyle="arc3,rad=0.2", 53 | ) 54 | nx.draw_networkx( 55 | network, 56 | pos=pos, 57 | cmap=plt.get_cmap("inferno"), 58 | node_color="white", 59 | node_size=15, 60 | linewidths=1, 61 | arrowstyle="-", 62 | edge_color=edge_colors, 63 | arrowsize=1, 64 | with_labels=False, 65 | connectionstyle="arc3,rad=0.2", 66 | vmin=0.0, 67 | vmax=1.0, 68 | width=2.0, 69 | ) 70 | suffix = suffix + "_" + str(int(time.time() * 100)) 71 | elif mode == "triang": 72 | 73 | node_colors = np.linspace(0, 1, len(network.nodes)) 74 | nx.draw_networkx( 75 | network, 76 | pos=pos, 77 | cmap=plt.get_cmap("Set1"), 78 | node_color=node_colors, 79 | node_size=15, 80 | linewidths=5, 81 | arrowsize=1, 82 | with_labels=False, 83 | vmin=0.0, 84 | vmax=1.0, 85 | arrowstyle="-", 86 | ) 87 | elif mode == "full": 88 | 89 | node_colors = np.linspace(0, 1, len(network.nodes)) 90 | edge_labels = {graph.edge_index[:, i]: f"{i}" for i in range(graph.edge_index.shape[1])} 91 | 92 | nx.draw_networkx( 93 | network, pos=pos, cmap=plt.get_cmap("Set1"), node_color=node_colors, node_size=100, linewidths=10 94 | ) 95 | nx.draw_networkx_edge_labels(network, pos=pos, edge_labels=edge_labels, label_pos=0.3) 96 | elif mode == "only_nodes": 97 | node_colors = np.linspace(0, 1, len(network.nodes)) 98 | nx.draw_networkx_nodes( 99 | network, pos=pos, cmap=plt.get_cmap("Set1"), node_color=node_colors, node_size=15, linewidths=1 100 | ) 101 | elif mode == "nograph": 102 | pass 103 | else: 104 | raise NotImplementedError 105 | 106 | filename = os.path.join(vis_dir, f"{idx}_{suffix}.png") 107 | plt.savefig(filename) 108 | plt.close() 109 | abs_filename = os.path.abspath(filename) 110 | return abs_filename 111 | 112 | 113 | @input_to_numpy 114 | def easy_visualize( 115 | graphs, 116 | positions, 117 | n_points, 118 | images, 119 | unary_costs, 120 | quadratic_costs, 121 | matchings, 122 | true_matchings, 123 | string_info, 124 | reduced_vis, 125 | produce_pdf=True, 126 | ): 127 | """ 128 | 129 | :param graphs: [num_graphs, bs, ...] 130 | :param positions: [num_graphs, bs, 2, max_n_p] 131 | :param n_points: [num_graphs, bs, n_p] 132 | :param images: [num_graphs, bs, size, size] 133 | :param unary_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p] 134 | :param quadratic_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p] 135 | :param matchings: [num_graphs \choose 2, bs, max_n_p, max_n_p] 136 | """ 137 | positions = [[p[:num] for p, num in zip(pos, n_p)] for pos, n_p in zip(positions, n_points)] 138 | matchings = [ 139 | [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)] 140 | for match, (n_p_x_batch, n_p_y_batch) in zip(matchings, lexico_iter(n_points)) 141 | ] 142 | true_matchings = [ 143 | [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)] 144 | for match, (n_p_x_batch, n_p_y_batch) in zip(true_matchings, lexico_iter(n_points)) 145 | ] 146 | 147 | visualization_string = "visualization" 148 | latex_file = lu.LatexFile(visualization_string) 149 | vis_dir = os.path.join(cfg.model_dir, visualization_string) 150 | unnorm = UnNormalize(cfg.NORM_MEANS, cfg.NORM_STD) 151 | images = [[unnorm(im) for im in im_b] for im_b in images] 152 | 153 | if not os.path.exists(vis_dir): 154 | os.makedirs(vis_dir) 155 | 156 | batch = zip( 157 | zip(*graphs), 158 | zip(*positions), 159 | zip(*images), 160 | zip(*unary_costs), 161 | zip(*quadratic_costs), 162 | zip(*matchings), 163 | zip(*true_matchings), 164 | ) 165 | for b, (graph_l, pos_l, im_l, unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l) in enumerate(batch): 166 | if not reduced_vis: 167 | files_single = [] 168 | for i, (graph, pos, im) in enumerate(zip(graph_l, pos_l, im_l)): 169 | f_single = visualize_graph(graph, pos, im, suffix=f"single_{i}", idx=b, vis_dir=vis_dir) 170 | f_single_simple = visualize_graph( 171 | graph, pos, im, suffix=f"single_simple_{i}", idx=b, vis_dir=vis_dir, mode="triang" 172 | ) 173 | files_single.append(f_single) 174 | files_single.append(f_single_simple) 175 | latex_file.add_section_from_figures( 176 | name=f"Single Graphs ({b})", list_of_filenames=files_single, common_scale=0.7 177 | ) 178 | 179 | files_mge = [] 180 | for ( 181 | unary_c, 182 | quadratic_c, 183 | matching, 184 | true_matching, 185 | (graph_src, graph_tgt), 186 | (pos_src, pos_tgt), 187 | (im_src, im_tgt), 188 | (i, j), 189 | ) in n_and_l_iter_parallel( 190 | n=[unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l], l=[graph_l, pos_l, im_l], enum=True 191 | ): 192 | im_mge, p_mge, graph_mge, edges_corrct_mge, node_colors_mge, true_graph = merge_images_and_graphs( 193 | graph_src, graph_tgt, pos_src, pos_tgt, im_src, im_tgt, new_edges=matching, true_edges=true_matching 194 | ) 195 | f_mge = visualize_graph( 196 | graph_mge, 197 | p_mge, 198 | im_mge, 199 | suffix=f"mge_{i}-{j}", 200 | idx=b, 201 | vis_dir=vis_dir, 202 | mode="only_edges", 203 | edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge], 204 | node_colors=node_colors_mge, 205 | true_graph=true_graph, 206 | ) 207 | files_mge.append(f_mge) 208 | 209 | if not reduced_vis: 210 | f_mge_nodes = visualize_graph( 211 | graph_mge, 212 | p_mge, 213 | im_mge, 214 | suffix=f"mge_nodes_{i}-{j}", 215 | idx=b, 216 | vis_dir=vis_dir, 217 | mode="only_nodes", 218 | edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge], 219 | node_colors=node_colors_mge, 220 | true_graph=true_graph, 221 | ) 222 | files_mge.append(f_mge_nodes) 223 | costs_and_matchings = dict( 224 | unary_cost=unary_c, quadratic_cost=quadratic_c, matchings=matching, true_matching=true_matching 225 | ) 226 | for key, value in costs_and_matchings.items(): 227 | latex_file.add_section_from_dataframe( 228 | name=f"{key} ({b}, {i}-{j})", dataframe=pd.DataFrame(value).round(2) 229 | ) 230 | 231 | latex_file.add_section_from_figures(name=f"Matched Graphs ({b})", list_of_filenames=files_mge, common_scale=0.7) 232 | 233 | time = "{date:%Y-%m-%d_%H-%M-%S}".format(date=datetime.datetime.now()) 234 | suffix = f"{string_info}_{time}" 235 | output_file = os.path.join(vis_dir, f"{visualization_string}_{suffix}.pdf") 236 | if produce_pdf: 237 | latex_file.produce_pdf(output_file=output_file) 238 | 239 | 240 | def merge_images_and_graphs(graph_src, graph_tgt, p_src, p_tgt, im_src, im_tgt, new_edges, true_edges): 241 | pos_offset = (im_src.shape[1], 0) 242 | merged_pos = np.concatenate([p_src, p_tgt + np.array([pos_offset] * p_tgt.shape[0])]) 243 | merged_im = np.concatenate([im_src, im_tgt], 2) 244 | merged_graph, edges_correct, node_colors, true_graph = merge_graphs(graph_src, graph_tgt, new_edges, true_edges) 245 | return merged_im, merged_pos, merged_graph, edges_correct, node_colors, true_graph 246 | 247 | 248 | def merge_graphs(graph1, graph2, new_edges, true_edges): 249 | merged_x = torch.cat([graph1.x, graph2.x], 0) 250 | 251 | def color_gen(): 252 | for i in np.linspace(0.4, 1, max(new_edges.shape)): 253 | yield i 254 | 255 | edge_list = [[], []] 256 | true_edge_list = [[], []] 257 | edges_correct = [] 258 | node_colors = np.zeros(merged_x.shape[0]) 259 | offset = new_edges.shape[0] 260 | color = color_gen() 261 | for i in range(new_edges.shape[0]): 262 | for j in range(new_edges.shape[1]): 263 | if new_edges[i, j] == 1: 264 | edge_list[0].append(i) 265 | edge_list[1].append(j + offset) 266 | edges_correct.append(true_edges[i, j]) 267 | if true_edges[i, j]: 268 | c = next(color) 269 | node_colors[i], node_colors[j + offset] = c, c 270 | true_edge_list[0].append(i) 271 | true_edge_list[1].append(j + offset) 272 | 273 | new_edges = torch.tensor(edge_list, device=graph1.edge_index.device) 274 | true_edges = torch.tensor(true_edge_list, device=graph1.edge_index.device) 275 | merged_graph = Data(x=merged_x, edge_index=new_edges) 276 | true_merged_graph = Data(x=merged_x, edge_index=true_edges) 277 | return merged_graph, np.array(edges_correct), node_colors, true_merged_graph 278 | -------------------------------------------------------------------------------- /data/willow_obj.py: -------------------------------------------------------------------------------- 1 | import random 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import scipy.io as sio 6 | from PIL import Image 7 | 8 | from data.base_dataset import BaseDataset 9 | from utils.config import cfg 10 | from utils.utils import lexico_iter 11 | 12 | 13 | class WillowObject(BaseDataset): 14 | def __init__(self, sets, obj_resize): 15 | """ 16 | :param sets: 'train' or 'test' 17 | :param obj_resize: resized object size 18 | """ 19 | super(WillowObject, self).__init__() 20 | self.classes = cfg.WILLOW.CLASSES 21 | self.kpt_len = [cfg.WILLOW.KPT_LEN for _ in cfg.WILLOW.CLASSES] 22 | 23 | self.root_path = Path(cfg.WILLOW.ROOT_DIR) 24 | self.obj_resize = obj_resize 25 | 26 | assert sets == "train" or "test", "No match found for dataset {}".format(sets) 27 | self.split_offset = cfg.WILLOW.TRAIN_OFFSET 28 | self.train_len = cfg.WILLOW.TRAIN_NUM 29 | self.sets = sets 30 | 31 | self.mat_list = [] 32 | for cls_name in self.classes: 33 | assert type(cls_name) is str 34 | cls_mat_list = [p for p in (self.root_path / cls_name).glob("*.mat")] 35 | ori_len = len(cls_mat_list) 36 | assert ori_len > 0, "No data found for WILLOW Object Class. Is the dataset installed correctly?" 37 | if self.split_offset % ori_len + self.train_len <= ori_len: 38 | if sets == "train": 39 | self.mat_list.append( 40 | cls_mat_list[self.split_offset % ori_len : (self.split_offset + self.train_len) % ori_len] 41 | ) 42 | else: 43 | self.mat_list.append( 44 | cls_mat_list[: self.split_offset % ori_len] 45 | + cls_mat_list[(self.split_offset + self.train_len) % ori_len :] 46 | ) 47 | else: 48 | if sets == "train": 49 | self.mat_list.append( 50 | cls_mat_list[: (self.split_offset + self.train_len) % ori_len - ori_len] 51 | + cls_mat_list[self.split_offset % ori_len :] 52 | ) 53 | else: 54 | self.mat_list.append( 55 | cls_mat_list[ 56 | (self.split_offset + self.train_len) % ori_len - ori_len : self.split_offset % ori_len 57 | ] 58 | ) 59 | 60 | def get_k_samples(self, idx, k, mode, cls=None, shuffle=True, num_iterations=200): 61 | """ 62 | Randomly get a sample of k objects from VOC-Berkeley keypoints dataset 63 | :param idx: Index of datapoint to sample, None for random sampling 64 | :param k: number of datapoints in sample 65 | :param mode: sampling strategy 66 | :param cls: None for random class, or specify for a certain set 67 | :param shuffle: random shuffle the keypoints 68 | :param num_iterations: maximum number of iterations for sampling a datapoint 69 | :return: (k samples of data, k \choose 2 groundtruth permutation matrices) 70 | """ 71 | if idx is not None: 72 | raise NotImplementedError("No indexed sampling implemented for willow.") 73 | if cls is None: 74 | cls = random.randrange(0, len(self.classes)) 75 | elif type(cls) == str: 76 | cls = self.classes.index(cls) 77 | assert type(cls) == int and 0 <= cls < len(self.classes) 78 | 79 | if mode == "superset" and k == 2: 80 | anno_list, perm_mat = self.get_pair_superset(cls=cls, shuffle=shuffle, num_iterations=num_iterations) 81 | return anno_list, [perm_mat] 82 | 83 | anno_list = [] 84 | for xml_name in random.sample(self.mat_list[cls], k): 85 | anno_dict = self.__get_anno_dict(xml_name, cls) 86 | if shuffle: 87 | random.shuffle(anno_dict["keypoints"]) 88 | anno_list.append(anno_dict) 89 | 90 | perm_mat_list = [ 91 | np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) for anno_pair in lexico_iter(anno_list) 92 | ] 93 | for n, (s1, s2) in enumerate(lexico_iter(anno_list)): 94 | row_list = [] 95 | col_list = [] 96 | for i, keypoint in enumerate(s1["keypoints"]): 97 | for j, _keypoint in enumerate(s2["keypoints"]): 98 | if keypoint["name"] == _keypoint["name"]: 99 | perm_mat_list[n][i, j] = 1 100 | row_list.append(i) 101 | col_list.append(j) 102 | break 103 | if mode == "all": 104 | pass 105 | elif mode == "rectangle" and k == 2: # so far only implemented for k = 2 106 | row_list.sort() 107 | perm_mat_list[n] = perm_mat_list[n][row_list, :] 108 | s1["keypoints"] = [s1["keypoints"][i] for i in row_list] 109 | assert perm_mat_list[n].size == len(s1["keypoints"]) * len(s2["keypoints"]) 110 | elif mode == "intersection" and k == 2: # so far only implemented for k = 2 111 | row_list.sort() 112 | col_list.sort() 113 | perm_mat_list[n] = perm_mat_list[n][row_list, :] 114 | perm_mat_list[n] = perm_mat_list[n][:, col_list] 115 | s1["keypoints"] = [s1["keypoints"][i] for i in row_list] 116 | s2["keypoints"] = [s2["keypoints"][j] for j in col_list] 117 | else: 118 | raise NotImplementedError(f"Unknown sampling strategy {mode}") 119 | 120 | return anno_list, perm_mat_list 121 | 122 | def get_pair_superset(self, cls=None, shuffle=True, num_iterations=200): 123 | """ 124 | Randomly get a pair of objects from VOC-Berkeley keypoints dataset 125 | :param cls: None for random class, or specify for a certain set 126 | :param shuffle: random shuffle the keypoints 127 | :return: (pair of data, groundtruth permutation matrix) 128 | """ 129 | if cls is None: 130 | cls = random.randrange(0, len(self.classes)) 131 | elif type(cls) == str: 132 | cls = self.classes.index(cls) 133 | assert type(cls) == int and 0 <= cls < len(self.classes) 134 | 135 | anno_pair = None 136 | 137 | anno_dict_1 = self.__get_anno_dict(random.sample(self.mat_list[cls], 1)[0], cls) 138 | if shuffle: 139 | random.shuffle(anno_dict_1["keypoints"]) 140 | keypoints_1 = set([kp["name"] for kp in anno_dict_1["keypoints"]]) 141 | 142 | for xml_name in random.sample(self.mat_list[cls], min(len(self.mat_list[cls]), num_iterations)): 143 | anno_dict_2 = self.__get_anno_dict(xml_name, cls) 144 | if shuffle: 145 | random.shuffle(anno_dict_2["keypoints"]) 146 | keypoints_2 = set([kp["name"] for kp in anno_dict_2["keypoints"]]) 147 | if keypoints_1.issubset(keypoints_2): 148 | anno_pair = [anno_dict_1, anno_dict_2] 149 | break 150 | 151 | if anno_pair is None: 152 | return self.get_pair_superset(cls, shuffle, num_iterations) 153 | 154 | perm_mat = np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) 155 | row_list = [] 156 | col_list = [] 157 | for i, keypoint in enumerate(anno_pair[0]["keypoints"]): 158 | for j, _keypoint in enumerate(anno_pair[1]["keypoints"]): 159 | if keypoint["name"] == _keypoint["name"]: 160 | perm_mat[i, j] = 1 161 | row_list.append(i) 162 | col_list.append(j) 163 | break 164 | 165 | assert len(row_list) == len(anno_pair[0]["keypoints"]) 166 | 167 | return anno_pair, perm_mat 168 | 169 | def get_pair(self, cls=None, shuffle=True): 170 | """ 171 | Randomly get a pair of objects from WILLOW-object dataset 172 | :param cls: None for random class, or specify for a certain set 173 | :param shuffle: random shuffle the keypoints 174 | :return: (pair of data, groundtruth permutation matrix) 175 | """ 176 | if cls is None: 177 | cls = random.randrange(0, len(self.classes)) 178 | elif type(cls) == str: 179 | cls = self.classes.index(cls) 180 | assert type(cls) == int and 0 <= cls < len(self.classes) 181 | 182 | anno_pair = [] 183 | for mat_name in random.sample(self.mat_list[cls], 2): 184 | anno_dict = self.__get_anno_dict(mat_name, cls) 185 | if shuffle: 186 | random.shuffle(anno_dict["keypoints"]) 187 | anno_pair.append(anno_dict) 188 | 189 | perm_mat = np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) 190 | row_list = [] 191 | col_list = [] 192 | for i, keypoint in enumerate(anno_pair[0]["keypoints"]): 193 | for j, _keypoint in enumerate(anno_pair[1]["keypoints"]): 194 | if keypoint["name"] == _keypoint["name"]: 195 | perm_mat[i, j] = 1 196 | row_list.append(i) 197 | col_list.append(j) 198 | break 199 | row_list.sort() 200 | col_list.sort() 201 | perm_mat = perm_mat[row_list, :] 202 | perm_mat = perm_mat[:, col_list] 203 | anno_pair[0]["keypoints"] = [anno_pair[0]["keypoints"][i] for i in row_list] 204 | anno_pair[1]["keypoints"] = [anno_pair[1]["keypoints"][j] for j in col_list] 205 | 206 | return anno_pair, perm_mat 207 | 208 | def __get_anno_dict(self, mat_file, cls): 209 | """ 210 | Get an annotation dict from .mat annotation 211 | """ 212 | assert mat_file.exists(), "{} does not exist.".format(mat_file) 213 | 214 | img_name = mat_file.stem + ".png" 215 | img_file = mat_file.parent / img_name 216 | 217 | struct = sio.loadmat(mat_file.open("rb")) 218 | kpts = struct["pts_coord"] 219 | 220 | with Image.open(str(img_file)) as img: 221 | ori_sizes = img.size 222 | obj = img.resize(self.obj_resize, resample=Image.BICUBIC) 223 | xmin = 0 224 | ymin = 0 225 | w = ori_sizes[0] 226 | h = ori_sizes[1] 227 | 228 | keypoint_list = [] 229 | for idx, keypoint in enumerate(np.split(kpts, kpts.shape[1], axis=1)): 230 | attr = {"name": idx} 231 | attr["x"] = float(keypoint[0]) * self.obj_resize[0] / w 232 | attr["y"] = float(keypoint[1]) * self.obj_resize[1] / h 233 | keypoint_list.append(attr) 234 | 235 | anno_dict = dict() 236 | anno_dict["image"] = obj 237 | anno_dict["keypoints"] = keypoint_list 238 | anno_dict["bounds"] = xmin, ymin, w, h 239 | anno_dict["ori_sizes"] = ori_sizes 240 | anno_dict["cls"] = cls 241 | 242 | return anno_dict 243 | -------------------------------------------------------------------------------- /data/pascal_voc.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import random 3 | import xml.etree.ElementTree as ET 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from utils.config import cfg 10 | from utils.utils import lexico_iter 11 | 12 | anno_path = cfg.VOC2011.KPT_ANNO_DIR 13 | img_path = cfg.VOC2011.ROOT_DIR + "JPEGImages" 14 | ori_anno_path = cfg.VOC2011.ROOT_DIR + "Annotations" 15 | set_path = cfg.VOC2011.SET_SPLIT 16 | cache_path = cfg.CACHE_PATH 17 | 18 | KPT_NAMES = { 19 | "cat": [ 20 | "L_B_Elbow", 21 | "L_B_Paw", 22 | "L_EarBase", 23 | "L_Eye", 24 | "L_F_Elbow", 25 | "L_F_Paw", 26 | "Nose", 27 | "R_B_Elbow", 28 | "R_B_Paw", 29 | "R_EarBase", 30 | "R_Eye", 31 | "R_F_Elbow", 32 | "R_F_Paw", 33 | "TailBase", 34 | "Throat", 35 | "Withers", 36 | ], 37 | "bottle": ["L_Base", "L_Neck", "L_Shoulder", "L_Top", "R_Base", "R_Neck", "R_Shoulder", "R_Top"], 38 | "horse": [ 39 | "L_B_Elbow", 40 | "L_B_Paw", 41 | "L_EarBase", 42 | "L_Eye", 43 | "L_F_Elbow", 44 | "L_F_Paw", 45 | "Nose", 46 | "R_B_Elbow", 47 | "R_B_Paw", 48 | "R_EarBase", 49 | "R_Eye", 50 | "R_F_Elbow", 51 | "R_F_Paw", 52 | "TailBase", 53 | "Throat", 54 | "Withers", 55 | ], 56 | "motorbike": [ 57 | "B_WheelCenter", 58 | "B_WheelEnd", 59 | "ExhaustPipeEnd", 60 | "F_WheelCenter", 61 | "F_WheelEnd", 62 | "HandleCenter", 63 | "L_HandleTip", 64 | "R_HandleTip", 65 | "SeatBase", 66 | "TailLight", 67 | ], 68 | "boat": [ 69 | "Hull_Back_Bot", 70 | "Hull_Back_Top", 71 | "Hull_Front_Bot", 72 | "Hull_Front_Top", 73 | "Hull_Mid_Left_Bot", 74 | "Hull_Mid_Left_Top", 75 | "Hull_Mid_Right_Bot", 76 | "Hull_Mid_Right_Top", 77 | "Mast_Top", 78 | "Sail_Left", 79 | "Sail_Right", 80 | ], 81 | "tvmonitor": [ 82 | "B_Bottom_Left", 83 | "B_Bottom_Right", 84 | "B_Top_Left", 85 | "B_Top_Right", 86 | "F_Bottom_Left", 87 | "F_Bottom_Right", 88 | "F_Top_Left", 89 | "F_Top_Right", 90 | ], 91 | "cow": [ 92 | "L_B_Elbow", 93 | "L_B_Paw", 94 | "L_EarBase", 95 | "L_Eye", 96 | "L_F_Elbow", 97 | "L_F_Paw", 98 | "Nose", 99 | "R_B_Elbow", 100 | "R_B_Paw", 101 | "R_EarBase", 102 | "R_Eye", 103 | "R_F_Elbow", 104 | "R_F_Paw", 105 | "TailBase", 106 | "Throat", 107 | "Withers", 108 | ], 109 | "chair": [ 110 | "BackRest_Top_Left", 111 | "BackRest_Top_Right", 112 | "Leg_Left_Back", 113 | "Leg_Left_Front", 114 | "Leg_Right_Back", 115 | "Leg_Right_Front", 116 | "Seat_Left_Back", 117 | "Seat_Left_Front", 118 | "Seat_Right_Back", 119 | "Seat_Right_Front", 120 | ], 121 | "car": [ 122 | "L_B_RoofTop", 123 | "L_B_WheelCenter", 124 | "L_F_RoofTop", 125 | "L_F_WheelCenter", 126 | "L_HeadLight", 127 | "L_SideviewMirror", 128 | "L_TailLight", 129 | "R_B_RoofTop", 130 | "R_B_WheelCenter", 131 | "R_F_RoofTop", 132 | "R_F_WheelCenter", 133 | "R_HeadLight", 134 | "R_SideviewMirror", 135 | "R_TailLight", 136 | ], 137 | "person": [ 138 | "B_Head", 139 | "HeadBack", 140 | "L_Ankle", 141 | "L_Ear", 142 | "L_Elbow", 143 | "L_Eye", 144 | "L_Foot", 145 | "L_Hip", 146 | "L_Knee", 147 | "L_Shoulder", 148 | "L_Toes", 149 | "L_Wrist", 150 | "Nose", 151 | "R_Ankle", 152 | "R_Ear", 153 | "R_Elbow", 154 | "R_Eye", 155 | "R_Foot", 156 | "R_Hip", 157 | "R_Knee", 158 | "R_Shoulder", 159 | "R_Toes", 160 | "R_Wrist", 161 | ], 162 | "diningtable": [ 163 | "Bot_Left_Back", 164 | "Bot_Left_Front", 165 | "Bot_Right_Back", 166 | "Bot_Right_Front", 167 | "Top_Left_Back", 168 | "Top_Left_Front", 169 | "Top_Right_Back", 170 | "Top_Right_Front", 171 | ], 172 | "dog": [ 173 | "L_B_Elbow", 174 | "L_B_Paw", 175 | "L_EarBase", 176 | "L_Eye", 177 | "L_F_Elbow", 178 | "L_F_Paw", 179 | "Nose", 180 | "R_B_Elbow", 181 | "R_B_Paw", 182 | "R_EarBase", 183 | "R_Eye", 184 | "R_F_Elbow", 185 | "R_F_Paw", 186 | "TailBase", 187 | "Throat", 188 | "Withers", 189 | ], 190 | "bird": [ 191 | "Beak_Base", 192 | "Beak_Tip", 193 | "Left_Eye", 194 | "Left_Wing_Base", 195 | "Left_Wing_Tip", 196 | "Leg_Center", 197 | "Lower_Neck_Base", 198 | "Right_Eye", 199 | "Right_Wing_Base", 200 | "Right_Wing_Tip", 201 | "Tail_Tip", 202 | "Upper_Neck_Base", 203 | ], 204 | "bicycle": [ 205 | "B_WheelCenter", 206 | "B_WheelEnd", 207 | "B_WheelIntersection", 208 | "CranksetCenter", 209 | "F_WheelCenter", 210 | "F_WheelEnd", 211 | "F_WheelIntersection", 212 | "HandleCenter", 213 | "L_HandleTip", 214 | "R_HandleTip", 215 | "SeatBase", 216 | ], 217 | "train": [ 218 | "Base_Back_Left", 219 | "Base_Back_Right", 220 | "Base_Front_Left", 221 | "Base_Front_Right", 222 | "Roof_Back_Left", 223 | "Roof_Back_Right", 224 | "Roof_Front_Middle", 225 | ], 226 | "sheep": [ 227 | "L_B_Elbow", 228 | "L_B_Paw", 229 | "L_EarBase", 230 | "L_Eye", 231 | "L_F_Elbow", 232 | "L_F_Paw", 233 | "Nose", 234 | "R_B_Elbow", 235 | "R_B_Paw", 236 | "R_EarBase", 237 | "R_Eye", 238 | "R_F_Elbow", 239 | "R_F_Paw", 240 | "TailBase", 241 | "Throat", 242 | "Withers", 243 | ], 244 | "aeroplane": [ 245 | "Bot_Rudder", 246 | "Bot_Rudder_Front", 247 | "L_Stabilizer", 248 | "L_WingTip", 249 | "Left_Engine_Back", 250 | "Left_Engine_Front", 251 | "Left_Wing_Base", 252 | "NoseTip", 253 | "Nose_Bottom", 254 | "Nose_Top", 255 | "R_Stabilizer", 256 | "R_WingTip", 257 | "Right_Engine_Back", 258 | "Right_Engine_Front", 259 | "Right_Wing_Base", 260 | "Top_Rudder", 261 | ], 262 | "sofa": [ 263 | "Back_Base_Left", 264 | "Back_Base_Right", 265 | "Back_Top_Left", 266 | "Back_Top_Right", 267 | "Front_Base_Left", 268 | "Front_Base_Right", 269 | "Handle_Front_Left", 270 | "Handle_Front_Right", 271 | "Handle_Left_Junction", 272 | "Handle_Right_Junction", 273 | "Left_Junction", 274 | "Right_Junction", 275 | ], 276 | "pottedplant": ["Bottom_Left", "Bottom_Right", "Top_Back_Middle", "Top_Front_Middle", "Top_Left", "Top_Right"], 277 | "bus": ["L_B_Base", "L_B_RoofTop", "L_F_Base", "L_F_RoofTop", "R_B_Base", "R_B_RoofTop", "R_F_Base", "R_F_RoofTop"], 278 | } 279 | 280 | 281 | class PascalVOC: 282 | def __init__(self, sets, obj_resize): 283 | """ 284 | :param sets: 'train' or 'test' 285 | :param obj_resize: resized object size 286 | """ 287 | self.classes = cfg.VOC2011.CLASSES 288 | self.kpt_len = [len(KPT_NAMES[_]) for _ in cfg.VOC2011.CLASSES] 289 | 290 | self.classes_kpts = {cls: len(KPT_NAMES[cls]) for cls in self.classes} 291 | 292 | self.anno_path = Path(anno_path) 293 | self.img_path = Path(img_path) 294 | self.ori_anno_path = Path(ori_anno_path) 295 | self.obj_resize = obj_resize 296 | self.sets = sets 297 | 298 | assert sets in ["train", "test"], "No match found for dataset {}".format(sets) 299 | cache_name = "voc_db_" + sets + ".pkl" 300 | self.cache_path = Path(cache_path) 301 | self.cache_file = self.cache_path / cache_name 302 | if self.cache_file.exists(): 303 | with self.cache_file.open(mode="rb") as f: 304 | self.xml_list = pickle.load(f) 305 | print("xml list loaded from {}".format(self.cache_file)) 306 | else: 307 | print("Caching xml list to {}...".format(self.cache_file)) 308 | self.cache_path.mkdir(exist_ok=True, parents=True) 309 | with np.load(set_path, allow_pickle=True) as f: 310 | self.xml_list = f[sets] 311 | before_filter = sum([len(k) for k in self.xml_list]) 312 | self.filter_list() 313 | after_filter = sum([len(k) for k in self.xml_list]) 314 | with self.cache_file.open(mode="wb") as f: 315 | pickle.dump(self.xml_list, f) 316 | print("Filtered {} images to {}. Annotation saved.".format(before_filter, after_filter)) 317 | 318 | def filter_list(self): 319 | """ 320 | Filter out 'truncated', 'occluded' and 'difficult' images following the practice of previous works. 321 | In addition, this dataset has uncleaned label (in person category). They are omitted as suggested by README. 322 | """ 323 | for cls_id in range(len(self.classes)): 324 | to_del = [] 325 | for xml_name in self.xml_list[cls_id]: 326 | xml_comps = xml_name.split("/")[-1].strip(".xml").split("_") 327 | ori_xml_name = "_".join(xml_comps[:-1]) + ".xml" 328 | voc_idx = int(xml_comps[-1]) 329 | xml_file = self.ori_anno_path / ori_xml_name 330 | assert xml_file.exists(), "{} does not exist.".format(xml_file) 331 | tree = ET.parse(xml_file.open()) 332 | root = tree.getroot() 333 | obj = root.findall("object")[voc_idx - 1] 334 | 335 | difficult = obj.find("difficult") 336 | if difficult is not None: 337 | difficult = int(difficult.text) 338 | occluded = obj.find("occluded") 339 | if occluded is not None: 340 | occluded = int(occluded.text) 341 | truncated = obj.find("truncated") 342 | if truncated is not None: 343 | truncated = int(truncated.text) 344 | if difficult or occluded or truncated: 345 | to_del.append(xml_name) 346 | continue 347 | 348 | # Exclude uncleaned images 349 | if self.classes[cls_id] == "person" and int(xml_comps[0]) > 2008: 350 | to_del.append(xml_name) 351 | continue 352 | 353 | # Exclude overlapping images in Willow 354 | if cfg.exclude_willow_classes: 355 | if ( 356 | self.sets == "train" 357 | and (self.classes[cls_id] == "motorbike" or self.classes[cls_id] == "car") 358 | and int(xml_comps[0]) == 2007 359 | ): 360 | to_del.append(xml_name) 361 | continue 362 | 363 | for x in to_del: 364 | self.xml_list[cls_id].remove(x) 365 | 366 | def get_k_samples(self, idx, k, mode, cls=None, shuffle=True, num_iterations=200): 367 | """ 368 | Randomly get a sample of k objects from VOC-Berkeley keypoints dataset 369 | :param idx: Index of datapoint to sample, None for random sampling 370 | :param k: number of datapoints in sample 371 | :param mode: sampling strategy 372 | :param cls: None for random class, or specify for a certain set 373 | :param shuffle: random shuffle the keypoints 374 | :param num_iterations: maximum number of iterations for sampling a datapoint 375 | :return: (k samples of data, k \choose 2 groundtruth permutation matrices) 376 | """ 377 | if idx is not None: 378 | raise NotImplementedError("No indexed sampling implemented for PVOC.") 379 | if cls is None: 380 | cls = random.randrange(0, len(self.classes)) 381 | elif type(cls) == str: 382 | cls = self.classes.index(cls) 383 | assert type(cls) == int and 0 <= cls < len(self.classes) 384 | 385 | if mode == "superset" and k == 2: # superset sampling only valid for pairs 386 | anno_list, perm_mat = self.get_pair_superset(cls=cls, shuffle=shuffle, num_iterations=num_iterations) 387 | return anno_list, [perm_mat] 388 | elif mode == "intersection": 389 | for i in range(num_iterations): 390 | xml_used = list(random.sample(self.xml_list[cls], 2)) 391 | anno_dict_1, anno_dict_2 = [self.__get_anno_dict(xml, cls) for xml in xml_used] 392 | kp_names_1 = [keypoint["name"] for keypoint in anno_dict_1["keypoints"]] 393 | kp_names_2 = [keypoint["name"] for keypoint in anno_dict_2["keypoints"]] 394 | kp_names_filtered = set(kp_names_1).intersection(kp_names_2) 395 | anno_dict_1["keypoints"] = [kp for kp in anno_dict_1["keypoints"] if kp["name"] in kp_names_2] 396 | anno_dict_2["keypoints"] = [kp for kp in anno_dict_2["keypoints"] if kp["name"] in kp_names_1] 397 | 398 | anno_list = [anno_dict_1, anno_dict_2] 399 | for j in range(num_iterations): 400 | if j > 2 * len(self.xml_list[cls]) or len(anno_list) == k: 401 | break 402 | xml = random.choice(self.xml_list[cls]) 403 | anno_dict = self.__get_anno_dict(xml, cls) 404 | anno_dict["keypoints"] = [kp for kp in anno_dict["keypoints"] if kp["name"] in kp_names_filtered] 405 | if len(anno_dict["keypoints"]) > len(kp_names_filtered) // 2 and xml not in xml_used: 406 | xml_used.append(xml) 407 | anno_list.append(anno_dict) 408 | if len(anno_list) == k: # k samples found that match restrictions 409 | break 410 | assert len(anno_list) == k 411 | elif mode == "all": 412 | anno_list = [] 413 | for xml_name in random.sample(self.xml_list[cls], k): 414 | anno_dict = self.__get_anno_dict(xml_name, cls) 415 | if shuffle: 416 | random.shuffle(anno_dict["keypoints"]) 417 | anno_list.append(anno_dict) 418 | 419 | if shuffle: 420 | for anno_dict in anno_list: 421 | random.shuffle(anno_dict["keypoints"]) 422 | 423 | # build permutation matrices 424 | perm_mat_list = [ 425 | np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) for anno_pair in lexico_iter(anno_list) 426 | ] 427 | for n, (s1, s2) in enumerate(lexico_iter(anno_list)): 428 | for i, keypoint in enumerate(s1["keypoints"]): 429 | for j, _keypoint in enumerate(s2["keypoints"]): 430 | if keypoint["name"] == _keypoint["name"]: 431 | perm_mat_list[n][i, j] = 1 432 | 433 | return anno_list, perm_mat_list 434 | 435 | def get_pair_superset(self, cls=None, shuffle=True, num_iterations=200): 436 | """ 437 | Randomly get a pair of objects from VOC-Berkeley keypoints dataset using superset sampling 438 | :param cls: None for random class, or specify for a certain set 439 | :param shuffle: random shuffle the keypoints 440 | :return: (pair of data, groundtruth permutation matrix) 441 | """ 442 | if cls is None: 443 | cls = random.randrange(0, len(self.classes)) 444 | elif type(cls) == str: 445 | cls = self.classes.index(cls) 446 | assert type(cls) == int and 0 <= cls < len(self.classes) 447 | 448 | anno_pair = None 449 | 450 | anno_dict_1 = self.__get_anno_dict(random.sample(self.xml_list[cls], 1)[0], cls) 451 | if shuffle: 452 | random.shuffle(anno_dict_1["keypoints"]) 453 | keypoints_1 = set([kp["name"] for kp in anno_dict_1["keypoints"]]) 454 | 455 | for xml_name in random.sample(self.xml_list[cls], min(len(self.xml_list[cls]), num_iterations)): 456 | anno_dict_2 = self.__get_anno_dict(xml_name, cls) 457 | if shuffle: 458 | random.shuffle(anno_dict_2["keypoints"]) 459 | keypoints_2 = set([kp["name"] for kp in anno_dict_2["keypoints"]]) 460 | if keypoints_1.issubset(keypoints_2): 461 | anno_pair = [anno_dict_1, anno_dict_2] 462 | break 463 | 464 | if anno_pair is None: 465 | return self.get_pair_superset(cls, shuffle, num_iterations) 466 | 467 | perm_mat = np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) 468 | row_list = [] 469 | col_list = [] 470 | for i, keypoint in enumerate(anno_pair[0]["keypoints"]): 471 | for j, _keypoint in enumerate(anno_pair[1]["keypoints"]): 472 | if keypoint["name"] == _keypoint["name"]: 473 | perm_mat[i, j] = 1 474 | row_list.append(i) 475 | col_list.append(j) 476 | break 477 | 478 | assert len(row_list) == len(anno_pair[0]["keypoints"]) 479 | 480 | return anno_pair, perm_mat 481 | 482 | def __get_anno_dict(self, xml_name, cls): 483 | """ 484 | Get an annotation dict from xml file 485 | """ 486 | xml_file = self.anno_path / xml_name 487 | assert xml_file.exists(), "{} does not exist.".format(xml_file) 488 | 489 | tree = ET.parse(xml_file.open()) 490 | root = tree.getroot() 491 | 492 | img_name = root.find("./image").text + ".jpg" 493 | img_file = self.img_path / img_name 494 | bounds = root.find("./visible_bounds").attrib 495 | 496 | h = float(bounds["height"]) 497 | w = float(bounds["width"]) 498 | xmin = float(bounds["xmin"]) 499 | ymin = float(bounds["ymin"]) 500 | 501 | with Image.open(str(img_file)) as img: 502 | ori_sizes = img.size 503 | obj = img.resize(self.obj_resize, resample=Image.BICUBIC, box=(xmin, ymin, xmin + w, ymin + h)) 504 | 505 | keypoint_list = [] 506 | for keypoint in root.findall("./keypoints/keypoint"): 507 | attr = keypoint.attrib 508 | attr["x"] = (float(attr["x"]) - xmin) * self.obj_resize[0] / w 509 | attr["y"] = (float(attr["y"]) - ymin) * self.obj_resize[1] / h 510 | if -1e-5 < attr["x"] < self.obj_resize[0] + 1e-5 and -1e-5 < attr["y"] < self.obj_resize[1] + 1e-5: 511 | keypoint_list.append(attr) 512 | 513 | anno_dict = dict() 514 | anno_dict["image"] = obj 515 | anno_dict["keypoints"] = keypoint_list 516 | anno_dict["bounds"] = xmin, ymin, w, h 517 | anno_dict["ori_sizes"] = ori_sizes 518 | anno_dict["cls"] = self.classes[cls] 519 | 520 | return anno_dict 521 | --------------------------------------------------------------------------------