├── LICENSE ├── Readme.md ├── eval_pf_pascal.py ├── lib ├── __pycache__ │ ├── constant.cpython-37.pyc │ ├── conv4d.cpython-37.pyc │ ├── dataloader.cpython-37.pyc │ ├── eval_util.cpython-37.pyc │ ├── im_pair_dataset.cpython-37.pyc │ ├── interpolator.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── model.cpython-37.pyc │ ├── normalization.cpython-37.pyc │ ├── pf_dataset.cpython-37.pyc │ ├── pf_pascal_dataset.cpython-37.pyc │ ├── pf_willow_dataset.cpython-37.pyc │ ├── point_tnf.cpython-37.pyc │ ├── tools.cpython-37.pyc │ ├── torch_util.cpython-37.pyc │ ├── transformation.cpython-37.pyc │ └── visualisation.cpython-37.pyc ├── constant.py ├── conv4d.py ├── dataloader.py ├── eval_util.py ├── im_pair_dataset.py ├── interpolator.py ├── model.py ├── normalization.py ├── pf_dataset.py ├── pf_pascal_dataset.py ├── plot.py ├── point_tnf.py ├── tools.py ├── torch_util.py ├── transformation.py └── visualisation.py ├── requirements.txt └── run.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2020 Shuda Li 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Release demo 2 | 3 | This package contains a demo for submission #689 Correspondence Networks with Adaptive Neighbourhood Consensus. 4 | 5 | The demo first calculates the pck@0.1 score on the PF-PASCAL dataset and then export two sets of visualisations. The first shows the correlation map and key point predictions. The second set illustrates the key point matching results.orrelation map with and key point predictions. The second set illustrates the key point matching results. The visualisation can be found in the folder output. 6 | 7 | 8 | ## Requirements 9 | - Ubuntu 18.04 10 | - Conda 11 | - python 3.7 12 | - CUDA 9.0 or newer 13 | 14 | ## Installation 15 | 1. Install CUDA 9.0 as well as either anaconda or miniconda [link](https://docs.conda.io/en/latest/miniconda.html#linux-installers) 16 | 2. Create a conda environment: `conda create -n 689release python=3.7` 17 | 3. Activate the environment: `conda activate 689release` 18 | 4. Run the following commands: 19 | - `conda install pytorch torchvision cudatoolkit=9.0` 20 | - `pip install -r requirements.txt` 21 | - `wget -O ancnet.zip https://www.dropbox.com/s/bjul4f5z7beq3um/ancnet.zip?dl=0 && unzip -q ancnet.zip && rm ancnet.zip` 22 | 23 | ## Usage 24 | To run example code: `python eval_pf_pascal.py` 25 | 26 | ## Quick start 27 | After creating a conda environment, you can simply do $sh run.sh. If you encounter any error, please follow installation and usage sections. 28 | -------------------------------------------------------------------------------- /eval_pf_pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from os.path import exists 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.utils.data import Dataset, DataLoader 8 | from collections import OrderedDict 9 | 10 | from lib.model import ImMatchNet 11 | from lib.pf_dataset import PFPascalDataset 12 | from lib.normalization import NormalizeImageDict 13 | from lib.torch_util import BatchTensorToVars, str_to_bool 14 | from lib.point_tnf import corr_to_matches 15 | from lib.eval_util import pck_metric 16 | from lib.dataloader import default_collate 17 | from lib.torch_util import collate_custom 18 | from lib import pf_pascal_dataset as pf 19 | from lib import tools 20 | 21 | import argparse 22 | import warnings 23 | from tqdm import tqdm 24 | 25 | warnings.filterwarnings("ignore", category=UserWarning) 26 | 27 | 28 | def main(): 29 | print("NCNet evaluation script - PF Pascal dataset") 30 | 31 | use_cuda = torch.cuda.is_available() 32 | 33 | # Argument parsing 34 | parser = argparse.ArgumentParser(description="Compute PF Pascal matches") 35 | parser.add_argument("--checkpoint", type=str, default="models/ancnet_86_11.pth.tar") 36 | parser.add_argument( 37 | "--vis", 38 | type=int, 39 | default=0, 40 | help="visilisation options: 0 calculate pck; 1 visualise keypoint matches and heat maps; 2 display matched key points", 41 | ) 42 | parser.add_argument("--a", type=float, default=0.1, help="a is the pck@alpha value") 43 | parser.add_argument( 44 | "--num_examples", type=int, default=5, help="the number of matching examples" 45 | ) 46 | 47 | args = parser.parse_args() 48 | 49 | vis = args.vis 50 | alpha = args.a 51 | num_examples = args.num_examples 52 | 53 | if args.checkpoint is not None and args.checkpoint is not "": 54 | print("Loading checkpoint...") 55 | checkpoint = torch.load( 56 | args.checkpoint, map_location=lambda storage, loc: storage 57 | ) 58 | checkpoint["state_dict"] = OrderedDict( 59 | [ 60 | (k.replace("vgg", "model"), v) 61 | for k, v in checkpoint["state_dict"].items() 62 | ] 63 | ) 64 | 65 | args = checkpoint["args"] 66 | else: 67 | print("checkpoint needed.") 68 | exit() 69 | 70 | cnn_image_size = (args.image_size, args.image_size) 71 | 72 | # Create model 73 | print("Creating CNN model...") 74 | model = ImMatchNet( 75 | use_cuda=use_cuda, 76 | feature_extraction_cnn=args.backbone, 77 | checkpoint=checkpoint, 78 | ncons_kernel_sizes=args.ncons_kernel_sizes, 79 | ncons_channels=args.ncons_channels, 80 | pss=args.pss, 81 | noniso=args.noniso, 82 | ) 83 | model.eval() 84 | 85 | print("args.dataset_image_path", args.dataset_image_path) 86 | # Dataset and dataloader 87 | collate_fn = default_collate 88 | csv_file = "image_pairs/test_pairs.csv" 89 | 90 | dataset = PFPascalDataset( 91 | csv_file=os.path.join(args.dataset_image_path, csv_file), 92 | dataset_path=args.dataset_image_path, 93 | transform=NormalizeImageDict(["source_image", "target_image"]), 94 | output_size=cnn_image_size, 95 | ) 96 | dataset.pck_procedure = "scnet" 97 | 98 | # Only batch_size=1 is supported for evaluation 99 | batch_size = 1 100 | 101 | dataloader = DataLoader( 102 | dataset, 103 | batch_size=batch_size, 104 | shuffle=False, 105 | num_workers=0, 106 | collate_fn=collate_fn, 107 | ) 108 | 109 | batch_tnf = BatchTensorToVars(use_cuda=use_cuda) 110 | 111 | # initialize vector for storing results 112 | stats = {} 113 | stats["point_tnf"] = {} 114 | stats["point_tnf"]["pck"] = np.zeros((len(dataset), 1)) 115 | 116 | # Compute pck accuracy 117 | total = len(dataloader) 118 | progress = tqdm(dataloader, total=total) 119 | for i, batch in enumerate(progress): 120 | batch = batch_tnf(batch) 121 | batch_start_idx = batch_size * i 122 | corr4d = model(batch) 123 | 124 | # get matches 125 | # note invert_matching_direction doesnt work at all 126 | xA, yA, xB, yB, sB = corr_to_matches( 127 | corr4d, do_softmax=True, invert_matching_direction=False 128 | ) 129 | 130 | matches = (xA, yA, xB, yB) 131 | stats = pck_metric( 132 | batch, batch_start_idx, matches, stats, alpha=alpha, use_cuda=use_cuda 133 | ) 134 | 135 | # Print results 136 | results = stats["point_tnf"]["pck"] 137 | good_idx = np.flatnonzero((results != -1) * ~np.isnan(results)) 138 | print("Total: " + str(results.size)) 139 | print("Valid: " + str(good_idx.size)) 140 | filtered_results = results[good_idx] 141 | print("PCK:", "{:.2%}".format(np.mean(filtered_results))) 142 | 143 | test_csv = "test_pairs.csv" 144 | dataset_val = pf.ImagePairDataset( 145 | transform=NormalizeImageDict(["source_image", "target_image"]), 146 | dataset_image_path=args.dataset_image_path, 147 | dataset_csv_path=os.path.join(args.dataset_image_path, "image_pairs"), 148 | dataset_csv_file=test_csv, 149 | output_size=cnn_image_size, 150 | keypoints_on=True, 151 | original=True, 152 | test=True, 153 | ) 154 | loader_test = DataLoader(dataset_val, batch_size=1, shuffle=True, num_workers=4) 155 | batch_tnf = BatchTensorToVars(use_cuda=use_cuda) 156 | 157 | print("visualise correlation") 158 | tools.visualise_feature( 159 | model, loader_test, batch_tnf, image_size=cnn_image_size, MAX=num_examples 160 | ) 161 | print("visualise pair") 162 | tools.validate( 163 | model, 164 | loader_test, 165 | batch_tnf, 166 | None, 167 | image_scale=args.image_size, 168 | im_fe_ratio=16, 169 | image_size=cnn_image_size, 170 | MAX=num_examples, 171 | display=True, 172 | ) 173 | 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /lib/__pycache__/constant.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/constant.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/conv4d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/conv4d.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/eval_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/eval_util.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/im_pair_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/im_pair_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/interpolator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/interpolator.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/normalization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/normalization.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pf_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/pf_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pf_pascal_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/pf_pascal_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/pf_willow_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/pf_willow_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/point_tnf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/point_tnf.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/torch_util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/torch_util.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/transformation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/transformation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/__pycache__/visualisation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ActiveVisionLab/ANCNet/7aadbb264375cd4e5300865f2987b6cab485756a/lib/__pycache__/visualisation.cpython-37.pyc -------------------------------------------------------------------------------- /lib/constant.py: -------------------------------------------------------------------------------- 1 | _colors = ['b', 'y', 'g', 'r', 'c', 'm', 'k'] 2 | _markers = ['o','v','s','p','P','*','+','x','D'] 3 | 4 | _eps = 1e-10 5 | -------------------------------------------------------------------------------- /lib/conv4d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | from torch.nn import Module 7 | from torch.nn.modules.conv import _ConvNd 8 | from torch.nn.modules.utils import _quadruple 9 | from torch.autograd import Variable 10 | import torch.nn as nn 11 | from torch.nn import Conv2d 12 | 13 | 14 | def conv4d(data, filters, bias=None, permute_filters=True, use_half=False): 15 | b, c, h, w, d, t = data.size() 16 | 17 | # [8, 1, 25, 25, 25, 25] -> [25, 8, 1, 25, 25, 25] 18 | data = data.permute( 19 | 2, 0, 1, 3, 4, 5 20 | ).contiguous() # permute to avoid making contiguous inside loop 21 | 22 | # Same permutation is done with filters, unless already provided with permutation 23 | if permute_filters: 24 | filters = filters.permute( 25 | 2, 0, 1, 3, 4, 5 26 | ).contiguous() # permute to avoid making contiguous inside loop 27 | c_out = filters.size(1) 28 | if use_half: 29 | output = torch.HalfTensor( 30 | [h, b, c_out, w, d, t], dtype=data.dtype, requires_grad=data.requires_grad 31 | ) 32 | else: 33 | output = torch.zeros( 34 | [h, b, c_out, w, d, t], dtype=data.dtype, requires_grad=data.requires_grad 35 | ) 36 | 37 | kh, _, _, kw, kd, kt = filters.shape 38 | padding = kh // 2 # calc padding size (kernel_size - 1)/2 39 | padding_3d = (kw // 2, kd // 2, kt // 2) 40 | if use_half: 41 | Z = torch.zeros([padding, b, c, w, d, t], dtype=data.dtype).half() 42 | else: 43 | Z = torch.zeros([padding, b, c, w, d, t], dtype=data.dtype) 44 | 45 | if data.is_cuda: 46 | Z = Z.cuda(data.get_device()) 47 | output = output.cuda(data.get_device()) 48 | 49 | data_padded = torch.cat((Z, data, Z), 0) # [29, 8, 16, 25, 25, 25] 50 | if bias is not None: 51 | bias = bias / (1 + padding * 2) 52 | # print('bias',bias) 53 | 54 | for i in range(output.size(0)): # loop on first feature dimension 55 | # convolve with center channel of filter (at position=padding) 56 | output[i, :, :, :, :, :] = F.conv3d( 57 | data_padded[i + padding, :, :, :, :, :], 58 | filters[padding, :, :, :, :, :], 59 | bias=bias, 60 | stride=1, 61 | padding=padding_3d, 62 | ) 63 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) 64 | for p in range(1, padding + 1): 65 | output[i, :, :, :, :, :] += F.conv3d( 66 | data_padded[i + padding - p, :, :, :, :, :], 67 | filters[padding - p, :, :, :, :, :], 68 | bias=bias, 69 | stride=1, 70 | padding=padding_3d, 71 | ) 72 | output[i, :, :, :, :, :] += F.conv3d( 73 | data_padded[i + padding + p, :, :, :, :, :], 74 | filters[padding + p, :, :, :, :, :], 75 | bias=bias, 76 | stride=1, 77 | padding=padding_3d, 78 | ) 79 | 80 | output = output.permute(1, 2, 0, 3, 4, 5).contiguous() # [8, 16, 25, 25, 25, 25] 81 | return output 82 | 83 | 84 | class Conv4d(_ConvNd): 85 | """Applies a 4D convolution over an input signal composed of several input 86 | planes. 87 | """ 88 | 89 | def __init__( 90 | self, 91 | in_channels, 92 | out_channels, 93 | kernel_size, 94 | pre_permuted_filters=True, 95 | bias=True, 96 | filters=None, 97 | bias_4d=None, 98 | ): 99 | # stride, dilation and groups !=1 functionality not tested 100 | stride = 1 101 | dilation = 1 102 | groups = 1 103 | # zero padding is added automatically in conv4d function to preserve tensor size 104 | padding = 0 105 | kernel_size = _quadruple(kernel_size) 106 | stride = _quadruple(stride) 107 | padding = _quadruple(padding) 108 | dilation = _quadruple(dilation) 109 | super(Conv4d, self).__init__( 110 | in_channels, 111 | out_channels, 112 | kernel_size, 113 | stride, 114 | padding, 115 | dilation, 116 | False, 117 | _quadruple(0), 118 | groups, 119 | bias, 120 | "zero", 121 | ) 122 | # weights will be sliced along one dimension during convolution loop 123 | # make the looping dimension to be the first one in the tensor, 124 | # so that we don't need to call contiguous() inside the loop 125 | self.pre_permuted_filters = pre_permuted_filters 126 | # self.groups=groups 127 | if filters is not None: 128 | self.weight.data = filters 129 | if bias_4d is not None and bias: 130 | self.bias.data = bias_4d 131 | 132 | if self.pre_permuted_filters: 133 | self.weight.data = self.weight.data.permute(2, 0, 1, 3, 4, 5).contiguous() 134 | self.use_half = False 135 | 136 | def forward(self, input): 137 | return conv4d( 138 | input, 139 | self.weight, 140 | bias=self.bias, 141 | permute_filters=not self.pre_permuted_filters, 142 | use_half=self.use_half, 143 | ) # filters pre-permuted in constructor 144 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler 4 | import collections 5 | import sys 6 | import traceback 7 | import threading 8 | import numpy as np 9 | import numpy.random 10 | 11 | # from torch._six import string_classes 12 | PY2 = sys.version_info[0] == 2 13 | PY3 = sys.version_info[0] == 3 14 | 15 | if PY2: 16 | string_classes = basestring 17 | else: 18 | string_classes = (str, bytes) 19 | 20 | 21 | if sys.version_info[0] == 2: 22 | import Queue as queue 23 | else: 24 | import queue 25 | 26 | 27 | _use_shared_memory = False 28 | """Whether to use shared memory in default_collate""" 29 | 30 | 31 | class ExceptionWrapper(object): 32 | "Wraps an exception plus traceback to communicate across threads" 33 | 34 | def __init__(self, exc_info): 35 | self.exc_type = exc_info[0] 36 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 37 | 38 | 39 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, rng_seed): 40 | global _use_shared_memory 41 | _use_shared_memory = True 42 | 43 | np.random.seed(rng_seed) 44 | torch.set_num_threads(1) 45 | while True: 46 | r = index_queue.get() 47 | if r is None: 48 | data_queue.put(None) 49 | break 50 | idx, batch_indices = r 51 | try: 52 | samples = collate_fn([dataset[i] for i in batch_indices]) 53 | except Exception: 54 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 55 | else: 56 | data_queue.put((idx, samples)) 57 | 58 | 59 | def _pin_memory_loop(in_queue, out_queue, done_event): 60 | while True: 61 | try: 62 | r = in_queue.get() 63 | except: 64 | if done_event.is_set(): 65 | return 66 | raise 67 | if r is None: 68 | break 69 | if isinstance(r[1], ExceptionWrapper): 70 | out_queue.put(r) 71 | continue 72 | idx, batch = r 73 | try: 74 | batch = pin_memory_batch(batch) 75 | except Exception: 76 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 77 | else: 78 | out_queue.put((idx, batch)) 79 | 80 | 81 | numpy_type_map = { 82 | "float64": torch.DoubleTensor, 83 | "float32": torch.FloatTensor, 84 | "float16": torch.HalfTensor, 85 | "int64": torch.LongTensor, 86 | "int32": torch.IntTensor, 87 | "int16": torch.ShortTensor, 88 | "int8": torch.CharTensor, 89 | "uint8": torch.ByteTensor, 90 | } 91 | 92 | 93 | def default_collate(batch): 94 | "Puts each data field into a tensor with outer dimension batch size" 95 | if torch.is_tensor(batch[0]): 96 | out = None 97 | if _use_shared_memory: 98 | # If we're in a background process, concatenate directly into a 99 | # shared memory tensor to avoid an extra copy 100 | numel = sum([x.numel() for x in batch]) 101 | storage = batch[0].storage()._new_shared(numel) 102 | out = batch[0].new(storage) 103 | return torch.stack(batch, 0, out=out) 104 | elif type(batch[0]).__module__ == "numpy": 105 | elem = batch[0] 106 | if type(elem).__name__ == "ndarray": 107 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 108 | if elem.shape == (): # scalars 109 | py_type = float if elem.dtype.name.startswith("float") else int 110 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 111 | elif isinstance(batch[0], int): 112 | return torch.LongTensor(batch) 113 | elif isinstance(batch[0], float): 114 | return torch.DoubleTensor(batch) 115 | elif isinstance(batch[0], string_classes): 116 | return batch 117 | elif isinstance(batch[0], collections.Mapping): 118 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 119 | elif isinstance(batch[0], collections.Sequence): 120 | transposed = zip(*batch) 121 | return [default_collate(samples) for samples in transposed] 122 | 123 | raise TypeError( 124 | ( 125 | "batch must contain tensors, numbers, dicts or lists; found {}".format( 126 | type(batch[0]) 127 | ) 128 | ) 129 | ) 130 | 131 | 132 | def pin_memory_batch(batch): 133 | if torch.is_tensor(batch): 134 | return batch.pin_memory() 135 | elif isinstance(batch, string_classes): 136 | return batch 137 | elif isinstance(batch, collections.Mapping): 138 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 139 | elif isinstance(batch, collections.Sequence): 140 | return [pin_memory_batch(sample) for sample in batch] 141 | else: 142 | return batch 143 | 144 | 145 | class DataLoaderIter(object): 146 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 147 | 148 | def __init__(self, loader): 149 | self.dataset = loader.dataset 150 | self.collate_fn = loader.collate_fn 151 | self.batch_sampler = loader.batch_sampler 152 | self.num_workers = loader.num_workers 153 | self.pin_memory = loader.pin_memory 154 | self.done_event = threading.Event() 155 | 156 | self.sample_iter = iter(self.batch_sampler) 157 | 158 | if self.num_workers > 0: 159 | self.index_queue = multiprocessing.SimpleQueue() 160 | self.data_queue = multiprocessing.SimpleQueue() 161 | self.batches_outstanding = 0 162 | self.shutdown = False 163 | self.send_idx = 0 164 | self.rcvd_idx = 0 165 | self.reorder_dict = {} 166 | 167 | self.workers = [ 168 | multiprocessing.Process( 169 | target=_worker_loop, 170 | args=( 171 | self.dataset, 172 | self.index_queue, 173 | self.data_queue, 174 | self.collate_fn, 175 | np.random.randint(0, 4294967296, dtype="uint32"), 176 | ), 177 | ) 178 | for _ in range(self.num_workers) 179 | ] 180 | 181 | for w in self.workers: 182 | w.daemon = True # ensure that the worker exits on process exit 183 | w.start() 184 | 185 | if self.pin_memory: 186 | in_data = self.data_queue 187 | self.data_queue = queue.Queue() 188 | self.pin_thread = threading.Thread( 189 | target=_pin_memory_loop, 190 | args=(in_data, self.data_queue, self.done_event), 191 | ) 192 | self.pin_thread.daemon = True 193 | self.pin_thread.start() 194 | 195 | # prime the prefetch loop 196 | for _ in range(2 * self.num_workers): 197 | self._put_indices() 198 | 199 | def __len__(self): 200 | return len(self.batch_sampler) 201 | 202 | def __next__(self): 203 | if self.num_workers == 0: # same-process loading 204 | indices = next(self.sample_iter) # may raise StopIteration 205 | batch = self.collate_fn([self.dataset[i] for i in indices]) 206 | if self.pin_memory: 207 | batch = pin_memory_batch(batch) 208 | return batch 209 | 210 | # check if the next sample has already been generated 211 | if self.rcvd_idx in self.reorder_dict: 212 | batch = self.reorder_dict.pop(self.rcvd_idx) 213 | return self._process_next_batch(batch) 214 | 215 | if self.batches_outstanding == 0: 216 | self._shutdown_workers() 217 | raise StopIteration 218 | 219 | while True: 220 | assert not self.shutdown and self.batches_outstanding > 0 221 | idx, batch = self.data_queue.get() 222 | self.batches_outstanding -= 1 223 | if idx != self.rcvd_idx: 224 | # store out-of-order samples 225 | self.reorder_dict[idx] = batch 226 | continue 227 | return self._process_next_batch(batch) 228 | 229 | next = __next__ # Python 2 compatibility 230 | 231 | def __iter__(self): 232 | return self 233 | 234 | def _put_indices(self): 235 | assert self.batches_outstanding < 2 * self.num_workers 236 | indices = next(self.sample_iter, None) 237 | if indices is None: 238 | return 239 | self.index_queue.put((self.send_idx, indices)) 240 | self.batches_outstanding += 1 241 | self.send_idx += 1 242 | 243 | def _process_next_batch(self, batch): 244 | self.rcvd_idx += 1 245 | self._put_indices() 246 | if isinstance(batch, ExceptionWrapper): 247 | raise batch.exc_type(batch.exc_msg) 248 | return batch 249 | 250 | def __getstate__(self): 251 | # TODO: add limited pickling support for sharing an iterator 252 | # across multiple threads for HOGWILD. 253 | # Probably the best way to do this is by moving the sample pushing 254 | # to a separate thread and then just sharing the data queue 255 | # but signalling the end is tricky without a non-blocking API 256 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 257 | 258 | def _shutdown_workers(self): 259 | if not self.shutdown: 260 | self.shutdown = True 261 | self.done_event.set() 262 | for _ in self.workers: 263 | self.index_queue.put(None) 264 | 265 | def __del__(self): 266 | if self.num_workers > 0: 267 | self._shutdown_workers() 268 | 269 | 270 | class DataLoader(object): 271 | """ 272 | Data loader. Combines a dataset and a sampler, and provides 273 | single- or multi-process iterators over the dataset. 274 | 275 | Arguments: 276 | dataset (Dataset): dataset from which to load the data. 277 | batch_size (int, optional): how many samples per batch to load 278 | (default: 1). 279 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 280 | at every epoch (default: False). 281 | sampler (Sampler, optional): defines the strategy to draw samples from 282 | the dataset. If specified, ``shuffle`` must be False. 283 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 284 | indices at a time. Mutually exclusive with batch_size, shuffle, 285 | sampler, and drop_last. 286 | num_workers (int, optional): how many subprocesses to use for data 287 | loading. 0 means that the data will be loaded in the main process 288 | (default: 0) 289 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 290 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 291 | into CUDA pinned memory before returning them. 292 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 293 | if the dataset size is not divisible by the batch size. If False and 294 | the size of dataset is not divisible by the batch size, then the last batch 295 | will be smaller. (default: False) 296 | """ 297 | 298 | def __init__( 299 | self, 300 | dataset, 301 | batch_size=1, 302 | shuffle=False, 303 | sampler=None, 304 | batch_sampler=None, 305 | num_workers=0, 306 | collate_fn=default_collate, 307 | pin_memory=False, 308 | drop_last=False, 309 | ): 310 | self.dataset = dataset 311 | self.batch_size = batch_size 312 | self.num_workers = num_workers 313 | self.collate_fn = collate_fn 314 | self.pin_memory = pin_memory 315 | self.drop_last = drop_last 316 | 317 | if batch_sampler is not None: 318 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 319 | raise ValueError( 320 | "batch_sampler is mutually exclusive with " 321 | "batch_size, shuffle, sampler, and drop_last" 322 | ) 323 | 324 | if sampler is not None and shuffle: 325 | raise ValueError("sampler is mutually exclusive with shuffle") 326 | 327 | if batch_sampler is None: 328 | if sampler is None: 329 | if shuffle: 330 | sampler = RandomSampler(dataset) 331 | else: 332 | sampler = SequentialSampler(dataset) 333 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 334 | 335 | self.sampler = sampler 336 | self.batch_sampler = batch_sampler 337 | 338 | def __iter__(self): 339 | return DataLoaderIter(self) 340 | 341 | def __len__(self): 342 | return len(self.batch_sampler) 343 | -------------------------------------------------------------------------------- /lib/eval_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import numpy as np 4 | import os 5 | from skimage import draw 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from lib.pf_dataset import PFPascalDataset 9 | from lib.point_tnf import ( 10 | PointsToUnitCoords, 11 | PointsToPixelCoords, 12 | bilinearInterpPointTnf, 13 | ) 14 | 15 | 16 | def pck(source_points, warped_points, L_pck, alpha=0.1): 17 | # compute precentage of correct keypoints 18 | batch_size = source_points.size(0) 19 | pck = torch.zeros((batch_size)) 20 | for i in range(batch_size): 21 | p_src = source_points[i, :] 22 | p_wrp = warped_points[i, :] 23 | N_pts = torch.sum(torch.ne(p_src[0, :], -1) * torch.ne(p_src[1, :], -1)) 24 | point_distance = torch.pow( 25 | torch.sum(torch.pow(p_src[:, :N_pts] - p_wrp[:, :N_pts], 2), 0), 0.5 26 | ) 27 | L_pck_mat = L_pck[i].expand_as(point_distance) 28 | correct_points = torch.le(point_distance, L_pck_mat * alpha) 29 | pck[i] = torch.mean(correct_points.float()) 30 | return pck 31 | 32 | 33 | def pck_metric(batch, batch_start_idx, matches, stats, alpha=0.1, use_cuda=True): 34 | 35 | source_im_size = batch["source_im_size"] 36 | target_im_size = batch["target_im_size"] 37 | 38 | source_points = batch["source_points"] # w.r.t. the original image coordinate #224 39 | target_points = batch["target_points"] # B x 2 x N 40 | 41 | # warp points with estimated transformations 42 | target_points_norm = PointsToUnitCoords( 43 | target_points, target_im_size 44 | ) # convert from image coordinate to -1, 1 45 | # print('target_im_size',target_im_size) 46 | 47 | # compute points stage 1 only 48 | warped_points_norm = bilinearInterpPointTnf(matches, target_points_norm) 49 | warped_points = PointsToPixelCoords(warped_points_norm, source_im_size) 50 | 51 | L_pck = batch["L_pck"].data 52 | 53 | current_batch_size = batch["source_im_size"].size(0) 54 | indices = range(batch_start_idx, batch_start_idx + current_batch_size) 55 | 56 | # compute PCK 57 | pck_batch = pck(source_points.data, warped_points.data, L_pck, alpha=alpha) 58 | stats["point_tnf"]["pck"][indices] = pck_batch.unsqueeze(1).cpu().numpy() 59 | 60 | return stats 61 | -------------------------------------------------------------------------------- /lib/im_pair_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils.data import Dataset 6 | from skimage import io 7 | import pandas as pd 8 | import numpy as np 9 | from lib.transformation import AffineTnf 10 | 11 | 12 | class ImagePairDataset(Dataset): 13 | 14 | """ 15 | 16 | Image pair dataset used for weak supervision 17 | 18 | 19 | Args: 20 | csv_file (string): Path to the csv file with image names and transformations. 21 | training_image_path (string): Directory with the images. 22 | output_size (2-tuple): Desired output size 23 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | dataset_csv_path, 30 | dataset_csv_file, 31 | dataset_image_path, 32 | dataset_size=0, 33 | output_size=(240, 240), 34 | transform=None, 35 | random_crop=False, 36 | ): 37 | self.random_crop = random_crop 38 | self.out_h, self.out_w = output_size 39 | self.train_data = pd.read_csv(os.path.join(dataset_csv_path, dataset_csv_file)) 40 | if dataset_size is not None and dataset_size != 0: 41 | dataset_size = min((dataset_size, len(self.train_data))) 42 | self.train_data = self.train_data.iloc[0:dataset_size, :] 43 | self.img_A_names = self.train_data.iloc[:, 0] 44 | self.img_B_names = self.train_data.iloc[:, 1] 45 | self.set = self.train_data.iloc[:, 2].values 46 | self.flip = self.train_data.iloc[:, 3].values.astype("int") 47 | self.dataset_image_path = dataset_image_path 48 | self.transform = transform 49 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 50 | self.affineTnf = AffineTnf(out_h=self.out_h, out_w=self.out_w, use_cuda=False) 51 | 52 | def __len__(self): 53 | return len(self.img_A_names) 54 | 55 | def __getitem__(self, idx): 56 | # get pre-processed images 57 | image_A, im_size_A = self.get_image(self.img_A_names, idx, self.flip[idx]) 58 | image_B, im_size_B = self.get_image(self.img_B_names, idx, self.flip[idx]) 59 | 60 | image_set = self.set[idx] 61 | 62 | sample = { 63 | "source_image": image_A, 64 | "target_image": image_B, 65 | "source_im_size": im_size_A, 66 | "target_im_size": im_size_B, 67 | "set": image_set, 68 | } 69 | 70 | if self.transform: 71 | sample = self.transform(sample) 72 | 73 | return sample 74 | 75 | def get_image(self, img_name_list, idx, flip): 76 | img_name = os.path.join(self.dataset_image_path, img_name_list.iloc[idx]) 77 | image = io.imread(img_name) 78 | 79 | # if grayscale convert to 3-channel image 80 | if image.ndim == 2: 81 | image = np.repeat(np.expand_dims(image, 2), axis=2, repeats=3) 82 | 83 | # do random crop 84 | if self.random_crop: 85 | h, w, c = image.shape 86 | top = np.random.randint(h / 4) 87 | bottom = int(3 * h / 4 + np.random.randint(h / 4)) 88 | left = np.random.randint(w / 4) 89 | right = int(3 * w / 4 + np.random.randint(w / 4)) 90 | image = image[top:bottom, left:right, :] 91 | 92 | # flip horizontally if needed 93 | if flip: 94 | image = np.flip(image, 1) 95 | 96 | # get image size 97 | im_size = np.asarray(image.shape) 98 | 99 | # convert to torch Variable 100 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 101 | image = torch.Tensor(image.astype(np.float32)) 102 | image_var = Variable(image, requires_grad=False) 103 | 104 | # Resize image using bilinear sampling with identity affine tnf 105 | image = self.affineTnf(image_var).data.squeeze(0) 106 | 107 | im_size = torch.Tensor(im_size.astype(np.float32)) 108 | 109 | return (image, im_size) 110 | 111 | -------------------------------------------------------------------------------- /lib/interpolator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import sys 9 | import numpy as np 10 | import numpy.matlib 11 | import pickle 12 | import math 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | class Interpolator(nn.Module): 17 | def __init__(self, im_fe_ratio): 18 | super(Interpolator, self).__init__() 19 | self.im_fe_ratio = im_fe_ratio 20 | self.maxXY = torch.ones(1, 1, 2) 21 | self.minXY = torch.zeros(1, 1, 2) 22 | 23 | def getMaxMinXY(self, B, N, H, W): 24 | B1, N1, _ = self.maxXY.shape 25 | B2, N2, _ = self.minXY.shape 26 | if B1 == B and N1 == N and B2 == B and N2 == N: 27 | return self.maxXY, self.minXY 28 | else: 29 | self.maxXY = torch.ones(1, 1, 2) 30 | self.minXY = torch.zeros(1, 1, 2) 31 | self.maxXY[0, 0, 0] = W - 1 32 | self.maxXY[0, 0, 1] = H - 1 33 | self.maxXY = self.maxXY.cuda() 34 | self.minXY = self.minXY.cuda() 35 | self.maxXY = self.maxXY.expand(B, N, -1) # B x N x 2 36 | self.minXY = self.minXY.expand(B, N, -1) # B x N x 2 37 | return self.maxXY, self.minXY 38 | 39 | def maskoff(self, feature_per_kp, keypoints): 40 | """ 41 | maskoff() set the features to be zeros if the keypoints are 0, 0 42 | Arguments: 43 | feature_per_kp [float tensor] B x C X N : standard feature tensor by number of key points 44 | keypoints [float tensor] B x N x 2: key points 45 | Returns: 46 | feature_per_kp [float tensor] B x C x N : standard feature tensor with invalid keypoints 47 | """ 48 | mask = (keypoints[:, :, :2] > 1e-10).float().mean(dim=2) # B x N 49 | mask = mask.unsqueeze(1) # B x 1 x N 50 | mask = mask.expand_as(feature_per_kp) # B x C x N 51 | feature_per_kp = feature_per_kp * mask 52 | return feature_per_kp 53 | 54 | def forward(self, feature, keypoints): 55 | """ 56 | Interpolator(): collects a set of sparse key points by interpolating from 57 | the feature map. 58 | Arguments 59 | feature [float tensor] B x C x H x W: standard feature map 60 | keypoints [float tensor] B x N x 2: key points 61 | Return 62 | UF [float tensor] B x C x N: the sparse interpolated features collected 63 | at the input sparse key point locations. 64 | note that the rows corresponding to invalid key points 65 | are masked off as zeros. 66 | """ 67 | B, C, H, W = feature.shape 68 | feature = feature.view(B, C, -1) # B x C x HW ie. B x 256 x 256 69 | # print('keypoints', keypoints[0,1]) 70 | # convert the key points from the image coordinate to feature map coordinate 71 | keypoints = keypoints / self.im_fe_ratio - 0.5 72 | _, N, _ = keypoints.shape 73 | 74 | # print(keypoints[0]) 75 | maxXY, minXY = self.getMaxMinXY(B, N, H, W) 76 | 77 | # nearest neighbours 78 | iLower = torch.max(torch.floor(keypoints), minXY) # B x N x 2, index of X, Y 79 | # iLower = torch.floor(keypoints) # B x N x 2, index of X, Y 80 | iUpper = torch.min(torch.ceil(keypoints), maxXY) 81 | # iUpper = torch.ceil(keypoints) 82 | upper = keypoints - iLower # note that weight is the 1 - distance 83 | lower = 1 - upper # B x N x 2 84 | 85 | iX = torch.cat((iLower[:, :, 0].unsqueeze(2), iUpper[:, :, 0].unsqueeze(2)), 2) 86 | iY = torch.cat((iLower[:, :, 1].unsqueeze(2), iUpper[:, :, 1].unsqueeze(2)), 2) 87 | xX = torch.cat((lower[:, :, 0].unsqueeze(2), upper[:, :, 0].unsqueeze(2)), 2) 88 | yY = torch.cat((lower[:, :, 1].unsqueeze(2), upper[:, :, 1].unsqueeze(2)), 2) 89 | 90 | iX = ( 91 | iX.unsqueeze(2).expand(-1, -1, 2, -1).long() 92 | ) # B x 32 x 2 x 2 ( x0 x1; x0 x1 ) 93 | iY = iY.unsqueeze(2).expand(-1, -1, 2, -1).transpose(2, 3).long() 94 | xX = xX.unsqueeze(2).expand(-1, -1, 2, -1) 95 | yY = yY.unsqueeze(2).expand(-1, -1, 2, -1).transpose(2, 3) 96 | 97 | iX = iX.view(B, N, -1) # B x N x 4 98 | iY = iY.view(B, N, -1) 99 | xX = xX.contiguous().view(B, N, -1) 100 | yY = yY.contiguous().view(B, N, -1) 101 | # print('iY', iY[0,1]) 102 | # print('iX', iX[0,1]) 103 | # print('xY', yY[0,1]) 104 | # print('xX', xX[0,1]) 105 | # print('xY*xY', (xX*yY)[0,1]) 106 | coeff = (xX * yY).contiguous().view(B, -1) # B x N*4 107 | # print('coeff', coeff[0,:8]) 108 | coeff = coeff.unsqueeze(dim=1).expand(-1, C, -1) # B x C x N*4 109 | 110 | # print('H', H, 'W', W) 111 | indices = (iY * W + iX).view(B, N * 4) # B x N*4 112 | # print('indices', indices[0,0:8]) 113 | indices = indices.unsqueeze(dim=1).expand(-1, C, -1) # B x C x N*4 114 | # print('2.indices', indices[0, 0, :8], indices[0, 1, :8]) 115 | UF = torch.gather(feature, 2, indices) # B x C x N*4 116 | 117 | UF = UF * coeff # B x 118 | # np.savetxt('UF', UF[0,:,:8].detach().cpu().numpy() ) 119 | UF = UF.view(B, C, N, -1) 120 | # np.savetxt('UF2', UF[0,:,1,:].detach().cpu().numpy() ) 121 | UF = UF.sum(dim=3) # B x C x N 122 | # print('UF',UF.shape) 123 | UF = self.maskoff(UF, keypoints) 124 | # print('UF', UF.shape) 125 | # print('UF', UF.shape) 126 | return UF 127 | 128 | 129 | class LocationInterpolator(nn.Module): 130 | def __init__(self, im_fe_ratio): 131 | super(LocationInterpolator, self).__init__() 132 | self.interpolator = Interpolator(im_fe_ratio) 133 | 134 | def forward(self, ijB_A, keypoints): 135 | """ 136 | LocationInterpolator() is to collect a set of interpolated correspondence pixel 137 | locations 138 | Arguments: 139 | ijB_A [long tensor]: B x 2 x H x W : is the tensor storing the 2D pixel 140 | locations from source image A to targe image B 141 | keypoints [float tensor] B x N x 2: key points 142 | Return: 143 | xyB_A [float tensor]: B x N x 2 the interpolated correspondnce map for the set of sparse 144 | key points. 145 | note that the rows corresponding to invalid key points 146 | are masked off as zeros. 147 | """ 148 | xyB_A = ( 149 | self.interpolator(ijB_A.float(), keypoints) * self.interpolator.im_fe_ratio 150 | ) 151 | return xyB_A.transpose(2, 1) 152 | 153 | 154 | class InverInterpolator(Interpolator): 155 | def __init__(self, im_fe_ratio, kernel_size=5, N=32, mode=1): 156 | super(InverInterpolator, self).__init__(im_fe_ratio) 157 | 158 | self.kernel_size = kernel_size 159 | self.mode = mode 160 | if kernel_size > 0: 161 | # add gaussian 162 | gaussian_filter = nn.Conv2d( 163 | in_channels=N, 164 | out_channels=N, 165 | padding_mode="zeros", 166 | padding=(int(kernel_size / 2), int(kernel_size / 2)), 167 | kernel_size=kernel_size, 168 | groups=N, 169 | bias=False, 170 | ) 171 | if kernel_size == 3: 172 | gk = torch.FloatTensor( 173 | np.array( 174 | [ 175 | [1 / 16.0, 1 / 8.0, 1 / 16.0], 176 | [1 / 8.0, 1 / 4.0, 1 / 8.0], 177 | [1 / 16.0, 1 / 8.0, 1 / 16.0], 178 | ] 179 | ) 180 | ) 181 | elif kernel_size == 5: 182 | gk = ( 183 | torch.FloatTensor( 184 | np.array( 185 | [ 186 | [1, 4, 7, 4, 1], 187 | [4, 16, 26, 16, 4], 188 | [7, 26, 41, 26, 7], 189 | [4, 16, 26, 16, 4], 190 | [1, 4, 7, 4, 1], 191 | ] 192 | ) 193 | ) 194 | / 273.0 195 | ) 196 | elif kernel_size == 7: 197 | gk = torch.FloatTensor( 198 | np.array( 199 | [ 200 | [0, 0, 0, 5, 0, 0, 0], 201 | [0, 5, 18, 32, 18, 5, 0], 202 | [0, 18, 64, 100, 64, 18, 0], 203 | [5, 32, 100, 100, 100, 32, 5], 204 | [0, 18, 64, 100, 64, 18, 0], 205 | [0, 5, 18, 32, 18, 5, 0], 206 | [0, 0, 0, 5, 0, 0, 0], 207 | ] 208 | ) 209 | ) 210 | gk /= gk.sum() 211 | 212 | gk = gk.unsqueeze(0).unsqueeze(1) 213 | gk = gk.expand(N, -1, -1, -1) 214 | gaussian_filter.weight.data = gk 215 | gaussian_filter.weight.requires_grad = False 216 | self.gaussian_filter = gaussian_filter.cuda() 217 | 218 | def get_1nn(self, Xg, keypoint_g, H, W): 219 | """ 220 | Arguments: 221 | Xg [tensor] B x N x N 222 | keypoint_g [tensor] B x N x 2 223 | H height, resolution of H and W 224 | W width 225 | Return 226 | onehot [tensor] B x N x HW 227 | """ 228 | B, N, _ = keypoint_g.shape 229 | xyGt = ( 230 | torch.bmm(Xg, keypoint_g) / self.im_fe_ratio - 0.5 231 | ) # B x N x 2 float gt coordinate in feature map 232 | maxXY, minXY = self.getMaxMinXY(B, N, H, W) 233 | 234 | boundedXY = torch.max(xyGt, minXY) # B x N x 2, index of X, Y in feature map 235 | boundedXY = torch.min( 236 | boundedXY, maxXY 237 | ) # B x N x 2, index of X, Y in feature map 238 | 239 | boundedXY = boundedXY.long() 240 | indices = boundedXY[:, :, 1] * W + boundedXY[:, :, 0] # B x N x 1 241 | indices = indices.unsqueeze(2) 242 | coeff = torch.ones(B, N, 1).cuda() 243 | onehot = torch.zeros(B, N, H * W).cuda() 244 | onehot.scatter_(dim=2, index=indices, src=coeff) 245 | if self.kernel_size > 0: 246 | onehot = self.gaussian_filter(onehot.view(B, N, H, W)).view( 247 | B, N, H * W 248 | ) # add gaussian blur 249 | mask = Xg.sum(dim=2, keepdim=True).expand_as(onehot) # B x N x HW 250 | onehot *= mask 251 | 252 | # for n in range(N): 253 | # print(xyGt[0][n]) 254 | # plt.imshow(onehot[0][n].cpu().view(H,W)) 255 | # plt.show() 256 | return onehot 257 | 258 | def get_4nn(self, Xg, keypoint_g, H, W): 259 | """ 260 | Arguments: 261 | Xg [tensor] B x N x N 262 | keypoint_g [tensor] B x N x 2 263 | H height, resolution of H and W 264 | W width 265 | Return 266 | onehot [tensor] B x N x HW 267 | """ 268 | # convert into feature map coordinate 269 | B, N, _ = keypoint_g.shape 270 | xyGt = ( 271 | torch.bmm(Xg, keypoint_g) / self.im_fe_ratio - 0.5 272 | ) # B x N x 2 float gt coordinate in feature map 273 | maxXY, minXY = self.getMaxMinXY(B, N, H, W) 274 | 275 | # nearest neighbours 276 | iLower = torch.max( 277 | torch.floor(xyGt), minXY 278 | ) # B x N x 2, index of X, Y in feature map 279 | iUpper = torch.min(torch.ceil(xyGt), maxXY) # B x N x 2, 280 | upper = xyGt - iLower # note that weight is the 1 - distance 281 | lower = 1 - upper # B x N x 2 282 | 283 | iX = torch.cat((iLower[:, :, 0].unsqueeze(2), iUpper[:, :, 0].unsqueeze(2)), 2) 284 | iY = torch.cat((iLower[:, :, 1].unsqueeze(2), iUpper[:, :, 1].unsqueeze(2)), 2) 285 | xX = torch.cat((lower[:, :, 0].unsqueeze(2), upper[:, :, 0].unsqueeze(2)), 2) 286 | yY = torch.cat((lower[:, :, 1].unsqueeze(2), upper[:, :, 1].unsqueeze(2)), 2) 287 | 288 | iX = ( 289 | iX.unsqueeze(2).expand(-1, -1, 2, -1).long() 290 | ) # B x 32 x 2 x 2 ( x0 x1; x0 x1 ) 291 | iY = iY.unsqueeze(2).expand(-1, -1, 2, -1).transpose(2, 3).long() 292 | xX = xX.unsqueeze(2).expand(-1, -1, 2, -1) 293 | yY = yY.unsqueeze(2).expand(-1, -1, 2, -1).transpose(2, 3) 294 | 295 | iX = iX.view(B, N, -1) # B x N x 4 296 | iY = iY.view(B, N, -1) 297 | xX = xX.contiguous().view(B, N, -1) 298 | yY = yY.contiguous().view(B, N, -1) 299 | 300 | coeff = (xX * yY).contiguous() # B x N x 4 301 | indices = iY * W + iX # B x N x4 302 | 303 | onehot = torch.zeros(B, N, H * W).cuda() 304 | onehot.scatter_(dim=2, index=indices, src=coeff) 305 | if self.kernel_size > 0: 306 | onehot = self.gaussian_filter(onehot.view(B, N, H, W)).view( 307 | B, N, H * W 308 | ) # add gaussian blur 309 | mask = Xg.sum(dim=2, keepdim=True).expand_as(onehot) # B x N x HW 310 | onehot *= mask 311 | 312 | # for n in range(N): 313 | # print(xyGt[0][n]) 314 | # plt.imshow(onehot[0][n].cpu().view(H,W)) 315 | # plt.show() 316 | return onehot 317 | 318 | def forward(self, Xg, keypoint_g, H, W): 319 | """ 320 | Arguments: 321 | Xg [tensor] B x N x N 322 | keypoint_g [tensor] B x N x 2 323 | H height, resolution of H and W 324 | W width 325 | Return 326 | onehot [tensor] B x N x HW 327 | """ 328 | if self.mode == 0: 329 | return self.get_1nn(Xg, keypoint_g, H, W) 330 | elif self.mode == 1: 331 | return self.get_4nn(Xg, keypoint_g, H, W) 332 | 333 | return self.get_4nn(Xg, keypoint_g, H, W) 334 | 335 | -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import torchvision.models as models 8 | import pretrainedmodels 9 | import numpy as np 10 | import numpy.matlib 11 | import pickle 12 | 13 | # import gluoncvth as gcv 14 | from lib.torch_util import Softmax1D 15 | from lib.conv4d import Conv4d 16 | 17 | 18 | def featureL2Norm(feature): 19 | epsilon = 1e-6 20 | norm = ( 21 | torch.pow(torch.sum(torch.pow(feature, 2), 1) + epsilon, 0.5) 22 | .unsqueeze(1) 23 | .expand_as(feature) 24 | ) 25 | return torch.div(feature, norm) 26 | 27 | 28 | class FeatureExtraction(torch.nn.Module): 29 | def get_feature_backbone(self, model): 30 | resnet_feature_layers = [ 31 | "conv1", 32 | "bn1", 33 | "relu", 34 | "maxpool", 35 | "layer1", 36 | "layer2", 37 | "layer3", 38 | "layer4", 39 | ] 40 | last_layer = "layer3" 41 | resnet_module_list = [getattr(model, l) for l in resnet_feature_layers] 42 | last_layer_idx = resnet_feature_layers.index(last_layer) 43 | model = nn.Sequential(*resnet_module_list[: last_layer_idx + 1]) 44 | return model 45 | 46 | def __init__( 47 | self, 48 | train_fe=False, 49 | feature_extraction_cnn="resnet101", 50 | feature_extraction_model_file="", 51 | normalization=True, 52 | last_layer="", 53 | use_cuda=True, 54 | ): 55 | super(FeatureExtraction, self).__init__() 56 | self.normalization = normalization 57 | self.feature_extraction_cnn = feature_extraction_cnn 58 | if feature_extraction_cnn == "vgg": 59 | self.model = models.vgg16(pretrained=True) 60 | # keep feature extraction network up to indicated layer 61 | vgg_feature_layers = [ 62 | "conv1_1", 63 | "relu1_1", 64 | "conv1_2", 65 | "relu1_2", 66 | "pool1", 67 | "conv2_1", 68 | "relu2_1", 69 | "conv2_2", 70 | "relu2_2", 71 | "pool2", 72 | "conv3_1", 73 | "relu3_1", 74 | "conv3_2", 75 | "relu3_2", 76 | "conv3_3", 77 | "relu3_3", 78 | "pool3", 79 | "conv4_1", 80 | "relu4_1", 81 | "conv4_2", 82 | "relu4_2", 83 | "conv4_3", 84 | "relu4_3", 85 | "pool4", 86 | "conv5_1", 87 | "relu5_1", 88 | "conv5_2", 89 | "relu5_2", 90 | "conv5_3", 91 | "relu5_3", 92 | "pool5", 93 | ] 94 | if last_layer == "": 95 | last_layer = "pool4" 96 | last_layer_idx = vgg_feature_layers.index(last_layer) 97 | self.model = nn.Sequential( 98 | *list(self.model.features.children())[: last_layer_idx + 1] 99 | ) 100 | # for resnet below 101 | resnet_feature_layers = [ 102 | "conv1", 103 | "bn1", 104 | "relu", 105 | "maxpool", 106 | "layer1", 107 | "layer2", 108 | "layer3", 109 | "layer4", 110 | ] 111 | if feature_extraction_cnn == "resnet101": 112 | self.model = models.resnet101(pretrained=True) 113 | self.model = self.get_feature_backbone(self.model) 114 | 115 | if feature_extraction_cnn == "resnet50": 116 | self.model = models.resnet50(pretrained=True) 117 | self.model = self.get_feature_backbone(self.model) 118 | 119 | if feature_extraction_cnn == "resnet152": 120 | model_name = "resnet152" # could be fbresnet152 or inceptionresnetv2 121 | model = pretrainedmodels.__dict__[model_name]( 122 | num_classes=1000, pretrained="imagenet" 123 | ) 124 | self.model = self.get_feature_backbone(model) 125 | 126 | if feature_extraction_cnn == "resnet101fcn": 127 | self.model = gcv.models.get_fcn_resnet101_voc(pretrained=True) 128 | 129 | if feature_extraction_cnn == "resnext101": 130 | model_name = "resnext101_32x4d" # could be fbresnet152 or inceptionresnetv2 131 | model = pretrainedmodels.__dict__[model_name]( 132 | num_classes=1000, pretrained="imagenet" 133 | ) 134 | self.model = model.features[:-1] 135 | 136 | if feature_extraction_cnn == "resnext101_64x4d": 137 | model_name = "resnext101_64x4d" 138 | model = pretrainedmodels.__dict__[model_name]( 139 | num_classes=1000, pretrained="imagenet" 140 | ) 141 | self.model = model.features[:-1] 142 | 143 | if feature_extraction_cnn == "resnet101fpn": 144 | if feature_extraction_model_file != "": 145 | resnet = models.resnet101(pretrained=True) 146 | # swap stride (2,2) and (1,1) in first layers (PyTorch ResNet is slightly different to caffe2 ResNet) 147 | # this is required for compatibility with caffe2 models 148 | resnet.layer2[0].conv1.stride = (2, 2) 149 | resnet.layer2[0].conv2.stride = (1, 1) 150 | resnet.layer3[0].conv1.stride = (2, 2) 151 | resnet.layer3[0].conv2.stride = (1, 1) 152 | resnet.layer4[0].conv1.stride = (2, 2) 153 | resnet.layer4[0].conv2.stride = (1, 1) 154 | else: 155 | resnet = models.resnet101(pretrained=True) 156 | resnet_module_list = [getattr(resnet, l) for l in resnet_feature_layers] 157 | conv_body = nn.Sequential(*resnet_module_list) 158 | self.model = fpn_body( 159 | conv_body, 160 | resnet_feature_layers, 161 | fpn_layers=["layer1", "layer2", "layer3"], 162 | normalize=normalization, 163 | hypercols=True, 164 | ) 165 | if feature_extraction_model_file != "": 166 | self.model.load_pretrained_weights(feature_extraction_model_file) 167 | 168 | if feature_extraction_cnn == "densenet201": 169 | self.model = models.densenet201(pretrained=True) 170 | self.model = nn.Sequential(*list(self.model.features.children())[:-3]) 171 | 172 | if train_fe == False: 173 | # freeze parameters 174 | for param in self.model.parameters(): 175 | param.requires_grad = False 176 | # move to GPU 177 | if use_cuda: 178 | self.model = self.model.cuda() 179 | 180 | def forward(self, image_batch): 181 | if self.feature_extraction_cnn == "resnet101fcn": 182 | features = self.model(image_batch) 183 | features = torch.cat((features[0], features[1]), 1) 184 | else: 185 | features = self.model(image_batch) 186 | 187 | if self.normalization and not self.feature_extraction_cnn == "resnet101fpn": 188 | features = featureL2Norm(features) 189 | 190 | return features 191 | 192 | 193 | class FeatureCorrelation(torch.nn.Module): 194 | def __init__(self, normalization=True): 195 | super(FeatureCorrelation, self).__init__() 196 | self.normalization = normalization 197 | self.ReLU = nn.ReLU() 198 | 199 | def forward(self, feature_A, feature_B): 200 | b, c, hA, wA = feature_A.size() 201 | b, c, hB, wB = feature_B.size() 202 | # reshape features for matrix multiplication 203 | feature_A = feature_A.view(b, c, hA * wA).transpose(1, 2) # size [b,c,h*w] 204 | feature_B = feature_B.view(b, c, hB * wB) # size [b,c,h*w] 205 | # perform matrix mult. 206 | feature_mul = torch.bmm(feature_A, feature_B) 207 | # indexed [batch,row_A,col_A,row_B,col_B] 208 | correlation_tensor = feature_mul.view(b, hA, wA, hB, wB).unsqueeze(1) 209 | 210 | if self.normalization: 211 | correlation_tensor = featureL2Norm(self.ReLU(correlation_tensor)) 212 | 213 | return correlation_tensor 214 | 215 | 216 | class SpatialContextNet(torch.nn.Module): 217 | def __init__(self, kernel_size=5, output_channel=1024, use_cuda=True): 218 | super(SpatialContextNet, self).__init__() 219 | self.kernel_size = kernel_size 220 | self.pad = kernel_size // 2 221 | self.conv = torch.nn.Conv2d( 222 | 1024 + self.kernel_size * self.kernel_size, 223 | output_channel, 224 | 1, 225 | bias=True, 226 | padding_mode="zeros", 227 | ) 228 | if use_cuda: 229 | self.conv = self.conv.cuda() 230 | 231 | def forward(self, feature): 232 | b, c, h, w = feature.size() 233 | feature_normalized = F.normalize(feature, p=2, dim=1) 234 | feature_pad = F.pad( 235 | feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0 236 | ) 237 | output = torch.zeros( 238 | [self.kernel_size * self.kernel_size, b, h, w], 239 | dtype=feature.dtype, 240 | requires_grad=feature.requires_grad, 241 | ) 242 | if feature.is_cuda: 243 | output = output.cuda(feature.get_device()) 244 | for c in range(self.kernel_size): 245 | for r in range(self.kernel_size): 246 | output[c * self.kernel_size + r] = ( 247 | feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized 248 | ).sum(1) 249 | 250 | output = output.transpose(0, 1).contiguous() 251 | output = torch.cat((feature, output), 1) 252 | output = self.conv(output) 253 | output = F.relu(output) 254 | return output 255 | 256 | 257 | class Pairwise(torch.nn.Module): 258 | def __init__(self, context_size=5, output_channel=128, use_cuda=True): 259 | super(Pairwise, self).__init__() 260 | self.context_size = context_size 261 | self.pad = context_size // 2 262 | self.conv = torch.nn.Conv2d( 263 | self.context_size * self.context_size, 264 | output_channel * 2, 265 | 3, 266 | padding=(1, 1), 267 | bias=True, 268 | padding_mode="zeros", 269 | ) 270 | self.conv1 = torch.nn.Conv2d( 271 | output_channel * 2, 272 | output_channel, 273 | 3, 274 | padding=(1, 1), 275 | bias=True, 276 | padding_mode="zeros", 277 | ) 278 | self.conv2 = torch.nn.Conv2d( 279 | output_channel, 280 | output_channel, 281 | 3, 282 | padding=(1, 1), 283 | bias=True, 284 | padding_mode="zeros", 285 | ) 286 | if use_cuda: 287 | self.conv = self.conv.cuda() 288 | self.conv1 = self.conv1.cuda() 289 | self.conv2 = self.conv2.cuda() 290 | 291 | def self_similarity(self, feature_normalized): 292 | b, c, h, w = feature_normalized.size() 293 | feature_pad = F.pad( 294 | feature_normalized, (self.pad, self.pad, self.pad, self.pad), "constant", 0 295 | ) 296 | output = torch.zeros( 297 | [self.context_size * self.context_size, b, h, w], 298 | dtype=feature_normalized.dtype, 299 | requires_grad=feature_normalized.requires_grad, 300 | ) 301 | if feature_normalized.is_cuda: 302 | output = output.cuda(feature_normalized.get_device()) 303 | for c in range(self.context_size): 304 | for r in range(self.context_size): 305 | output[c * self.context_size + r] = ( 306 | feature_pad[:, :, r : (h + r), c : (w + c)] * feature_normalized 307 | ).sum(1) 308 | 309 | output = output.transpose(0, 1).contiguous() 310 | return output 311 | 312 | def forward(self, feature): 313 | feature_normalized = F.normalize(feature, p=2, dim=1) 314 | ss = self.self_similarity(feature_normalized) 315 | 316 | ss1 = F.relu(self.conv(ss)) 317 | ss2 = F.relu(self.conv1(ss1)) 318 | output = torch.cat((ss, ss1, ss2), 1) 319 | return output 320 | 321 | 322 | def CreateCon4D(k1, k2, channels): 323 | num_layers = len(channels) 324 | nn_modules = list() 325 | for i in range(1, num_layers): 326 | nn_modules.append( 327 | Conv4d( 328 | in_channels=channels[i - 1], 329 | out_channels=channels[i], 330 | kernel_size=[k1, k1, k2, k2], 331 | bias=True, 332 | ) 333 | ) 334 | nn_modules.append(nn.ReLU(inplace=True)) 335 | conv = nn.Sequential(*nn_modules) 336 | return conv 337 | 338 | 339 | class NeighConsensus(torch.nn.Module): 340 | def __init__( 341 | self, 342 | use_cuda=True, 343 | kernel_sizes=[3, 3, 3], 344 | channels=[1, 10, 10, 1], 345 | symmetric_mode=True, 346 | ): 347 | super(NeighConsensus, self).__init__() 348 | self.symmetric_mode = symmetric_mode 349 | self.conv = CreateCon4D(kernel_sizes[0], kernel_sizes[0], channels) 350 | if use_cuda: 351 | self.conv.cuda() 352 | 353 | def forward(self, x): 354 | if self.symmetric_mode: 355 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 356 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 357 | x = self.conv(x) + self.conv(x.permute(0, 1, 4, 5, 2, 3)).permute( 358 | 0, 1, 4, 5, 2, 3 359 | ) 360 | # because of the ReLU layers in between linear layers, 361 | # this operation is different than convolving a single time with the filters+filters^T 362 | # and therefore it makes sense to do this. 363 | else: 364 | x = self.conv(x) 365 | return x 366 | 367 | 368 | class NonIsotropicNCA(torch.nn.Module): 369 | def __init__(self, use_cuda=True, channels=[1, 16, 16, 1], symmetric_mode=True): 370 | super(NonIsotropicNCA, self).__init__() 371 | self.symmetric_mode = symmetric_mode 372 | self.conv0 = CreateCon4D(3, 5, [1, 8]) 373 | self.conv1 = CreateCon4D(5, 5, [1, 8]) 374 | self.conv2 = CreateCon4D(5, 5, [16, 16, 1]) 375 | if use_cuda: 376 | self.conv0.cuda() 377 | self.conv1.cuda() 378 | self.conv2.cuda() 379 | 380 | def forward(self, x): 381 | if self.symmetric_mode: 382 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 383 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 384 | x0 = self.conv0(x) + self.conv0(x.permute(0, 1, 4, 5, 2, 3)).permute( 385 | 0, 1, 4, 5, 2, 3 386 | ) 387 | x1 = self.conv1(x) + self.conv1(x.permute(0, 1, 4, 5, 2, 3)).permute( 388 | 0, 1, 4, 5, 2, 3 389 | ) 390 | x = torch.cat((x0, x1), 1) 391 | x = self.conv2(x) + self.conv2(x.permute(0, 1, 4, 5, 2, 3)).permute( 392 | 0, 1, 4, 5, 2, 3 393 | ) 394 | # because of the ReLU layers in between linear layers, 395 | # this operation is different than convolving a single time with the filters+filters^T 396 | # and therefore it makes sense to do this. 397 | else: 398 | x0 = self.conv0(x) 399 | x1 = self.conv1(x) 400 | x = torch.cat((x0, x1), 1) 401 | x = self.conv2(x) 402 | return x 403 | 404 | 405 | class NonIsotropicNCB(torch.nn.Module): 406 | def __init__(self, use_cuda=True, channels=[1, 16, 16, 1], symmetric_mode=True): 407 | super(NonIsotropicNCB, self).__init__() 408 | self.symmetric_mode = symmetric_mode 409 | self.conv0 = CreateCon4D(5, 5, [1, 16]) 410 | self.conv10 = CreateCon4D(3, 5, [16, 8]) 411 | self.conv11 = CreateCon4D(5, 5, [16, 8]) 412 | self.conv2 = CreateCon4D(5, 5, [16, 1]) 413 | if use_cuda: 414 | self.conv0.cuda() 415 | self.conv10.cuda() 416 | self.conv11.cuda() 417 | self.conv2.cuda() 418 | 419 | def forward(self, x): 420 | if self.symmetric_mode: 421 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 422 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 423 | x = self.conv0(x) + self.conv0(x.permute(0, 1, 4, 5, 2, 3)).permute( 424 | 0, 1, 4, 5, 2, 3 425 | ) 426 | x0 = self.conv10(x) + self.conv10(x.permute(0, 1, 4, 5, 2, 3)).permute( 427 | 0, 1, 4, 5, 2, 3 428 | ) 429 | x1 = self.conv11(x) + self.conv11(x.permute(0, 1, 4, 5, 2, 3)).permute( 430 | 0, 1, 4, 5, 2, 3 431 | ) 432 | x = torch.cat((x0, x1), 1) 433 | x = self.conv2(x) + self.conv2(x.permute(0, 1, 4, 5, 2, 3)).permute( 434 | 0, 1, 4, 5, 2, 3 435 | ) 436 | # because of the ReLU layers in between linear layers, 437 | # this operation is different than convolving a single time with the filters+filters^T 438 | # and therefore it makes sense to do this. 439 | else: 440 | x = self.conv0(x) 441 | x0 = self.conv10(x) 442 | x1 = self.conv11(x) 443 | x = torch.cat((x0, x1), 1) 444 | x = self.conv2(x) 445 | # because of the ReLU layers in between linear layers, 446 | return x 447 | 448 | 449 | class NonIsotropicNCC(torch.nn.Module): 450 | def __init__(self, use_cuda=True, channels=[1, 16, 16, 1], symmetric_mode=True): 451 | super(NonIsotropicNCC, self).__init__() 452 | self.symmetric_mode = symmetric_mode 453 | self.conv00 = CreateCon4D(3, 5, [1, 8]) 454 | self.conv01 = CreateCon4D(5, 5, [1, 8]) 455 | self.conv10 = CreateCon4D(3, 5, [16, 8]) 456 | self.conv11 = CreateCon4D(5, 5, [16, 8]) 457 | self.conv2 = CreateCon4D(5, 5, [16, 1]) 458 | if use_cuda: 459 | self.conv00.cuda() 460 | self.conv01.cuda() 461 | self.conv10.cuda() 462 | self.conv11.cuda() 463 | self.conv2.cuda() 464 | 465 | def forward(self, x): 466 | if self.symmetric_mode: 467 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 468 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 469 | x0 = self.conv00(x) + self.conv00(x.permute(0, 1, 4, 5, 2, 3)).permute( 470 | 0, 1, 4, 5, 2, 3 471 | ) 472 | x1 = self.conv01(x) + self.conv01(x.permute(0, 1, 4, 5, 2, 3)).permute( 473 | 0, 1, 4, 5, 2, 3 474 | ) 475 | x = torch.cat((x0, x1), 1) 476 | x0 = self.conv10(x) + self.conv10(x.permute(0, 1, 4, 5, 2, 3)).permute( 477 | 0, 1, 4, 5, 2, 3 478 | ) 479 | x1 = self.conv11(x) + self.conv11(x.permute(0, 1, 4, 5, 2, 3)).permute( 480 | 0, 1, 4, 5, 2, 3 481 | ) 482 | x = torch.cat((x0, x1), 1) 483 | x = self.conv2(x) + self.conv2(x.permute(0, 1, 4, 5, 2, 3)).permute( 484 | 0, 1, 4, 5, 2, 3 485 | ) 486 | # because of the ReLU layers in between linear layers, 487 | # this operation is different than convolving a single time with the filters+filters^T 488 | # and therefore it makes sense to do this. 489 | else: 490 | x0 = self.conv00(x) 491 | x1 = self.conv01(x) 492 | x = torch.cat((x0, x1), 1) 493 | x0 = self.conv10(x) 494 | x1 = self.conv11(x) 495 | x = torch.cat((x0, x1), 1) 496 | x = self.conv2(x) 497 | return x 498 | 499 | 500 | def MutualMatching(corr4d): 501 | # mutual matching 502 | batch_size, ch, fs1, fs2, fs3, fs4 = corr4d.size() 503 | 504 | corr4d_B = corr4d.view(batch_size, fs1 * fs2, fs3, fs4) # [batch_idx,k_A,i_B,j_B] 505 | corr4d_A = corr4d.view(batch_size, fs1, fs2, fs3 * fs4) 506 | 507 | # get max 508 | corr4d_B_max, _ = torch.max(corr4d_B, dim=1, keepdim=True) 509 | corr4d_A_max, _ = torch.max(corr4d_A, dim=3, keepdim=True) 510 | 511 | eps = 1e-5 512 | corr4d_B = corr4d_B / (corr4d_B_max + eps) 513 | corr4d_A = corr4d_A / (corr4d_A_max + eps) 514 | 515 | corr4d_B = corr4d_B.view(batch_size, 1, fs1, fs2, fs3, fs4) 516 | corr4d_A = corr4d_A.view(batch_size, 1, fs1, fs2, fs3, fs4) 517 | 518 | corr4d = corr4d * ( 519 | corr4d_A * corr4d_B 520 | ) # parenthesis are important for symmetric output 521 | 522 | return corr4d 523 | 524 | 525 | def maxpool4d(corr4d_hres, k_size=4): 526 | slices = [] 527 | for i in range(k_size): 528 | for j in range(k_size): 529 | for k in range(k_size): 530 | for l in range(k_size): 531 | slices.append( 532 | corr4d_hres[ 533 | :, 0, i::k_size, j::k_size, k::k_size, l::k_size 534 | ].unsqueeze(0) 535 | ) 536 | slices = torch.cat(tuple(slices), dim=1) 537 | corr4d, max_idx = torch.max(slices, dim=1, keepdim=True) 538 | max_l = torch.fmod(max_idx, k_size) 539 | max_k = torch.fmod(max_idx.sub(max_l).div(k_size), k_size) 540 | max_j = torch.fmod(max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size), k_size) 541 | max_i = max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size).sub(max_j).div(k_size) 542 | # i,j,k,l represent the *relative* coords of the max point in the box of size k_size*k_size*k_size*k_size 543 | return (corr4d, max_i, max_j, max_k, max_l) 544 | 545 | 546 | class ImMatchNet(nn.Module): 547 | def __init__( 548 | self, 549 | feature_extraction_cnn="resnet101", 550 | feature_extraction_last_layer="", 551 | feature_extraction_model_file=None, 552 | return_correlation=False, 553 | ncons_kernel_sizes=[3, 3, 3], 554 | ncons_channels=[1, 10, 10, 1], 555 | normalize_features=True, 556 | train_fe=False, 557 | use_cuda=True, 558 | relocalization_k_size=0, 559 | half_precision=False, 560 | checkpoint=None, 561 | pss=True, 562 | noniso=1, 563 | ): 564 | 565 | super(ImMatchNet, self).__init__() 566 | self.use_cuda = use_cuda 567 | self.normalize_features = normalize_features 568 | self.return_correlation = return_correlation 569 | self.relocalization_k_size = relocalization_k_size 570 | self.half_precision = half_precision 571 | self.pss = pss 572 | self.noniso = noniso 573 | 574 | self.FeatureExtraction = FeatureExtraction( 575 | train_fe=train_fe, 576 | feature_extraction_cnn=feature_extraction_cnn, 577 | feature_extraction_model_file=feature_extraction_model_file, 578 | last_layer=feature_extraction_last_layer, 579 | normalization=normalize_features, 580 | use_cuda=self.use_cuda, 581 | ) 582 | 583 | self.FeatureCorrelation = FeatureCorrelation(normalization=False) 584 | 585 | if self.noniso == 0: 586 | self.NeighConsensus = NeighConsensus( 587 | use_cuda=self.use_cuda, 588 | kernel_sizes=ncons_kernel_sizes, 589 | channels=ncons_channels, 590 | ) 591 | elif self.noniso == 1: 592 | self.NeighConsensus = NonIsotropicNCA( 593 | use_cuda=self.use_cuda, channels=ncons_channels 594 | ) 595 | elif self.noniso == 2: 596 | self.NeighConsensus = NonIsotropicNCB( 597 | use_cuda=self.use_cuda, channels=ncons_channels 598 | ) 599 | elif self.noniso == 3: 600 | self.NeighConsensus = NonIsotropicNCC( 601 | use_cuda=self.use_cuda, channels=ncons_channels 602 | ) 603 | 604 | if self.pss == 1: 605 | self.SS = Pairwise( 606 | ncons_kernel_sizes[0], output_channel=32, use_cuda=self.use_cuda 607 | ) 608 | elif self.pss == 2: 609 | self.SS = SpatialContextNet( 610 | ncons_kernel_sizes[0], output_channel=256, use_cuda=self.use_cuda 611 | ) 612 | # Load weights 613 | if checkpoint is not None and checkpoint is not "": 614 | print("Copying weights...") 615 | for name, param in self.FeatureExtraction.state_dict().items(): 616 | if "num_batches_tracked" not in name: 617 | self.FeatureExtraction.state_dict()[name].copy_( 618 | checkpoint["state_dict"]["module.FeatureExtraction." + name] 619 | ) 620 | for name, param in self.NeighConsensus.state_dict().items(): 621 | self.NeighConsensus.state_dict()[name].copy_( 622 | checkpoint["state_dict"]["module.NeighConsensus." + name] 623 | ) 624 | if self.pss > 0: 625 | for name, param in self.SS.state_dict().items(): 626 | self.SS.state_dict()[name].copy_( 627 | checkpoint["state_dict"]["module.SS." + name] 628 | ) 629 | 630 | print("Done!") 631 | 632 | self.FeatureExtraction.eval() 633 | 634 | if self.half_precision: 635 | for p in self.NeighConsensus.parameters(): 636 | p.data = p.data.half() 637 | for l in self.NeighConsensus.conv: 638 | if isinstance(l, Conv4d): 639 | l.use_half = True 640 | 641 | def ncnet(self, corr4d): 642 | corr4d = MutualMatching(corr4d) 643 | corr4d = self.NeighConsensus(corr4d) 644 | corr4d = MutualMatching(corr4d) 645 | return corr4d 646 | 647 | # used only for foward pass at eval and for training with strong supervision 648 | def forward(self, tnf_batch): 649 | """ 650 | Arguments: 651 | tnf_batch [dict]: source_image B x 3 x H x W image 256 x 256 652 | target_image B x 3 x H x W image 653 | Return: 654 | corr4d [Tensor float]: B x 1 x 16 x 16 x 16 x 16 655 | """ 656 | # feature extraction 657 | feature_A = self.FeatureExtraction(tnf_batch["source_image"]) # B x C x 16 x 16 658 | feature_B = self.FeatureExtraction(tnf_batch["target_image"]) 659 | 660 | if self.half_precision: 661 | feature_A = feature_A.half() 662 | feature_B = feature_B.half() 663 | # feature correlation 664 | corr4d = self.FeatureCorrelation( 665 | feature_A, feature_B 666 | ) # B x 1 x 16 x 16 x 16 x 16 667 | # do 4d maxpooling for relocalization 668 | if self.relocalization_k_size > 1: 669 | corr4d, max_i, max_j, max_k, max_l = maxpool4d( 670 | corr4d, k_size=self.relocalization_k_size 671 | ) 672 | 673 | # run match processing model 674 | corr4d = self.ncnet(corr4d) 675 | 676 | # pss is pairwise term 677 | if self.pss > 0: 678 | selfsimilarity_A = self.SS(feature_A) 679 | selfsimilarity_B = self.SS(feature_B) 680 | corr4d_s = self.FeatureCorrelation(selfsimilarity_A, selfsimilarity_B) 681 | # do 4d maxpooling for relocalization 682 | if self.relocalization_k_size > 1: 683 | corr4d_s, max_i_s, max_j_s, max_k_s, max_l_s = maxpool4d( 684 | corr4d_s, k_size=self.relocalization_k_size 685 | ) 686 | 687 | # run match processing model 688 | corr4d_s = self.ncnet(corr4d_s) 689 | corr4d = 0.5 * corr4d + 0.5 * corr4d_s 690 | 691 | if self.relocalization_k_size > 1: 692 | delta4d = (max_i, max_j, max_k, max_l) 693 | delta4d_s = (max_i_s, max_j_s, max_k_s, max_l_s) 694 | 695 | return (corr4d, delta4d, delta4d_s) 696 | else: 697 | return corr4d 698 | -------------------------------------------------------------------------------- /lib/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | 5 | 6 | class NormalizeImageDict(object): 7 | """ 8 | 9 | Normalizes Tensor images in dictionary 10 | 11 | Args: 12 | image_keys (list): dict. keys of the images to be normalized 13 | normalizeRange (bool): if True the image is divided by 255.0s 14 | 15 | """ 16 | 17 | def __init__(self, image_keys, normalizeRange=True): 18 | self.image_keys = image_keys 19 | self.normalizeRange = normalizeRange 20 | self.normalize = transforms.Normalize( 21 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 22 | ) 23 | 24 | def __call__(self, sample): 25 | for key in self.image_keys: 26 | if self.normalizeRange: 27 | sample[key] /= 255.0 28 | sample[key] = self.normalize(sample[key]) 29 | return sample 30 | 31 | 32 | def normalize_image( 33 | image, forward=True, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 34 | ): 35 | im_size = image.size() 36 | mean = torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2) 37 | std = torch.FloatTensor(std).unsqueeze(1).unsqueeze(2) 38 | if image.is_cuda: 39 | mean = mean.cuda() 40 | std = std.cuda() 41 | if isinstance(image, torch.autograd.variable.Variable): 42 | mean = Variable(mean, requires_grad=False) 43 | std = Variable(std, requires_grad=False) 44 | if forward: 45 | if len(im_size) == 3: 46 | result = image.sub(mean.expand(im_size)).div(std.expand(im_size)) 47 | elif len(im_size) == 4: 48 | result = image.sub(mean.unsqueeze(0).expand(im_size)).div( 49 | std.unsqueeze(0).expand(im_size) 50 | ) 51 | else: 52 | if len(im_size) == 3: 53 | result = image.mul(std.expand(im_size)).add(mean.expand(im_size)) 54 | elif len(im_size) == 4: 55 | result = image.mul(std.unsqueeze(0).expand(im_size)).add( 56 | mean.unsqueeze(0).expand(im_size) 57 | ) 58 | 59 | return result 60 | 61 | -------------------------------------------------------------------------------- /lib/pf_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | from lib.transformation import AffineTnf 10 | import matplotlib.pyplot as plt 11 | 12 | 13 | class PFPascalDataset(Dataset): 14 | 15 | """ 16 | 17 | Proposal Flow PASCAL image pair dataset 18 | 19 | 20 | Args: 21 | csv_file (string): Path to the csv file with image names and transformations. 22 | dataset_path (string): Directory with the images. 23 | output_size (2-tuple): Desired output size 24 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 25 | 26 | """ 27 | 28 | def __init__( 29 | self, 30 | csv_file, 31 | dataset_path, 32 | output_size=(240, 240), 33 | transform=None, 34 | category=None, 35 | pck_procedure="pf", 36 | ): 37 | 38 | self.category_names = [ 39 | "aeroplane", 40 | "bicycle", 41 | "bird", 42 | "boat", 43 | "bottle", 44 | "bus", 45 | "car", 46 | "cat", 47 | "chair", 48 | "cow", 49 | "diningtable", 50 | "dog", 51 | "horse", 52 | "motorbike", 53 | "person", 54 | "pottedplant", 55 | "sheep", 56 | "sofa", 57 | "train", 58 | "tvmonitor", 59 | ] 60 | self.out_h, self.out_w = output_size 61 | self.pairs = pd.read_csv(csv_file) 62 | self.category = self.pairs.iloc[:, 2].values.astype("float") 63 | if category is not None: 64 | cat_idx = np.nonzero(self.category == category)[0] 65 | self.category = self.category[cat_idx] 66 | self.pairs = self.pairs.iloc[cat_idx, :] 67 | self.img_A_names = self.pairs.iloc[:, 0] 68 | self.img_B_names = self.pairs.iloc[:, 1] 69 | self.point_A_coords = self.pairs.iloc[:, 3:5] 70 | self.point_B_coords = self.pairs.iloc[:, 5:] 71 | self.dataset_path = dataset_path 72 | self.transform = transform 73 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 74 | self.affineTnf = AffineTnf(out_h=self.out_h, out_w=self.out_w, use_cuda=False) 75 | self.pck_procedure = pck_procedure 76 | 77 | def __len__(self): 78 | return len(self.pairs) 79 | 80 | def __getitem__(self, idx): 81 | # get pre-processed images 82 | image_A, im_size_A = self.get_image(self.img_A_names, idx) 83 | image_B, im_size_B = self.get_image(self.img_B_names, idx) 84 | 85 | # get pre-processed point coords 86 | point_A_coords = self.get_points(self.point_A_coords, idx) 87 | point_B_coords = self.get_points(self.point_B_coords, idx) 88 | 89 | # compute PCK reference length L_pck (equal to max bounding box side in image_A) 90 | # L_pck = torch.FloatTensor([torch.max(point_A_coords.max(1)[0]-point_A_coords.min(1)[0])]) 91 | N_pts = torch.sum(torch.ne(point_A_coords[0, :], -1)) 92 | 93 | if self.pck_procedure == "pf": 94 | L_pck = torch.FloatTensor( 95 | [ 96 | torch.max( 97 | point_A_coords[:, :N_pts].max(1)[0] 98 | - point_A_coords[:, :N_pts].min(1)[0] 99 | ) 100 | ] 101 | ) 102 | elif self.pck_procedure == "scnet": 103 | # modification to follow the evaluation procedure of SCNet 104 | point_A_coords[0, 0:N_pts] = ( 105 | point_A_coords[0, 0:N_pts] * self.out_w / im_size_A[1] 106 | ) 107 | point_A_coords[1, 0:N_pts] = ( 108 | point_A_coords[1, 0:N_pts] * self.out_h / im_size_A[0] 109 | ) 110 | 111 | point_B_coords[0, 0:N_pts] = ( 112 | point_B_coords[0, 0:N_pts] * self.out_w / im_size_B[1] 113 | ) 114 | point_B_coords[1, 0:N_pts] = ( 115 | point_B_coords[1, 0:N_pts] * self.out_h / im_size_B[0] 116 | ) 117 | 118 | im_size_A[0:2] = torch.FloatTensor([self.out_h, self.out_w]) 119 | im_size_B[0:2] = torch.FloatTensor([self.out_h, self.out_w]) 120 | 121 | L_pck = torch.FloatTensor([self.out_h]) 122 | 123 | sample = { 124 | "source_image": image_A, 125 | "target_image": image_B, 126 | "source_im_size": im_size_A, 127 | "target_im_size": im_size_B, 128 | "source_points": point_A_coords, 129 | "target_points": point_B_coords, 130 | "L_pck": L_pck, 131 | } 132 | 133 | # # get key points annotation 134 | # np_img_A = sample['source_image'].long().numpy().transpose(1,2,0) 135 | # np_img_B = sample['target_image'].long().numpy().transpose(1,2,0) 136 | 137 | # kp_A = sample['source_points'].transpose(1,0).numpy() 138 | # kp_B = sample['target_points'].transpose(1,0).numpy() 139 | 140 | # kp_A[:,0] *= self.out_w/float(im_size_A[1]) 141 | # kp_A[:,1] *= self.out_h/float(im_size_A[0]) 142 | # print('kp_A', kp_A) 143 | # print('L_pck', sample['L_pck']) 144 | # fig=plt.figure(figsize=(1, 2)) 145 | # ax0 = fig.add_subplot(1, 2, 1) 146 | # # ax0.add_patch(rect) 147 | # plt.imshow(np_img_A) 148 | # # dispaly bounding boxes 149 | # for i, kp in enumerate(kp_A): 150 | # if kp[0] == kp[0]: 151 | # ax0.scatter(kp[0],kp[1], s=5, color='r',alpha=1.) 152 | # ax1 = fig.add_subplot(1, 2, 2) 153 | # # rect = matplotlib.patches.Rectangle((bbox_B[0],bbox_B[1]),bbox_B[2]-bbox_B[0],bbox_B[3]-bbox_B[1],linewidth=1,edgecolor='r',facecolor='none') 154 | # # ax1.add_patch(rect) 155 | # plt.imshow(np_img_B) 156 | # for i, kp in enumerate(kp_B): 157 | # if kp[0] == kp[0]: 158 | # ax1.scatter(kp[0],kp[1], s=5, color='r',alpha=1.) 159 | # plt.show() 160 | 161 | if self.transform: 162 | sample = self.transform(sample) 163 | 164 | return sample 165 | 166 | def get_image(self, img_name_list, idx): 167 | img_name = os.path.join(self.dataset_path, img_name_list.iloc[idx]) 168 | image = io.imread(img_name) 169 | 170 | # get image size 171 | im_size = np.asarray(image.shape) 172 | 173 | # convert to torch Variable 174 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 175 | image = torch.Tensor(image.astype(np.float32)) 176 | image_var = Variable(image, requires_grad=False) 177 | 178 | # Resize image using bilinear sampling with identity affine tnf 179 | image = self.affineTnf(image_var).data.squeeze(0) 180 | 181 | im_size = torch.Tensor(im_size.astype(np.float32)) 182 | 183 | return (image, im_size) 184 | 185 | def get_points(self, point_coords_list, idx): 186 | X = np.fromstring(point_coords_list.iloc[idx, 0], sep=";") 187 | Y = np.fromstring(point_coords_list.iloc[idx, 1], sep=";") 188 | Xpad = -np.ones(20) 189 | Xpad[: len(X)] = X 190 | Ypad = -np.ones(20) 191 | Ypad[: len(X)] = Y 192 | point_coords = np.concatenate( 193 | (Xpad.reshape(1, 20), Ypad.reshape(1, 20)), axis=0 194 | ) 195 | 196 | # make arrays float tensor for subsequent processing 197 | point_coords = torch.Tensor(point_coords.astype(np.float32)) 198 | 199 | return point_coords 200 | 201 | -------------------------------------------------------------------------------- /lib/pf_pascal_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils.data import Dataset 6 | from skimage import io 7 | import pandas as pd 8 | import numpy as np 9 | from . import transformation as tf 10 | import scipy.io 11 | import matplotlib 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | class ImagePairDataset(Dataset): 16 | 17 | """ 18 | 19 | Image pair dataset used for weak supervision 20 | 21 | 22 | Args: 23 | csv_file (string): Path to the csv file with image names and transformations. 24 | training_image_path (string): Directory with the images. 25 | output_size (2-tuple): Desired output size 26 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 27 | 28 | """ 29 | 30 | def __init__( 31 | self, 32 | dataset_csv_path, 33 | dataset_csv_file, 34 | dataset_image_path, 35 | dataset_size=0, 36 | output_size=(240, 240), 37 | transform=None, 38 | random_crop=False, 39 | keypoints_on=False, 40 | original=True, 41 | test=False, 42 | ): 43 | self.category_names = [ 44 | "aeroplane", 45 | "bicycle", 46 | "bird", 47 | "boat", 48 | "bottle", 49 | "bus", 50 | "car", 51 | "cat", 52 | "chair", 53 | "cow", 54 | "diningtable", 55 | "dog", 56 | "horse", 57 | "motorbike", 58 | "person", 59 | "pottedplant", 60 | "sheep", 61 | "sofa", 62 | "train", 63 | "tvmonitor", 64 | ] 65 | self.random_crop = random_crop 66 | self.out_h, self.out_w = output_size 67 | self.annotations = os.path.join( 68 | dataset_image_path, "PF-dataset-PASCAL", "Annotations" 69 | ) 70 | self.train_data = pd.read_csv(os.path.join(dataset_csv_path, dataset_csv_file)) 71 | if dataset_size is not None and dataset_size != 0: 72 | dataset_size = min((dataset_size, len(self.train_data))) 73 | self.train_data = self.train_data.iloc[0:dataset_size, :] 74 | self.img_A_names = self.train_data.iloc[:, 0] 75 | self.img_B_names = self.train_data.iloc[:, 1] 76 | self.set = self.train_data.iloc[:, 2].values 77 | self.test = test 78 | if self.test == False: 79 | self.flip = self.train_data.iloc[:, 3].values.astype("int") 80 | self.dataset_image_path = dataset_image_path 81 | self.transform = transform 82 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 83 | self.affineTnf = tf.AffineTnf( 84 | out_h=self.out_h, out_w=self.out_w, use_cuda=False 85 | ) # resize 86 | self.keypoints_on = keypoints_on 87 | self.original = original 88 | 89 | def __len__(self): 90 | return len(self.img_A_names) 91 | 92 | def __getitem__(self, idx): 93 | # get pre-processed images 94 | image_set = self.set[idx] 95 | if self.test == False: 96 | flip = self.flip[idx] 97 | else: 98 | flip = False 99 | 100 | cat = self.category_names[image_set - 1] 101 | 102 | image_A, im_size_A, kp_A, bbox_A = self.get_image( 103 | self.img_A_names, idx, flip, category_name=cat 104 | ) 105 | image_B, im_size_B, kp_B, bbox_B = self.get_image( 106 | self.img_B_names, idx, flip, category_name=cat 107 | ) 108 | A, kp_A = self.get_gt_assignment(kp_A, kp_B) 109 | 110 | sample = { 111 | "source_image": image_A, 112 | "target_image": image_B, 113 | "source_im_size": im_size_A, 114 | "target_im_size": im_size_B, 115 | "set": image_set, 116 | "source_points": kp_A, 117 | "target_points": kp_B, 118 | "source_bbox": bbox_A, 119 | "target_bbox": bbox_B, 120 | "assignment": A, 121 | } 122 | 123 | if self.transform: 124 | sample = self.transform(sample) 125 | 126 | if self.original: 127 | sample["source_original"] = image_A 128 | sample["target_original"] = image_B 129 | # # get key points annotation 130 | # np_img_A = sample['source_original'].numpy().transpose(1,2,0) 131 | # np_img_B = sample['target_original'].numpy().transpose(1,2,0) 132 | # print('bbox_A', bbox_A) 133 | # print('bbox_B', bbox_B) 134 | # rect = matplotlib.patches.Rectangle((bbox_A[0],bbox_A[1]),bbox_A[2]-bbox_A[0],bbox_A[3]-bbox_A[1],linewidth=1,edgecolor='r',facecolor='none') 135 | # print(rect) 136 | 137 | # fig=plt.figure(figsize=(1, 2)) 138 | # ax0 = fig.add_subplot(1, 2, 1) 139 | # ax0.add_patch(rect) 140 | # plt.imshow(np_img_A) 141 | # # dispaly bounding boxes 142 | # for i, kp in enumerate(kp_A): 143 | # if kp[0] == kp[0]: 144 | # ax0.scatter(kp[0],kp[1], s=5, color='r',alpha=1.) 145 | # ax1 = fig.add_subplot(1, 2, 2) 146 | # rect = matplotlib.patches.Rectangle((bbox_B[0],bbox_B[1]),bbox_B[2]-bbox_B[0],bbox_B[3]-bbox_B[1],linewidth=1,edgecolor='r',facecolor='none') 147 | # print(rect) 148 | # ax1.add_patch(rect) 149 | # plt.imshow(np_img_B) 150 | # for i, kp in enumerate(kp_B): 151 | # if kp[0] == kp[0]: 152 | # ax1.scatter(kp[0],kp[1], s=5, color='r',alpha=1.) 153 | # plt.show() 154 | 155 | return sample 156 | 157 | def get_gt_assignment(self, kp_A, kp_B): 158 | """ 159 | get_gt_assigment() get the ground truth assignment matrix 160 | Arguments: 161 | kp_A [Tensor, float32] Nx3: ground truth key points from the source image 162 | kp_B [Tensor, float32] Nx3: ground truth key points from the target image 163 | Returns: 164 | A [Tensor, float32] NxN: ground truth assignment matrix 165 | kp_A [Tensor, float32] Nx3: ground truth key points + change original idx into target column idx 166 | """ 167 | s = kp_A[:, 2].long() 168 | t = kp_B[:, 2].long() 169 | N = s.shape[0] 170 | A = torch.zeros(N, N) 171 | for n in range(N): 172 | if s[n] == 0: 173 | continue 174 | idx = (t == s[n]).nonzero() 175 | if idx.nelement() == 0: 176 | continue 177 | A[n, idx] = 1 178 | kp_A[n, 2] = idx + 1 179 | 180 | return A, kp_A 181 | 182 | def get_image(self, img_name_list, idx, flip, category_name=None): 183 | img_name = os.path.join(self.dataset_image_path, img_name_list.iloc[idx]) 184 | image = io.imread(img_name) 185 | 186 | # if grayscale convert to 3-channel image 187 | if image.ndim == 2: 188 | image = np.repeat(np.expand_dims(image, 2), axis=2, repeats=3) 189 | 190 | if self.keypoints_on: 191 | keypoints, bbox = self.get_annotations( 192 | img_name_list.iloc[idx], category_name 193 | ) 194 | 195 | # do random crop 196 | if self.random_crop: 197 | h, w, c = image.shape 198 | top = np.random.randint(h / 4) 199 | bottom = int(3 * h / 4 + np.random.randint(h / 4)) 200 | left = np.random.randint(w / 4) 201 | right = int(3 * w / 4 + np.random.randint(w / 4)) 202 | image = image[top:bottom, left:right, :] 203 | 204 | # get image size 205 | im_size = np.asarray(image.shape) 206 | 207 | # flip horizontally if needed 208 | if flip: 209 | image = np.flip(image, 1) 210 | if self.keypoints_on: 211 | N, _ = keypoints.shape 212 | for n in range(N): 213 | if keypoints[n, 2] > 0: 214 | keypoints[n, 0] = im_size[1] - keypoints[n, 0] 215 | bbox[0] = im_size[1] - bbox[0] 216 | bbox[2] = im_size[1] - bbox[2] 217 | tmp = bbox[0] 218 | bbox[0] = bbox[2] 219 | bbox[2] = tmp 220 | 221 | # convert to torch Variable 222 | image = np.expand_dims(image.transpose((2, 0, 1)), 0) 223 | image = torch.Tensor(image.astype(np.float32)) 224 | image_var = Variable(image, requires_grad=False) 225 | 226 | # Resize image using bilinear sampling with identity affine tnf 227 | image = self.affineTnf(image_var).data.squeeze( 228 | 0 229 | ) # the resized image becomes 400 x 400 230 | im_size = torch.Tensor(im_size.astype(np.float32)) # original image sise 231 | 232 | if self.keypoints_on: 233 | keypoints[:, 0] = keypoints[:, 0] / float(im_size[1]) * float(self.out_w) 234 | keypoints[:, 1] = keypoints[:, 1] / float(im_size[0]) * float(self.out_h) 235 | bbox[0] = bbox[0] / float(im_size[1]) * float(self.out_w) 236 | bbox[1] = bbox[1] / float(im_size[0]) * float(self.out_h) 237 | bbox[2] = bbox[2] / float(im_size[1]) * float(self.out_w) 238 | bbox[3] = bbox[3] / float(im_size[0]) * float(self.out_h) 239 | return (image, im_size, keypoints, bbox) 240 | else: 241 | return (image, im_size) 242 | 243 | def construct_graph(self, kp): 244 | """ 245 | construct_graph() construct a sparse graph represented by G and H. 246 | Arguments: 247 | kp [np array float, N x 3] stores the key points 248 | Returns 249 | G [np.array float, 32 x 96]: stores nodes by edges, if c-th edge leaves r-th node 250 | H [np.array float, 32 x 96]: stores nodes by edges, if c-th edge ends at r-th node 251 | """ 252 | N = kp.shape[0] 253 | 254 | G = np.zeros(32, 96) 255 | H = np.zeros(32, 96) 256 | return G, H 257 | 258 | def get_annotations(self, keypoint_annotation, category_name): 259 | """ 260 | get_annotations() get key points annotation 261 | Arguments: 262 | keypoint_annotations str: the file name of the key point annotations 263 | category_name str: the category name of the image 264 | Returns: 265 | keypoint [Tensor float32] 32x3 266 | bbox [Tensor float32] 4 267 | """ 268 | base, _ = os.path.splitext(os.path.basename(keypoint_annotation)) 269 | # print('base', os.path.join(self.annotations, category_name, base +'.mat')) 270 | anno = scipy.io.loadmat( 271 | os.path.join(self.annotations, category_name, base + ".mat") 272 | ) 273 | keypoint = np.zeros((32, 3), dtype=np.float32) 274 | annotation = anno["kps"] 275 | N = annotation.shape[0] 276 | for i in range(N): 277 | if ( 278 | annotation[i, 0] == annotation[i, 0] 279 | and annotation[i, 1] == annotation[i, 1] 280 | ): # not nan 281 | keypoint[i, :2] = annotation[i] 282 | keypoint[i, 2] = i + 1 283 | 284 | np.random.shuffle(keypoint) 285 | 286 | keypoint = torch.Tensor(keypoint.astype(np.float32)) 287 | bbox = anno["bbox"][0].astype(np.float32) 288 | return keypoint, bbox 289 | 290 | 291 | class ImagePairDatasetKeyPoint(ImagePairDataset): 292 | def __init__( 293 | self, 294 | dataset_csv_path, 295 | dataset_csv_file, 296 | dataset_image_path, 297 | dataset_size=0, 298 | output_size=(240, 240), 299 | transform=None, 300 | random_crop=False, 301 | ): 302 | super(ImagePairDatasetKeyPoint, self).__init__( 303 | dataset_csv_path, 304 | dataset_csv_file, 305 | dataset_image_path, 306 | dataset_size=dataset_size, 307 | output_size=output_size, 308 | transform=transform, 309 | random_crop=random_crop, 310 | ) 311 | 312 | -------------------------------------------------------------------------------- /lib/plot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | def plot_image(im, batch_idx=0, return_im=False): 8 | if im.dim() == 4: 9 | im = im[batch_idx, :, :, :] 10 | mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]).view(3, 1, 1)) 11 | std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]).view(3, 1, 1)) 12 | if im.is_cuda: 13 | mean = mean.cuda() 14 | std = std.cuda() 15 | im = im.mul(std).add(mean) * 255.0 16 | im = im.permute(1, 2, 0).data.cpu().numpy().astype(np.uint8) 17 | if return_im: 18 | return im 19 | plt.imshow(im) 20 | plt.show() 21 | 22 | 23 | def save_plot(filename): 24 | plt.gca().set_axis_off() 25 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 26 | plt.margins(0, 0) 27 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 28 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 29 | plt.savefig(filename, bbox_inches="tight", pad_inches=0) 30 | 31 | -------------------------------------------------------------------------------- /lib/point_tnf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | 7 | def normalize_axis(x, L): 8 | # convert original pixel coordinate into unit coordinate (-1, 1) 9 | return (x - 1 - (L - 1) / 2) * 2 / (L - 1) 10 | # return (x-L/2)*2/L 11 | 12 | 13 | def unnormalize_axis(x, L): 14 | return x * (L - 1) / 2 + 1 + (L - 1) / 2 15 | # return x*L/2+L/2 16 | 17 | 18 | def corr_to_matches( 19 | corr4d, 20 | delta4d=None, 21 | k_size=1, 22 | do_softmax=False, 23 | scale="centered", 24 | return_indices=False, 25 | invert_matching_direction=False, 26 | ): 27 | """ 28 | corr_to_matches() interprente the corr4d into correspondences defined in the centred/postive unit coordinate system 29 | Arguements: 30 | corr4d: 31 | Returns: 32 | xA, yA, xB, yB, B x 256 ; A are regular unit coordinate form the source feature map, 33 | B are best matches of A in image B 34 | scores: B x 256; number of correlations 35 | """ 36 | to_cuda = lambda x: x.cuda() if corr4d.is_cuda else x 37 | batch_size, ch, fs1, fs2, fs3, fs4 = corr4d.size() 38 | 39 | if scale == "centered": 40 | XA, YA = np.meshgrid( 41 | np.linspace(-1, 1, fs2 * k_size), np.linspace(-1, 1, fs1 * k_size) 42 | ) 43 | XB, YB = np.meshgrid( 44 | np.linspace(-1, 1, fs4 * k_size), np.linspace(-1, 1, fs3 * k_size) 45 | ) 46 | elif scale == "positive": 47 | XA, YA = np.meshgrid( 48 | np.linspace(0, 1, fs2 * k_size), np.linspace(0, 1, fs1 * k_size) 49 | ) 50 | XB, YB = np.meshgrid( 51 | np.linspace(0, 1, fs4 * k_size), np.linspace(0, 1, fs3 * k_size) 52 | ) 53 | 54 | JA, IA = np.meshgrid(range(fs2), range(fs1)) 55 | JB, IB = np.meshgrid(range(fs4), range(fs3)) 56 | 57 | XA, YA = ( 58 | Variable(to_cuda(torch.FloatTensor(XA))), 59 | Variable(to_cuda(torch.FloatTensor(YA))), 60 | ) 61 | XB, YB = ( 62 | Variable(to_cuda(torch.FloatTensor(XB))), 63 | Variable(to_cuda(torch.FloatTensor(YB))), 64 | ) 65 | 66 | JA, IA = ( 67 | Variable(to_cuda(torch.LongTensor(JA).view(1, -1))), 68 | Variable(to_cuda(torch.LongTensor(IA).view(1, -1))), 69 | ) 70 | JB, IB = ( 71 | Variable(to_cuda(torch.LongTensor(JB).view(1, -1))), 72 | Variable(to_cuda(torch.LongTensor(IB).view(1, -1))), 73 | ) 74 | 75 | if invert_matching_direction: 76 | nc_A_Bvec = corr4d.view(batch_size, fs1, fs2, fs3 * fs4) 77 | 78 | if do_softmax: 79 | nc_A_Bvec = torch.nn.functional.softmax(nc_A_Bvec, dim=3) 80 | 81 | match_A_vals, idx_A_Bvec = torch.max(nc_A_Bvec, dim=3) 82 | score = match_A_vals.view(batch_size, -1) 83 | 84 | iB = IB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1) 85 | jB = JB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1) 86 | iA = IA.expand_as(iB) 87 | jA = JA.expand_as(jB) 88 | 89 | else: 90 | # actually runs 91 | nc_B_Avec = corr4d.view( 92 | batch_size, fs1 * fs2, fs3, fs4 93 | ) # [batch_idx,k_A,i_B,j_B] 94 | if do_softmax: 95 | nc_B_Avec = torch.nn.functional.softmax(nc_B_Avec, dim=1) 96 | 97 | match_B_vals, idx_B_Avec = torch.max(nc_B_Avec, dim=1) 98 | score = match_B_vals.view(batch_size, -1) 99 | 100 | iA = IA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, -1) 101 | jA = JA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, -1) 102 | iB = IB.expand_as(iA) 103 | jB = JB.expand_as(jA) 104 | 105 | if delta4d is not None: # relocalization 106 | delta_iA, delta_jA, delta_iB, delta_jB = delta4d 107 | 108 | diA = delta_iA.squeeze(0).squeeze(0)[ 109 | iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1) 110 | ] 111 | djA = delta_jA.squeeze(0).squeeze(0)[ 112 | iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1) 113 | ] 114 | diB = delta_iB.squeeze(0).squeeze(0)[ 115 | iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1) 116 | ] 117 | djB = delta_jB.squeeze(0).squeeze(0)[ 118 | iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1) 119 | ] 120 | 121 | iA = iA * k_size + diA.expand_as(iA) 122 | jA = jA * k_size + djA.expand_as(jA) 123 | iB = iB * k_size + diB.expand_as(iB) 124 | jB = jB * k_size + djB.expand_as(jB) 125 | 126 | xA = XA[iA.view(-1), jA.view(-1)].view(batch_size, -1) 127 | yA = YA[iA.view(-1), jA.view(-1)].view(batch_size, -1) 128 | xB = XB[iB.view(-1), jB.view(-1)].view(batch_size, -1) 129 | yB = YB[iB.view(-1), jB.view(-1)].view(batch_size, -1) 130 | 131 | if return_indices: 132 | return (xA, yA, xB, yB, score, iA, jA, iB, jB) 133 | else: 134 | return (xA, yA, xB, yB, score) 135 | 136 | 137 | def nearestNeighPointTnf(matches, target_points_norm): 138 | xA, yA, xB, yB = matches 139 | 140 | # match target points to grid 141 | deltaX = target_points_norm[:, 0, :].unsqueeze(1) - xB.unsqueeze(2) 142 | deltaY = target_points_norm[:, 1, :].unsqueeze(1) - yB.unsqueeze(2) 143 | distB = torch.sqrt(torch.pow(deltaX, 2) + torch.pow(deltaY, 2)) 144 | vals, idx = torch.min(distB, dim=1) 145 | 146 | warped_points_x = xA.view(-1)[idx.view(-1)].view(1, 1, -1) 147 | warped_points_y = yA.view(-1)[idx.view(-1)].view(1, 1, -1) 148 | warped_points_norm = torch.cat((warped_points_x, warped_points_y), dim=1) 149 | return warped_points_norm 150 | 151 | 152 | def bilinearInterpPointTnf(matches, target_points_norm): 153 | """ 154 | bilinearInterpPointTnf() 155 | Argument: 156 | matches tuple (xA, yA, xB, yB), xA, yA, xB, yB Bx 256 157 | target_points_norm [tensor] B x 2 x N 158 | 159 | 160 | """ 161 | xA, yA, xB, yB = matches # 162 | 163 | feature_size = int(np.sqrt(xB.shape[-1])) 164 | 165 | b, _, N = target_points_norm.size() 166 | 167 | X_ = xB.view(-1) # B*256 168 | Y_ = yB.view(-1) # B*256 169 | 170 | grid = torch.FloatTensor(np.linspace(-1, 1, feature_size)).unsqueeze(0).unsqueeze(2) 171 | # grid is 1 x 16 x 1 172 | if xB.is_cuda: 173 | grid = grid.cuda() 174 | if isinstance(xB, Variable): 175 | grid = Variable(grid) 176 | 177 | x_minus = ( 178 | torch.sum( 179 | ((target_points_norm[:, 0, :] - grid) > 0).long(), dim=1, keepdim=True 180 | ) 181 | - 1 182 | ) 183 | x_minus[x_minus < 0] = 0 # fix edge case 184 | x_plus = x_minus + 1 185 | x_plus[x_plus > (feature_size - 1)] = feature_size - 1 186 | 187 | y_minus = ( 188 | torch.sum( 189 | ((target_points_norm[:, 1, :] - grid) > 0).long(), dim=1, keepdim=True 190 | ) 191 | - 1 192 | ) 193 | y_minus[y_minus < 0] = 0 # fix edge case 194 | y_plus = y_minus + 1 195 | y_plus[y_plus > (feature_size - 1)] = feature_size - 1 196 | 197 | toidx = lambda x, y, L: y * L + x 198 | 199 | m_m_idx = toidx(x_minus, y_minus, feature_size) 200 | p_p_idx = toidx(x_plus, y_plus, feature_size) 201 | p_m_idx = toidx(x_plus, y_minus, feature_size) 202 | m_p_idx = toidx(x_minus, y_plus, feature_size) 203 | 204 | # print('m_m_idx',m_m_idx) 205 | # print('p_p_idx',p_p_idx) 206 | # print('p_m_idx',p_m_idx) 207 | # print('m_p_idx',m_p_idx) 208 | topoint = lambda idx, X, Y: torch.cat( 209 | ( 210 | X[idx.view(-1)].view(b, 1, N).contiguous(), 211 | Y[idx.view(-1)].view(b, 1, N).contiguous(), 212 | ), 213 | dim=1, 214 | ) 215 | 216 | P_m_m = topoint(m_m_idx, X_, Y_) 217 | P_p_p = topoint(p_p_idx, X_, Y_) 218 | P_p_m = topoint(p_m_idx, X_, Y_) 219 | P_m_p = topoint(m_p_idx, X_, Y_) 220 | 221 | multrows = lambda x: x[:, 0, :] * x[:, 1, :] 222 | 223 | f_p_p = multrows(torch.abs(target_points_norm - P_m_m)) 224 | f_m_m = multrows(torch.abs(target_points_norm - P_p_p)) 225 | f_m_p = multrows(torch.abs(target_points_norm - P_p_m)) 226 | f_p_m = multrows(torch.abs(target_points_norm - P_m_p)) 227 | 228 | Q_m_m = topoint(m_m_idx, xA.view(-1), yA.view(-1)) 229 | Q_p_p = topoint(p_p_idx, xA.view(-1), yA.view(-1)) 230 | Q_p_m = topoint(p_m_idx, xA.view(-1), yA.view(-1)) 231 | Q_m_p = topoint(m_p_idx, xA.view(-1), yA.view(-1)) 232 | 233 | warped_points_norm = ( 234 | Q_m_m * f_m_m + Q_p_p * f_p_p + Q_m_p * f_m_p + Q_p_m * f_p_m 235 | ) / (f_p_p + f_m_m + f_m_p + f_p_m) 236 | return warped_points_norm 237 | 238 | 239 | def PointsToUnitCoords(P, im_size): 240 | h, w = im_size[:, 0], im_size[:, 1] 241 | P_norm = P.clone() 242 | # normalize Y 243 | P_norm[:, 0, :] = normalize_axis(P[:, 0, :], w.unsqueeze(1).expand_as(P[:, 0, :])) 244 | # normalize X 245 | P_norm[:, 1, :] = normalize_axis(P[:, 1, :], h.unsqueeze(1).expand_as(P[:, 1, :])) 246 | return P_norm 247 | 248 | 249 | def PointsToPixelCoords(P, im_size): 250 | h, w = im_size[:, 0], im_size[:, 1] 251 | P_norm = P.clone() 252 | # normalize Y 253 | P_norm[:, 0, :] = unnormalize_axis(P[:, 0, :], w.unsqueeze(1).expand_as(P[:, 0, :])) 254 | # normalize X 255 | P_norm[:, 1, :] = unnormalize_axis(P[:, 1, :], h.unsqueeze(1).expand_as(P[:, 1, :])) 256 | return P_norm 257 | -------------------------------------------------------------------------------- /lib/tools.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.nn.functional as F 3 | import os, time, sys, math 4 | import subprocess, shutil 5 | from . import constant 6 | from . import visualisation 7 | from os.path import * 8 | import numpy as np 9 | import numpy 10 | import torch 11 | import tqdm as tqdm 12 | import matplotlib.pyplot as plt 13 | import random 14 | from tqdm import tqdm 15 | from . import interpolator 16 | 17 | 18 | def seed_torch(seed=1029): 19 | np.random.seed(seed) 20 | os.environ["PYTHONHASHSEED"] = str(seed) 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed(seed) 23 | torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. 24 | torch.backends.cudnn.benchmark = False 25 | torch.backends.cudnn.deterministic = True 26 | 27 | 28 | def save_checkpoint(state, file): 29 | model_dir = dirname(file) 30 | model_fn = basename(file) 31 | # make dir if needed (should be non-empty) 32 | if model_dir != "" and not exists(model_dir): 33 | os.makedirs(model_dir) 34 | torch.save(state, join(model_dir, str(state["epoch"]) + "_" + model_fn)) 35 | 36 | 37 | def calc_gt_indices(batch_keys_gt, batch_assignments_gt): 38 | """ 39 | calc_gt_indices() calculate the ground truth indices and number of valid key points. 40 | Arguments: 41 | batch_keys_gt: [tensor B x N x 3] the last column stores the indicator of the key points 42 | if it is larger than 0 it is valid, otherwise invalid 43 | batch_assignments_gt [tensor B x N x N], the ground truth assignment matrix 44 | Returns: 45 | indices_gt: [tensor B x N ]: from source to target, each element stores the index of target matches 46 | key_num_gt: [tensor B]: number of valid key points per input of the batch 47 | """ 48 | _, indices_gt = torch.max( 49 | batch_assignments_gt, 2 50 | ) # get ground truth matches from source to target 51 | indices_gt += ( 52 | 1 53 | ) # remember that indices start counting from 1 for 0 is used to store empty key points 54 | mask_gt = (batch_keys_gt[:, :, 2] > 0).long() # get the valid key point masks 55 | indices_gt = indices_gt * mask_gt 56 | key_num_gt = mask_gt.sum(dim=1).float() 57 | return indices_gt, key_num_gt 58 | 59 | 60 | def calc_accuracy(batch_assignments, indices_gt, src_key_num_gt): 61 | """ 62 | calc_accuracy() calculate the accuracy for each instance in a batch of ground truth key points 63 | and batch and predicted assignments. 64 | Arguments: 65 | batch_assignment [tensor float B x 32 x 32]: the batch of the predicted assignment matrix 66 | indices_gt [tensor long B x 32 ]: the batch of ground truth indices from source to target 67 | src_key_num_gt [tensor float Bx 1]: the ground truth number of valid key points for a batch 68 | with batch size B. 69 | Returns: 70 | accuracy [tensor float B x 32]: the accuracy for each instance of the batch is calculated. 71 | """ 72 | 73 | values, indices = torch.max( 74 | batch_assignments, 2 75 | ) # get matches for source key points 76 | indices += ( 77 | 1 78 | ) # remember that indices start counting from 1 for 0 is used to store empty key points 79 | 80 | accuracy = (indices_gt == indices).sum(dim=1).float() 81 | accuracy = torch.div(accuracy, src_key_num_gt) 82 | 83 | return accuracy 84 | 85 | 86 | def pure_pck(keys_pred, keys_gt, key_num_gt, image_scale, alpha): 87 | """ 88 | pure_pck() calculate the pck percentage for each instance in a batch of predicted key points 89 | Arguments: 90 | keys_pred [tensor float B x 32 x 3]: the predicted key points 91 | keys_gt [tensor float B x 32 x 2]: the ground truth key points. 92 | key_num_gt [tensor float B X 1]: the ground truth number of valid key points 93 | image_scale float: the length of image diagonal image_scale = sqrt( W^2 + H^2) 94 | alpha float: percentage threshold. 95 | Returns: 96 | pck [tensor float Bx 1]: the pck score for the batch 97 | """ 98 | dif = keys_pred - keys_gt 99 | err = dif.norm(dim=2) / image_scale 100 | wrong = (err > alpha).sum(dim=1).float() # number of incorrect predictior 101 | pck = 1 - torch.div(wrong, key_num_gt) 102 | return pck 103 | 104 | 105 | def calc_pck0( 106 | target_batch_keys, 107 | target_keys_pred, 108 | batch_assignments_gt, 109 | src_key_num_gt, 110 | image_scale, 111 | alpha=0.1, 112 | ): 113 | """ 114 | calc_pck() calculate the pck percentage for each instance in a batch of predicted key points 115 | Arguments: 116 | target_batch_keys [tensor float B x 32 x 3]: the target key points 117 | target_keys_pred [tensor float B x 32 x 2]: the predicted key points. 118 | batch_assignments_gt [tensor float B x 32 x 32]: the ground truth assignment matrix. 119 | src_key_num_gt [tensor float B X 1]: the ground truth number of valid key points 120 | image_scale float: the length of image diagonal image_scale = sqrt( W^2 + H^2) 121 | alpha float: percentage threshold. 122 | Returns: 123 | pck [tensor float Bx 1]: the pck score for the batch 124 | target_keys_pred, [tensor float B x 2], predicted key points locations w.r.t. target image 125 | batch_keys_gt, [tensor float B x 2 ], ground truth key points locations w.r.t. target image 126 | """ 127 | batch_keys_gt = torch.bmm(batch_assignments_gt, target_batch_keys[:, :, :2]) 128 | pck = pure_pck(target_keys_pred, batch_keys_gt, src_key_num_gt, image_scale, alpha) 129 | return pck, target_keys_pred, batch_keys_gt 130 | 131 | 132 | def distance(keys_pred, keys_gt, key_num_gt): 133 | """ 134 | pure_pck() calculate the pck percentage for each instance in a batch of predicted key points 135 | Arguments: 136 | keys_pred [tensor float B x 32 x 3]: the predicted key points 137 | keys_gt [tensor float B x 32 x 2]: the ground truth key points. 138 | key_num_gt [tensor float B X 1]: the ground truth number of valid key points 139 | Returns: 140 | pck [tensor float Bx 1]: the pck score for the batch 141 | """ 142 | mask = (keys_gt > 1e-10).float() 143 | dif = keys_pred * mask - keys_gt 144 | err = dif.norm(dim=2) 145 | err = err.sum(dim=1) 146 | err = torch.div(err, key_num_gt) 147 | return err 148 | 149 | 150 | def calc_distance( 151 | target_batch_keys, target_keys_pred, batch_assignments_gt, src_key_num_gt 152 | ): 153 | """ 154 | calc_pck() calculate the pck percentage for each instance in a batch of predicted key points 155 | Arguments: 156 | target_batch_keys [tensor float B x 32 x 3]: the target key points 157 | target_keys_pred [tensor float B x 32 x 2]: the predicted key points. 158 | batch_assignments_gt [tensor float B x 32 x 32]: the ground truth assignment matrix. 159 | src_key_num_gt [tensor float B X 1]: the ground truth number of valid key points 160 | image_scale float: the length of image diagonal image_scale = sqrt( W^2 + H^2) 161 | alpha float: percentage threshold. 162 | Returns: 163 | pck [tensor float B x 1]: the pck score for the batch 164 | target_keys_pred, [tensor float B x 2], predicted key points locations w.r.t. target image 165 | batch_keys_gt, [tensor float B x 2 ], ground truth key points locations w.r.t. target image 166 | """ 167 | batch_keys_gt = torch.bmm(batch_assignments_gt, target_batch_keys[:, :, :2]) 168 | err = distance(target_keys_pred, batch_keys_gt, src_key_num_gt) 169 | return err 170 | 171 | 172 | def calc_pck( 173 | target_batch_keys, 174 | batch_assignments, 175 | batch_assignments_gt, 176 | src_key_num_gt, 177 | image_scale, 178 | alpha=0.1, 179 | ): 180 | """ 181 | calc_pck() calculate the pck percentage for each instance in a batch of predicted key points 182 | Arguments: 183 | target_batch_keys [tensor float B x 32 x 3]: the target key points 184 | batch_assignments [tensor float B x 32 x 32]: the predicted assignment matrix. 185 | batch_assignments_gt [tensor float B x 32 x 32]: the ground truth assignment matrix. 186 | src_key_num_gt [tensor float B X 1]: the ground truth number of valid key points 187 | image_scale float: the length of image diagonal image_scale = sqrt( W^2 + H^2) 188 | alpha float: percentage threshold. 189 | Returns: 190 | pck [tensor float Bx 1]: the pck score for the batch 191 | batch_keys_pred, [tensor float B x 2], predicted key points locations w.r.t. target image 192 | batch_keys_gt, [tensor float B x 2 ], ground truth key points locations w.r.t. target image 193 | """ 194 | 195 | batch_keys_pred = torch.bmm(batch_assignments, target_batch_keys[:, :, :2]) 196 | return calc_pck0( 197 | target_batch_keys, 198 | batch_keys_pred, 199 | batch_assignments_gt, 200 | src_key_num_gt, 201 | image_scale, 202 | alpha, 203 | ) 204 | 205 | 206 | def calc_mto(batch_assignments, src_indices_gt, src_key_num_gt): 207 | """ 208 | calc_mto() calculate the one-to-many matching score, notice one is source, many is destination 209 | Arguments: 210 | batch_assignment [tensor float B x 32 x 32]: the batch of the predicted assignment matrix 211 | src_indices_gt [tensor long B x 32 ]: the batch of ground truth key point indices from source to target 212 | src_key_num_gt [tensor float Bx 1]: the ground truth number of valid key points for a batch with batch size B 213 | Returns: 214 | mto [tensor B x 1] cpu: the mto score for the batch 215 | """ 216 | values, indices = torch.max(batch_assignments, 2) 217 | indices += ( 218 | 1 219 | ) # remember that indices start counting from 1 for 0 is used to store empty key points 220 | mask = (src_indices_gt == indices).long() 221 | indices *= mask 222 | num_unique = torch.tensor([float(len(torch.unique(kk))) - 1 for kk in indices]) 223 | mto = 1 - torch.div(num_unique, src_key_num_gt.cpu()) 224 | return mto 225 | 226 | 227 | def graph_matching(batch_assignments, iterations=2): 228 | """ 229 | graph_matching() applying the graph matching update to refine the matches 230 | Arguments: 231 | batch_assignment [tensor float B x 32 x 32]: the batch of the predicted assignment matrix 232 | Returns: 233 | batch_assignment [tensor float B x 32 x 32]: the batch of the refined assignment matrix 234 | """ 235 | 236 | for i in range(iterations): 237 | batch_assignments = batch_assignments * batch_assignments 238 | Xrs_sum = torch.sum(batch_assignments, dim=2, keepdim=True) # normalisation 239 | batch_assignments = batch_assignments / (Xrs_sum + constant._eps) 240 | return batch_assignments 241 | 242 | 243 | def corr_to_matches(corr4d, do_softmax=False, source_to_target=True): 244 | B, ch, fs1, fs2, fs3, fs4 = corr4d.size() 245 | 246 | XA, YA = np.meshgrid(range(fs2), range(fs1)) # pixel coordinate 247 | XB, YB = np.meshgrid(range(fs4), range(fs3)) 248 | XA, YA = torch.FloatTensor(XA), torch.FloatTensor(YA) 249 | XB, YB = torch.FloatTensor(XB), torch.FloatTensor(YB) 250 | XA, YA = XA.view(-1).cuda(), YA.view(-1).cuda() 251 | XB, YB = XB.view(-1).cuda(), YB.view(-1).cuda() 252 | 253 | if source_to_target: 254 | # best match from source to target 255 | nc_A_Bvec = corr4d.view(B, fs1, fs2, fs3 * fs4) 256 | 257 | if do_softmax: 258 | nc_A_Bvec = torch.nn.functional.softmax(nc_A_Bvec, dim=3) 259 | 260 | match_A_vals, idx_A_Bvec = torch.max(nc_A_Bvec, dim=3) # B x fs1 x fs2 261 | score = match_A_vals.view(B, 1, fs1, fs2) # B x 1 x fs1 x fs2 262 | 263 | # idx_A_Bvec: 12 x 16 x 16 264 | xB = XB[idx_A_Bvec.view(-1)].view( 265 | B, 1, fs1, fs2 266 | ) # B x fs1*fs2: index of betch matches in B 267 | yB = YB[idx_A_Bvec.view(-1)].view(B, 1, fs1, fs2) 268 | xyB = torch.cat((xB, yB), 1) 269 | 270 | return xyB.contiguous(), score.contiguous() 271 | else: 272 | # best matches from target to source 273 | nc_B_Avec = corr4d.view(B, fs1 * fs2, fs3, fs4) # [batch_idx,k_A,i_B,j_B] 274 | if do_softmax: # default 275 | nc_B_Avec = torch.nn.functional.softmax(nc_B_Avec, dim=1) 276 | 277 | match_B_vals, idx_B_Avec = torch.max( 278 | nc_B_Avec, dim=1 279 | ) # idx_B_Avec is Bx (16*16 = 256) 280 | score = match_B_vals.view(B, 1, fs3, fs4) # score is B x 1 x fs3 x fs4 281 | 282 | # idx_B_Avec: 12 x 16 x 16 283 | xA = XA[idx_B_Avec.view(-1)].view( 284 | B, 1, fs3, fs4 285 | ) # B x 256, it stores the col index for A 286 | yA = YA[idx_B_Avec.view(-1)].view( 287 | B, 1, fs3, fs4 288 | ) # B x 256, it stores the col index for A 289 | xyA = torch.cat((xA, yA), 1) 290 | 291 | return xyA.contiguous(), score.contiguous() 292 | 293 | 294 | def NormalisationPerRow(keycorr): 295 | """ 296 | NormalisationPerRow() normalise the 3rd dimension by calculating its sum and divide the vector 297 | in last dimension by the sum 298 | Arguments 299 | keycorr: B x N x HW 300 | Returns 301 | keycorr: B x N x HW 302 | """ 303 | eps = 1e-15 304 | sum_per_row = keycorr.sum(dim=2, keepdim=True) + eps 305 | sum_per_row = sum_per_row.expand_as(keycorr) # B x N x L 306 | keycorr = keycorr / sum_per_row 307 | return keycorr 308 | 309 | 310 | class ExtractFeatureMap: 311 | def __init__(self, im_fe_ratio): 312 | self.im_fe_ratio = im_fe_ratio 313 | self.interp = interpolator.Interpolator(im_fe_ratio) 314 | self.offset = int(im_fe_ratio / 2 - 1) 315 | 316 | def upsampling_keycorr(self, keycorr, image_size): 317 | B, N, H, W = keycorr.shape 318 | keycorr = F.interpolate( 319 | keycorr, size=image_size, mode="bilinear", align_corners=False 320 | ) # (H2-1)*interp.im_fe_ratio+2) x (W2-1)*interp.im_fe_ratio+2) 321 | keycorr = keycorr.view(B, N, -1).contiguous() 322 | return keycorr 323 | 324 | def normalise_image(self, keycorr_original, kmin=None, krange=None): 325 | keycorr = keycorr_original.clone() 326 | eps = 1e-15 327 | B, N, C = keycorr.shape 328 | keycorr = keycorr.view(B, -1) 329 | if kmin is None and krange is None: 330 | kmin, _ = keycorr.min(dim=1, keepdim=True) # B x 1 331 | kmax, _ = keycorr.max(dim=1, keepdim=True) 332 | krange = kmax - kmin 333 | krange = krange.expand_as(keycorr) 334 | kmin = kmin.expand_as(keycorr) 335 | keycorr = (keycorr - kmin) / krange 336 | return keycorr.view(B, N, C), kmin, krange 337 | 338 | def __call__(self, corr, key_gt, source_to_target=True, image_size=None): 339 | """ 340 | extract_featuremap() extract the interpolated feature map for each query key points in key_gt 341 | Arguements 342 | corr [tensor float] B x 1 x H1 x W1 x H2 x W2: the 4d correlation map 343 | key_gt [tensor float] B x N x 2: the tensor stores the sparse query key points 344 | image_size [tuple int] H, W: original input image size, if it is None, then no interpolation 345 | interp [object of Interpolator]: to interpolate the correlation maps 346 | source_to_targe [boolean]: if true, query from source to target, otherwise, from target to source 347 | Return: 348 | keycorr [tensor float]: B x N x H2W2 (when source_to_targe = True) the correlation map for each source 349 | key ponit 350 | """ 351 | B, C, H1, W1, H2, W2 = corr.shape 352 | _, N, _ = key_gt.shape 353 | if source_to_target: 354 | corr = corr.view(B, H1, W1, H2 * W2) 355 | corr = corr.permute(0, 3, 1, 2) 356 | keycorr = self.interp(corr, key_gt) # keycorr B x H2*W2 x N, key is source 357 | keycorr = keycorr.permute(0, 2, 1) # B x N x H2*W2 358 | keycorr = keycorr.view(B, N, H2, W2) # B x N x H2 x W2 359 | else: 360 | corr = corr.view(B, H1 * W1, H2, W2) 361 | keycorr = self.interp( 362 | corr, key_gt 363 | ) # keycorr B x H1*W1 x N key_gt is target point 364 | keycorr = keycorr.permute(0, 2, 1) # B x N x H1*W1 365 | keycorr = keycorr.view(B, N, H1, W1) # B x N x H1 x W1 366 | 367 | if image_size is not None: 368 | keycorr = self.upsampling_keycorr(keycorr, image_size) 369 | keycorr = keycorr.view(B, N, -1).contiguous() 370 | # try softmax 371 | # keycorr = torch.softmax(keycorr, dim = 2) 372 | keycorr = NormalisationPerRow(keycorr) 373 | return keycorr 374 | 375 | def keycorr_to_matches(self, keycorr, image_size): 376 | """ 377 | keycorr_to_matches() 378 | keycorr [tensor float]: B x N x HW (when source_to_targe = True) the correlation map for each source 379 | key ponit, note H x W must be aligned with the original image size 380 | image_size (tuple) H x W: original image size 381 | Returns 382 | xyA [tensor float]: B x N x 2, the key points from source to target 383 | """ 384 | B, N, _ = keycorr.shape 385 | 386 | XA, YA = np.meshgrid( 387 | range(image_size[1]), range(image_size[0]) 388 | ) # pixel coordinate 389 | XA, YA = ( 390 | torch.FloatTensor(XA).view(-1).cuda(), 391 | torch.FloatTensor(YA).view(-1).cuda(), 392 | ) 393 | 394 | values, indices = torch.max(keycorr, dim=2) 395 | xA = XA[indices.view(-1)].view(B, N, 1) 396 | yA = YA[indices.view(-1)].view(B, N, 1) 397 | xyA = torch.cat((xA, yA), 2) 398 | return xyA 399 | 400 | 401 | def validate( 402 | model, 403 | loader, 404 | batch_preprocessing_fn, 405 | graph_layer, 406 | image_scale, 407 | im_fe_ratio=16, 408 | image_size=(256, 256), 409 | alpha=0.1, 410 | MAX=100, 411 | display=False, 412 | iterations=2, 413 | ): 414 | model.train(mode=False) 415 | avg_recall = 0.0 416 | 417 | total = min(MAX, len(loader)) 418 | 419 | mean_accuracy = 0 420 | mean_accuracy2 = 0 421 | mean_accuracy3 = 0 422 | mean_mto = 0 423 | mean_mto2 = 0 424 | mean_mto3 = 0 425 | mean_pck = 0 426 | mean_pck2 = 0 427 | mean_pck3 = torch.zeros((1, 1)) 428 | mean_pck4 = torch.zeros((1, 1)) 429 | 430 | output_dir = "output" 431 | if output_dir != "" and not exists(output_dir): 432 | os.makedirs(output_dir) 433 | 434 | extract_featuremap = ExtractFeatureMap(im_fe_ratio) 435 | progress = tqdm(loader, total=total) 436 | for i, data in enumerate(progress): 437 | if i >= total: 438 | break 439 | tnf_batch = batch_preprocessing_fn(data) 440 | corr = model(tnf_batch) # Xr is the predicted permutation matrix 441 | 442 | Xg = tnf_batch["assignment"] 443 | Xgt = tnf_batch["assignment"].permute(0, 2, 1) 444 | 445 | src_gt = tnf_batch["source_points"] 446 | dst_gt = tnf_batch["target_points"] 447 | # calc key_num 448 | src_indices_gt, src_key_num_gt = calc_gt_indices(src_gt, Xg) 449 | dst_indices_gt, dst_key_num_gt = calc_gt_indices(dst_gt, Xgt) 450 | 451 | keycorrB_A = extract_featuremap( 452 | corr, src_gt[:, :, :2], source_to_target=True, image_size=image_size 453 | ) 454 | keycorrA_B = extract_featuremap( 455 | corr, dst_gt[:, :, :2], source_to_target=False, image_size=image_size 456 | ) 457 | xyB_A = extract_featuremap.keycorr_to_matches(keycorrB_A, image_size) 458 | xyA_B = extract_featuremap.keycorr_to_matches(keycorrA_B, image_size) 459 | 460 | pck3, src_key_p3, src_key_gt3 = calc_pck0( 461 | dst_gt, xyB_A, Xg, dst_key_num_gt, image_scale, alpha 462 | ) 463 | pck4, src_key_p4, src_key_gt4 = calc_pck0( 464 | src_gt, xyA_B, Xgt, src_key_num_gt, image_scale, alpha 465 | ) 466 | mean_pck3 += pck3.mean() 467 | mean_pck4 += pck4.mean() 468 | 469 | # visualise results 470 | if display: 471 | 472 | source = (tnf_batch["source_original"] * 255).int() 473 | target = (tnf_batch["target_original"] * 255).int() 474 | B = source.shape[0] 475 | for b in range(B): 476 | file_name = join(output_dir, "{}_{}_".format(i, b)) 477 | visualisation.displayPair( 478 | source[b].detach().cpu().permute(1, 2, 0), 479 | src_gt[b].detach().cpu(), 480 | target[b].detach().cpu().permute(1, 2, 0), 481 | src_key_p3[b].detach().cpu(), 482 | src_key_gt3[b].detach().cpu(), 483 | file_name=file_name, 484 | ) 485 | 486 | mean_pck3 /= total 487 | mean_pck4 /= total 488 | model.train(mode=True) 489 | return mean_pck3, mean_pck4 490 | 491 | 492 | def visualise_feature( 493 | model, loader, batch_preprocessing_fn, image_size, im_fe_ratio=16, MAX=100 494 | ): 495 | model.train(mode=False) 496 | 497 | total = min(MAX, len(loader)) 498 | 499 | extract_featuremap = ExtractFeatureMap(im_fe_ratio) 500 | progress = tqdm(loader, total=total) 501 | cm_hot = plt.get_cmap("bwr") # coolwarm 502 | 503 | output_dir = "output" 504 | if output_dir != "" and not exists(output_dir): 505 | os.makedirs(output_dir) 506 | 507 | for i, data in enumerate(progress): 508 | if i >= total: 509 | break 510 | tnf_batch = batch_preprocessing_fn(data) 511 | 512 | category = tnf_batch["set"] 513 | Xg = tnf_batch["assignment"] 514 | src_gt_cuda = tnf_batch["source_points"] 515 | dst_gt_cuda = tnf_batch["target_points"] 516 | src_indices_gt, src_key_num_gt = calc_gt_indices(src_gt_cuda, Xg) 517 | 518 | src_gt = src_gt_cuda.detach().cpu() 519 | dst_gt = dst_gt_cuda.detach().cpu() 520 | target = (tnf_batch["target_original"] * 255).int() 521 | source = (tnf_batch["source_original"] * 255).int() 522 | B, N, _ = src_gt.shape 523 | 524 | corr = model(tnf_batch) # Xr is the predicted permutation matrix 525 | B, C, H1, W1, H2, W2 = corr.shape 526 | 527 | keycorrB_A = extract_featuremap( 528 | corr, src_gt_cuda[:, :, :2], source_to_target=True, image_size=image_size 529 | ) 530 | keycorrA_B = extract_featuremap( 531 | corr, dst_gt_cuda[:, :, :2], source_to_target=False, image_size=image_size 532 | ) 533 | xyB_A = extract_featuremap.keycorr_to_matches(keycorrB_A, image_size) 534 | xyA_B = extract_featuremap.keycorr_to_matches(keycorrA_B, image_size) 535 | keycorrB_A, _, _ = extract_featuremap.normalise_image(keycorrB_A) 536 | keycorrA_B, _, _ = extract_featuremap.normalise_image(keycorrA_B) 537 | 538 | keycorrB_A = keycorrB_A.view(B, N, *image_size).detach().cpu() 539 | keycorrA_B = keycorrA_B.view(B, N, *image_size).detach().cpu() 540 | xyB_A = xyB_A.detach().cpu() 541 | xyA_B = xyA_B.detach().cpu() 542 | 543 | for b in range(B): 544 | NN = min(32, int(src_key_num_gt[b])) 545 | fig, axes = plt.subplots(4, NN, sharex="all", sharey="all") 546 | nn = 0 547 | for n in range(N): 548 | 549 | tn = src_indices_gt[b, n] # target index 550 | if tn > 0: 551 | # paint source 552 | c = n % len(constant._colors) 553 | m = n // len(constant._colors) 554 | 555 | original = source[b].detach().cpu().permute(1, 2, 0) 556 | 557 | axes[0, nn].imshow(original) 558 | axes[0, nn].scatter( 559 | src_gt[b, n, 0], 560 | src_gt[b, n, 1], 561 | s=5., 562 | edgecolors='g', 563 | color='g', 564 | alpha=.7, 565 | marker='o', 566 | ) 567 | axes[0, nn].axis("off") 568 | 569 | # paint 4d corr 570 | im = cm_hot(keycorrB_A[b, n].detach().cpu()) * 255 571 | original = target[b].detach().cpu().permute(1, 2, 0).numpy() 572 | im = im[:, :, :3] * 0.5 + original * 0.5 573 | im = np.uint8(im) 574 | im = Image.fromarray(im) 575 | 576 | axes[1, nn].imshow(im) 577 | axes[1, nn].scatter( 578 | xyB_A[b, n, 0], 579 | xyB_A[b, n, 1], 580 | s=1., 581 | edgecolors="r", 582 | color='r', 583 | alpha=.7, 584 | marker='o', 585 | ) 586 | 587 | axes[1, nn].scatter( 588 | dst_gt[b, tn - 1, 0], 589 | dst_gt[b, tn - 1, 1], 590 | s=1., 591 | edgecolors='g', 592 | color='g', 593 | alpha=.7, 594 | marker='o', 595 | ) 596 | axes[1, nn].axis("off") 597 | 598 | # paint target 599 | axes[2, nn].imshow(target[b].detach().cpu().permute(1, 2, 0)) 600 | axes[2, nn].scatter( 601 | dst_gt[b, tn - 1, 0], 602 | dst_gt[b, tn - 1, 1], 603 | s=5., 604 | edgecolors='g', 605 | color='g', 606 | alpha=.7, 607 | marker='o', 608 | ) 609 | axes[2, nn].axis("off") 610 | 611 | im = cm_hot(keycorrA_B[b, tn - 1].detach().cpu()) * 255 612 | original = source[b].detach().cpu().permute(1, 2, 0).numpy() 613 | im = im[:, :, :3] * 0.5 + original * 0.5 614 | im = np.uint8(im) 615 | im = Image.fromarray(im) 616 | 617 | axes[3, nn].imshow(im) 618 | axes[3, nn].scatter( 619 | xyA_B[b, tn - 1, 0], 620 | xyA_B[b, tn - 1, 1], 621 | s=1., 622 | edgecolors="r", 623 | color='r', 624 | alpha=.7, 625 | marker='o', 626 | ) 627 | axes[3, nn].scatter( 628 | src_gt[b, n, 0], 629 | src_gt[b, n, 1], 630 | s=1., 631 | edgecolors='g', 632 | color='g', 633 | alpha=.7, 634 | marker='o', 635 | ) 636 | axes[3, nn].axis("off") 637 | 638 | nn += 1 639 | if nn >= NN: 640 | break 641 | 642 | # source image 643 | file_name = join(output_dir, "{}_heatmaps.png".format(i)) 644 | plt.axis("off") 645 | plt.savefig( 646 | file_name, bbox_inches="tight", pad_inches=0, quality=100, dpi=1200 647 | ) 648 | plt.clf() 649 | # plt.show() 650 | -------------------------------------------------------------------------------- /lib/torch_util.py: -------------------------------------------------------------------------------- 1 | # from NCNet 2 | 3 | import shutil 4 | import torch 5 | from torch.autograd import Variable 6 | from os import makedirs, remove 7 | from os.path import exists, join, basename, dirname 8 | import collections 9 | from lib.dataloader import default_collate 10 | 11 | 12 | def collate_custom(batch): 13 | """ Custom collate function for the Dataset class 14 | * It doesn't convert numpy arrays to stacked-tensors, but rather combines them in a list 15 | * This is useful for processing annotations of different sizes 16 | """ 17 | # this case will occur in first pass, and will convert a 18 | # list of dictionaries (returned by the threads by sampling dataset[idx]) 19 | # to a unified dictionary of collated values 20 | if isinstance(batch[0], collections.Mapping): 21 | return {key: collate_custom([d[key] for d in batch]) for key in batch[0]} 22 | # these cases will occur in recursion 23 | elif torch.is_tensor(batch[0]): # for tensors, use standrard collating function 24 | return default_collate(batch) 25 | else: # for other types (i.e. lists), return as is 26 | return batch 27 | 28 | 29 | class BatchTensorToVars(object): 30 | """Convert tensors in dict batch to vars 31 | """ 32 | 33 | def __init__(self, use_cuda=True): 34 | self.use_cuda = use_cuda 35 | 36 | def __call__(self, batch): 37 | batch_var = {} 38 | for key, value in batch.items(): 39 | if isinstance(value, torch.Tensor) and not self.use_cuda: 40 | batch_var[key] = Variable(value, requires_grad=False) 41 | elif isinstance(value, torch.Tensor) and self.use_cuda: 42 | batch_var[key] = Variable(value, requires_grad=False).cuda() 43 | else: 44 | batch_var[key] = value 45 | return batch_var 46 | 47 | 48 | def Softmax1D(x, dim): 49 | x_k = torch.max(x, dim)[0].unsqueeze(dim) 50 | x -= x_k.expand_as(x) 51 | exp_x = torch.exp(x) 52 | return torch.div(exp_x, torch.sum(exp_x, dim).unsqueeze(dim).expand_as(x)) 53 | 54 | 55 | def save_checkpoint(state, is_best, file, save_all_epochs=False): 56 | model_dir = dirname(file) 57 | model_fn = basename(file) 58 | # make dir if needed (should be non-empty) 59 | if model_dir != "" and not exists(model_dir): 60 | makedirs(model_dir) 61 | if save_all_epochs: 62 | torch.save(state, join(model_dir, str(state["epoch"]) + "_" + model_fn)) 63 | if is_best: 64 | shutil.copyfile( 65 | join(model_dir, str(state["epoch"]) + "_" + model_fn), 66 | join(model_dir, "best_" + model_fn), 67 | ) 68 | return join(model_dir, str(state["epoch"]) + "_" + model_fn) 69 | else: 70 | torch.save(state, file) 71 | if is_best: 72 | shutil.copyfile(file, join(model_dir, "best_" + model_fn)) 73 | 74 | 75 | def str_to_bool(v): 76 | if v.lower() in ("yes", "true", "t", "y", "1"): 77 | return True 78 | elif v.lower() in ("no", "false", "f", "n", "0"): 79 | return False 80 | else: 81 | raise argparse.ArgumentTypeError("Boolean value expected.") 82 | 83 | 84 | def expand_dim(tensor, dim, desired_dim_len): 85 | sz = list(tensor.size()) 86 | sz[dim] = desired_dim_len 87 | return tensor.expand(tuple(sz)) 88 | 89 | -------------------------------------------------------------------------------- /lib/transformation.py: -------------------------------------------------------------------------------- 1 | # from NCNet 2 | from __future__ import print_function, division 3 | import os 4 | import sys 5 | from skimage import io 6 | import pandas as pd 7 | import numpy as np 8 | import torch 9 | from torch.nn.modules.module import Module 10 | from torch.utils.data import Dataset 11 | from torch.autograd import Variable 12 | import torch.nn.functional as F 13 | 14 | from lib.torch_util import expand_dim 15 | 16 | 17 | class AffineTnf(object): 18 | def __init__(self, out_h=240, out_w=240, use_cuda=True): 19 | self.out_h = out_h 20 | self.out_w = out_w 21 | self.use_cuda = use_cuda 22 | self.gridGen = AffineGridGen(out_h=out_h, out_w=out_w, use_cuda=use_cuda) 23 | self.theta_identity = torch.Tensor( 24 | np.expand_dims(np.array([[1, 0, 0], [0, 1, 0]]), 0).astype(np.float32) 25 | ) 26 | if use_cuda: 27 | self.theta_identity = self.theta_identity.cuda() 28 | 29 | def __call__(self, image_batch, theta_batch=None, out_h=None, out_w=None): 30 | if image_batch is None: 31 | b = 1 32 | else: 33 | b = image_batch.size(0) 34 | if theta_batch is None: 35 | theta_batch = self.theta_identity 36 | theta_batch = theta_batch.expand(b, 2, 3).contiguous() 37 | theta_batch = Variable(theta_batch, requires_grad=False) 38 | 39 | # check if output dimensions have been specified at call time and have changed 40 | if (out_h is not None and out_w is not None) and ( 41 | out_h != self.out_h or out_w != self.out_w 42 | ): 43 | gridGen = AffineGridGen(out_h, out_w) 44 | else: 45 | gridGen = self.gridGen 46 | 47 | sampling_grid = gridGen(theta_batch) 48 | 49 | # sample transformed image 50 | warped_image_batch = F.grid_sample(image_batch, sampling_grid) 51 | 52 | return warped_image_batch 53 | 54 | 55 | class AffineGridGen(Module): 56 | def __init__(self, out_h=240, out_w=240, out_ch=3, use_cuda=True): 57 | super(AffineGridGen, self).__init__() 58 | self.out_h = out_h 59 | self.out_w = out_w 60 | self.out_ch = out_ch 61 | 62 | def forward(self, theta): 63 | b = theta.size()[0] 64 | if not theta.size() == (b, 2, 3): 65 | theta = theta.view(-1, 2, 3) 66 | theta = theta.contiguous() 67 | batch_size = theta.size()[0] 68 | out_size = torch.Size((batch_size, self.out_ch, self.out_h, self.out_w)) 69 | return F.affine_grid(theta, out_size) 70 | -------------------------------------------------------------------------------- /lib/visualisation.py: -------------------------------------------------------------------------------- 1 | from . import constant 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def displaySingle(img0, key0, file_name): 7 | plt.imshow(img0) 8 | cm_hot = plt.get_cmap('hsv') #coolwarm 9 | plt.axis('off') 10 | 11 | N = key0.shape[0] 12 | # dispaly bounding boxes 13 | num=0 14 | for i in range(N): 15 | tmp = key0[i] 16 | if tmp[2]> 0.1: 17 | num+=1 18 | plt.scatter(tmp[0],tmp[1], s=100, 19 | edgecolors='w', 20 | color=cm_hot(float(num/N)), 21 | alpha=1., 22 | marker='o') 23 | 24 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0, dpi=300) 25 | plt.clf() 26 | 27 | def displaySingle2(img0, key0, key_gt, file_name): 28 | plt.imshow(img0) 29 | cm_hot = plt.get_cmap('tab20') #coolwarm 30 | plt.axis('off') 31 | 32 | N = key0.shape[0] 33 | # dispaly bounding boxes 34 | num=0 35 | for i in range(N): 36 | tmp = key0[i] 37 | gt = key_gt[i] 38 | if tmp[2]> 0.1: 39 | num+=1 40 | plt.scatter(tmp[0],tmp[1], s=100, 41 | edgecolors='w', 42 | color=cm_hot(float(num/N)), 43 | alpha=.5, 44 | marker='X') 45 | plt.scatter(gt[0],gt[1], s=100, 46 | edgecolors='w', 47 | color=cm_hot(float(num/N)), 48 | alpha=.5, 49 | marker='o') 50 | x = [tmp[0],gt[0]] 51 | y = [tmp[1], gt[1]] 52 | plt.plot(x, y, alpha=.5, color=cm_hot(float(num/N)),linewidth=3.0) 53 | 54 | plt.savefig(file_name, bbox_inches='tight', pad_inches=0, dpi=300) 55 | plt.clf() 56 | 57 | 58 | def displayPair(img0, k0, img1, k1, gt1, file_name=None): 59 | """ 60 | displayPair() visualise a pair of image to be matched 61 | Arguments: 62 | img0, img1 [numpy.array float] H X W x 3: source and target image 63 | k0, k1 [numpy.array float] N x 3: source and target key points 64 | the first and second column are x y coordinate in image 65 | the third column is an indicator: for k0, the index of key 66 | points; for k1, the indicator can be messy and therefore to 67 | be set according to its source. the k0 and k1 may contain 68 | empty zero rows, but the rows for k0 and k1 must be synchronized 69 | is_axisoff: bool switcher for showing axis 70 | """ 71 | tmp = np.zeros((k1.shape[0], 3), float) 72 | tmp[:, :2] = k1 73 | k1 = tmp 74 | for a0, a1 in zip(k0, k1): 75 | if a0[2] > 0: 76 | a1[2] = a0[2] 77 | else: 78 | a1[2] = 0.0 79 | 80 | # source image 81 | displaySingle(img0, k0, file_name+'source.png') 82 | displaySingle(img1, k1, file_name+'target.png') 83 | displaySingle2(img1, k1, gt1, file_name+'prediction.png') 84 | 85 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scikit-image==0.16.1 2 | pandas==0.25.2 3 | tqdm==4.36.1 4 | opencv-python==4.1.1.26 5 | pretrainedmodels==0.7.4 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | pip install -r requirements.txt 2 | wget -O ancnet.zip https://www.dropbox.com/s/bjul4f5z7beq3um/ancnet.zip?dl=0 3 | unzip -q ancnet.zip 4 | rm ancnet.zip 5 | 6 | python eval_pf_pascal.py --a 0.1 --num_examples 5 7 | 8 | --------------------------------------------------------------------------------