├── .gitignore ├── LICENSE ├── README.md ├── data ├── .gitkeep └── prepare_immatch_val_data.sh ├── data_pairs ├── download.sh ├── precompute_immatch_val_ovs.py └── prep_megadepth_training_pairs.ipynb ├── environment.yml ├── examples ├── images │ ├── pair_1 │ │ ├── 1.jpg │ │ └── 2.jpg │ ├── pair_2 │ │ ├── 1.jpg │ │ └── 2.jpg │ └── pair_3 │ │ ├── 1.jpg │ │ └── 2.jpg └── visualize_matches.ipynb ├── networks ├── modules.py ├── ncn │ ├── conv4d.py │ ├── extract_ncmatches.py │ └── model.py ├── patch2pix.py ├── resnet.py └── utils.py ├── patch2pix.png ├── pretrained └── download.sh ├── train_patch2pix.py └── utils ├── colmap ├── data_loading.py ├── read_database.py └── read_write_model.py ├── common ├── plotting.py ├── setup_helper.py └── visdom_helper.py ├── datasets ├── __init__.py ├── data_parsing.py ├── dataset_megadepth.py └── preprocess.py ├── eval ├── geometry.py ├── measure.py └── model_helper.py └── train ├── eval_epoch_immatch.py └── helper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Self-defined ignores 2 | *.pdf 3 | *.ipynb 4 | *.npy 5 | *.pth 6 | data/ 7 | output/ 8 | 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Qunjie Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Patch2Pix for Accurate Image Correspondence Estimation 5 | This repository contains the Pytorch implementation of our paper accepted at CVPR2021: Patch2Pix: Epipolar-Guided Pixel-Level Correspondences. [[Paper]](https://arxiv.org/abs/2012.01909) [[Video]](https://www.youtube.com/watch?v=Qxkyjsgi8xY)[[Slides]](https://drive.google.com/file/d/1prmyTiYLGq5iDJnDebL-GrANKh2lvB8f/view?usp=sharing). 6 | 7 | ![Overview](patch2pix.png) 8 | To use our code, first download the repository: 9 | ```` 10 | git clone git@github.com:GrumpyZhou/patch2pix.git 11 | ```` 12 | ## Setup Running Environment 13 | The code has been tested on Ubuntu (16.04&18.04) with Python 3.7 + Pytorch 1.7.0 + CUDA 10.2. 14 | We recommend to use *Anaconda* to manage packages and reproduce the paper results. Run the following lines to automatically setup a ready environment for our code. 15 | ```` 16 | conda env create -f environment.yml 17 | conda activte patch2pix 18 | ```` 19 | 20 | ### Download Pretrained Models 21 | In order to run our examples, one needs to first download our pretrained Patch2Pix model. To further train a Patch2Pix model, one needs to download the pretrained NCNet. We provide the download links in [pretrained/download.sh](pretrained/download.sh). 22 | To download both, one can run 23 | ```` 24 | cd pretrained 25 | bash download.sh 26 | ```` 27 | 28 | ## Evaluation 29 | **❗️NOTICE❗️:** In this repository, we only provide examples to estimate correspondences using our Patch2Pix implemenetation. 30 | 31 | To reproduce our evalutions on **HPatches**, **Aachen** and **InLoc** benchmarks, we refer you to **our toolbox for image matching**: [**image-matching-toolbox**](https://github.com/GrumpyZhou/image-matching-toolbox). There, you can also find implementation to reproduce the results of other state-of-the-art methods that we compared to in our paper. 32 | 33 | ### Matching Examples 34 | In our notebook [examples/visualize_matches.ipynb](examples/visualize_matches.ipynb) , we give examples how to obtain matches given a pair of images using both **Patch2Pix** (our pretrained) and **NCNet** (our adapted). The example image pairs are borrowed from [D2Net](https://github.com/mihaidusmanu/d2-net), one can easily replace it with your own examples. 35 | 36 | ## Training 37 | *Notice the followings are necessary **only if** you want to train a model yourself*. 38 | ### Data preparation 39 | We use [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/) dataset for training. 40 | To keep more data for training, we didn't split a validation set from MegaDepth. Instead we use the validation splits of [PhotoTourism](https://www.cs.ubc.ca/~kmyi/imw2020/data.html). 41 | The following steps describe how to prepare the same training and validation data that we used. 42 | 43 | **Preapre Training Data** 44 | 1. We preprocess **MegaDepth** dataset following the preprocessing steps proposed by [D2Net](https://github.com/mihaidusmanu/d2-net). For details, please checkout their *"Downloading and preprocessing the MegaDepth dataset"* section in their github documentation. 45 | 46 | 2. Then place the processed MegaDepth dataset under **data/** folder and name it as **MegaDepth_undistort** (or create a symbolic link for it). 47 | 48 | 3. One can directly download our **pre-computred** training pairs using our download script. 49 | ```` 50 | cd data_pairs 51 | bash download.sh 52 | ```` 53 | In case one wants to generate pairs with different settings, we provide notebooks to **generate pairs from scratch**. Once you finish step 1 and 2, the training pairs can be generated using our notebook [data_pairs/prep_megadepth_training_pairs.ipynb](data_pairs/prep_megadepth_training_pairs.ipynb). 54 | 55 | **Preapre Validation Data** 56 | 1. Use our script to dowload and extract the subset of train and val sequences from the **PhotoTourism** dataset. 57 | ``` 58 | cd data 59 | bash prepare_immatch_val_data.sh 60 | ``` 61 | 2. Precompute image pairwise overlappings for fast loading of validation pairs. 62 | ``` 63 | # Under the root folder: patch2pix/ 64 | python -m data_pairs.precompute_immatch_val_ovs \ 65 | --data_root data/immatch_benchmark/val_dense 66 | ``` 67 | 68 | ### Training Examples 69 | 70 | To train our best model: 71 | ```` 72 | python -m train_patch2pix --gpu 0 \ 73 | --epochs 25 --batch 4 \ 74 | --save_step 1 --plot_counts 20 --data_root 'data' \ 75 | --change_stride --panc 8 --ptmax 400 \ 76 | --pretrain 'pretrained/ncn_ivd_5ep.pth' \ 77 | -lr 0.0005 -lrd 'multistep' 0.2 5 \ 78 | --cls_dthres 50 5 --epi_dthres 50 5 \ 79 | -o 'output/patch2pix' 80 | ```` 81 | 82 | The above command will save the log file and checkpoints to the output folder specified by `-o`. Our best model was trained on a 48GB GPU. 83 | To train on a smaller GPU, e.g, with 12 GB, one can either set `--batch 1 ` or `--ptmax 250` which defines the maximum number of match proposals to be refined for each image pair. 84 | However, those changes might also decrease the training performance according to our experience. 85 | Notice, during the testing, our network only requires 12GB GPU. 86 | 87 | **Usage of Visdom Server** 88 | Our training script is coded to monitor the training process using Visdom. To enable the monitoring, one needs to: 89 | 1) Run a visdom sever on your localhost, for example: 90 | ``` 91 | # Feel free to change the port 92 | python -m visdom.server -port 9333 \ 93 | -env_path ~/.visdom/patch2pix 94 | ``` 95 | 2) Append options `-vh 'localhost' -vp 9333` to the commands of the training example above. 96 | 97 | ## BibTeX 98 | If you use our method or code in your project, please cite our paper: 99 | ``` 100 | @inproceedings{ZhouCVPRpatch2pix, 101 | author = "Zhou, Qunjie and Sattler, Torsten and Leal-Taixe, Laura", 102 | title = "Patch2Pix: Epipolar-Guided Pixel-Level Correspondences", 103 | booktitle = "CVPR", 104 | year = 2021, 105 | } 106 | ``` 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/data/.gitkeep -------------------------------------------------------------------------------- /data/prepare_immatch_val_data.sh: -------------------------------------------------------------------------------- 1 | # Download and extracting validation set 2 | # Dataset link: https://www.cs.ubc.ca/~kmyi/imw2020/data.html 3 | 4 | mkdir immatch_benchmark 5 | mkdir immatch_benchmark/val_dense 6 | cd immatch_benchmark/val_dense 7 | 8 | wget https://www.cs.ubc.ca/research/kmyi_data/imw2020/TrainingData/reichstag.tar.gz 9 | tar -xvzf reichstag.tar.gz 10 | 11 | wget https://www.cs.ubc.ca/research/kmyi_data/imw2020/TrainingData/sacre_coeur.tar.gz 12 | tar -xvzf sacre_coeur.tar.gz 13 | 14 | wget https://www.cs.ubc.ca/research/kmyi_data/imw2020/TrainingData/st_peters_square.tar.gz 15 | tar -xvzf st_peters_square.tar.gz 16 | 17 | wget https://www.cs.ubc.ca/research/kmyi_data/imw2020/TrainingData/taj_mahal.tar.gz 18 | tar -xvzf taj_mahal.tar.gz 19 | 20 | wget https://www.cs.ubc.ca/research/kmyi_data/imw2020/TrainingData/temple_nara_japan.tar.gz 21 | tar -xvzf temple_nara_japan.tar.gz 22 | 23 | rm *.tar.gz 24 | -------------------------------------------------------------------------------- /data_pairs/download.sh: -------------------------------------------------------------------------------- 1 | # Pre-computed pairs: https://drive.google.com/file/d/1u8sfc23c9ZhXSA_IVct2l_T3JLrS5air/view?usp=sharing 2 | gdown 1u8sfc23c9ZhXSA_IVct2l_T3JLrS5air 3 | -------------------------------------------------------------------------------- /data_pairs/precompute_immatch_val_ovs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import numpy as np 5 | from utils.colmap.data_loading import sav_model_multi_ov_pairs 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--data_root', type=str, default='data/immatch_benchmark/val_as_train') 9 | args = parser.parse_args() 10 | data_root = args.data_root 11 | 12 | overlaps = [0.1, 0.2, 0.3, 0.4, 0.5] 13 | scenes = os.listdir(data_root) 14 | print(f'Target scenes: {scenes}, ovs: {overlaps}\n') 15 | for scene in scenes: 16 | print(f'Start processing scene: {scene}') 17 | model_dir = os.path.join(data_root, scene, 'dense/sparse') 18 | t0 = time.time() 19 | ov_pair_dict = sav_model_multi_ov_pairs(model_dir, overlaps) 20 | print(f'Finished, time {time.time() - t0}') -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: patch2pix 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.7 8 | - pytorch 9 | - cudatoolkit=10.2 10 | - torchvision 11 | - jupyter 12 | - scipy 13 | - matplotlib 14 | - pillow=6.0.0=py37he7afcd5_0 15 | - pip 16 | - pip: 17 | - transforms3d 18 | - pydegensac 19 | - h5py 20 | - opencv-python 21 | - visdom==0.1.8.8 -------------------------------------------------------------------------------- /examples/images/pair_1/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_1/1.jpg -------------------------------------------------------------------------------- /examples/images/pair_1/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_1/2.jpg -------------------------------------------------------------------------------- /examples/images/pair_2/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_2/1.jpg -------------------------------------------------------------------------------- /examples/images/pair_2/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_2/2.jpg -------------------------------------------------------------------------------- /examples/images/pair_3/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_3/1.jpg -------------------------------------------------------------------------------- /examples/images/pair_3/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/examples/images/pair_3/2.jpg -------------------------------------------------------------------------------- /networks/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | L2Normalize = lambda feat, dim: feat / torch.pow(torch.sum(torch.pow(feat, 2), dim=dim) + 1e-6, 0.5).unsqueeze(dim) 7 | 8 | def cal_conv_out_size(w, kernel_size, stride, padding): 9 | return (w - kernel_size + 2 * padding) // stride + 1 10 | 11 | def maxpool4d(corr4d_hres, k_size=4): 12 | slices=[] 13 | for i in range(k_size): 14 | for j in range(k_size): 15 | for k in range(k_size): 16 | for l in range(k_size): 17 | sl = corr4d_hres[:,:,i::k_size,j::k_size,k::k_size,l::k_size] # Support batches 18 | slices.append(sl) 19 | 20 | slices = torch.cat(tuple(slices),dim=1) # B, ksize*4, h1, w1, h2, w2 21 | corr4d, max_idx = torch.max(slices,dim=1,keepdim=True) 22 | 23 | # i,j,k,l represent the *relative* coords of the max point in the box of size k_size*k_size*k_size*k_size 24 | if torch.__version__ >= '1.6.0': 25 | max_l=torch.fmod(max_idx,k_size) 26 | max_k=torch.fmod(max_idx.sub(max_l).floor_divide(k_size),k_size) 27 | max_j=torch.fmod(max_idx.sub(max_l).floor_divide(k_size).sub(max_k).floor_divide(k_size),k_size) 28 | max_i=max_idx.sub(max_l).floor_divide(k_size).sub(max_k).floor_divide(k_size).sub(max_j).floor_divide(k_size) 29 | else: 30 | max_l=torch.fmod(max_idx,k_size) 31 | max_k=torch.fmod(max_idx.sub(max_l).div(k_size),k_size) 32 | max_j=torch.fmod(max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size),k_size) 33 | max_i=max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size).sub(max_j).div(k_size) 34 | return (corr4d,max_i,max_j,max_k,max_l) 35 | 36 | class FeatCorrelation(torch.nn.Module): 37 | def __init__(self, shape='4D'): 38 | super().__init__() 39 | self.shape = shape 40 | 41 | def forward(self, feat1, feat2): 42 | b, c, h1, w1 = feat1.size() 43 | b, c, h2, w2 = feat2.size() 44 | feat1 = feat1.view(b, c, h1*w1).transpose(1, 2) # size [b, h1*w1, c] 45 | feat2 = feat2.view(b, c, h2*w2) # size [b, c, h2*w2] 46 | 47 | # Matrix multiplication 48 | correlation = torch.bmm(feat1, feat2) # [b, h1*w1, h2*w2] 49 | if self.shape == '3D': 50 | correlation = correlation.view(b, h1, w1, h2*w2).permute(0, 3, 1, 2) # [b, h2*w2, h1, w1] 51 | elif self.shape == '4D': 52 | correlation = correlation.view(b, h1, w1, h2, w2).unsqueeze(1) # [b, 1, h1, w1, h2, w2] 53 | return correlation 54 | 55 | 56 | class FeatRegressNet(nn.Module): 57 | def __init__(self, config, psize=16, out_dim=5): 58 | super().__init__() 59 | self.psize = psize 60 | self.conv_strs = config.conv_strs if 'conv_strs' in config else [2] * len(config.conv_kers) 61 | self.conv_dims = config.conv_dims 62 | self.conv_kers = config.conv_kers 63 | self.feat_comb = config.feat_comb # Combine 2 feature maps before the conv or after the conv 64 | self.feat_dim = config.feat_dim if self.feat_comb == 'post' else 2 * config.feat_dim 65 | self.fc_in_dim = config.conv_dims[-1] * 2 if self.feat_comb == 'post' else config.conv_dims[-1] 66 | 67 | # Build layers 68 | self.conv = self.make_conv_layers(self.feat_dim, self.conv_dims, self.conv_kers) 69 | self.fc = self.make_fc_layers(self.fc_in_dim, config.fc_dims, out_dim) 70 | print(f'FeatRegressNet: feat_comb:{self.feat_comb} ' \ 71 | f'psize:{self.psize} out:{out_dim} ' \ 72 | f'feat_dim:{self.feat_dim} conv_kers:{self.conv_kers} ' \ 73 | f'conv_dims:{self.conv_dims} conv_str:{self.conv_strs} ' 74 | ) 75 | 76 | def make_conv_layers(self, in_dim, conv_dims, conv_kers, bias=False): 77 | layers = [] 78 | w = self.psize # Initial spatial size 79 | for out_dim, kernel_size, stride in zip(conv_dims, conv_kers, self.conv_strs): 80 | layers.append(nn.Conv2d(in_dim, out_dim, kernel_size, stride=stride, padding=1, bias=bias)) 81 | layers.append(nn.BatchNorm2d(out_dim)) 82 | w = cal_conv_out_size(w, kernel_size, stride, 1) 83 | in_dim = out_dim 84 | layers.append(nn.ReLU()) 85 | # To make sure spatial dim goes to 1, one can also use AdaptiveMaxPool 86 | layers.append(nn.MaxPool2d(kernel_size=w)) 87 | return nn.Sequential(*layers) 88 | 89 | def make_fc_layers(self, in_dim, fc_dims, fc_out_dim): 90 | layers = [] 91 | for out_dim in fc_dims: 92 | layers.append(nn.Linear(in_dim, out_dim)) 93 | layers.append(nn.BatchNorm1d(out_dim)), 94 | layers.append(nn.ReLU()) 95 | in_dim = out_dim 96 | 97 | # Final layer 98 | layers.append(nn.Linear(in_dim, fc_out_dim)) 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, feat1, feat2): 102 | # feat1, feat2: shape (N, D, 16, 16) 103 | if self.feat_comb == 'pre': 104 | feat = torch.cat([feat1, feat2], dim=1) 105 | feat = self.conv(feat) # N, D, 1, 1 106 | else: 107 | feat1 = self.conv(feat1) 108 | feat2 = self.conv(feat2) 109 | feat = torch.cat([feat1, feat2], dim=1) # N, D, 1, 1 110 | feat = feat.view(-1, feat.shape[1]) 111 | out = self.fc(feat) # N, 5 112 | return out 113 | 114 | def init_optimizer(params, config): 115 | if config.opt == 'adam': 116 | optimizer = torch.optim.Adam(params, lr=config.lr_init, weight_decay=config.weight_decay) 117 | print('Setup Adam optimizer(lr={},wd={})'.format(config.lr_init, config.weight_decay)) 118 | 119 | elif config.opt == 'sgd': 120 | optimizer = torch.optim.SGD(params, momentum=0.9, lr=config.lr_init, weight_decay=config.weight_decay) 121 | print('Setup SGD optimizer(lr={},wd={},mom=0.9)'.format(config.lr_init, config.weight_decay)) 122 | 123 | if config.optimizer_dict: 124 | optimizer.load_state_dict(config.optimizer_dict) 125 | 126 | # Schedule learning rate decay lr_decay = ['name', params] or None 127 | lr_scheduler = None 128 | if 'lr_decay' in config and config.lr_decay: 129 | if config.lr_decay[0] == 'step': 130 | decay_factor, decay_step = float(config.lr_decay[1]), int(config.lr_decay[2]) 131 | last_epoch = config.start_epoch - 1 132 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 133 | step_size=decay_step, 134 | gamma=decay_factor, 135 | last_epoch=last_epoch) 136 | print(f'Setup StepLR Decay: decay_factor={decay_factor} ' 137 | f'step={decay_step} last_epoch={last_epoch}') 138 | 139 | elif config.lr_decay[0] == 'multistep': 140 | decay_factor = float(config.lr_decay[1]) 141 | decay_steps = [int(v) for v in config.lr_decay[2::]] 142 | last_epoch = config.start_epoch - 1 143 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 144 | milestones=decay_steps, 145 | gamma=decay_factor, 146 | last_epoch=last_epoch) 147 | print(f'Setup MultiStepLR Decay: decay_factor={decay_factor} ' 148 | f'steps={decay_steps} last_epoch={last_epoch}') 149 | 150 | if config.lr_scheduler_dict and lr_scheduler: 151 | lr_scheduler.load_state_dict(config.lr_scheduler_dict) 152 | return optimizer, lr_scheduler 153 | 154 | def xavier_init_func_(m): 155 | classname = m.__class__.__name__ 156 | if classname.startswith('Conv'): 157 | nn.init.xavier_uniform_(m.weight.data) 158 | if m.bias is not None: # Incase bias is turned off 159 | nn.init.constant_(m.bias.data, 0.0) 160 | elif classname.find('Linear') != -1: 161 | nn.init.xavier_uniform_(m.weight.data) 162 | if m.bias is not None: # Incase bias is turned off 163 | nn.init.constant_(m.bias.data, 0.0) 164 | elif classname.find('BatchNorm2d') != -1: 165 | nn.init.normal_(m.weight.data, 1.0, 0.02) 166 | nn.init.constant_(m.bias.data, 0.0) 167 | -------------------------------------------------------------------------------- /networks/ncn/conv4d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | from torch.nn import Module 6 | from torch.nn.modules.conv import _ConvNd 7 | from torch.nn.modules.utils import _quadruple 8 | from torch.autograd import Variable 9 | from torch.nn import Conv2d 10 | 11 | 12 | def conv4d(data, filters, bias=None, permute_filters=True, use_half=False): 13 | b, c, h, w, d, t = data.size() 14 | 15 | data = data.permute( 16 | 2, 0, 1, 3, 4, 5 17 | ).contiguous() # permute to avoid making contiguous inside loop 18 | 19 | # Same permutation is done with filters, unless already provided with permutation 20 | if permute_filters: 21 | filters = filters.permute( 22 | 2, 0, 1, 3, 4, 5 23 | ).contiguous() # permute to avoid making contiguous inside loop 24 | 25 | c_out = filters.size(1) 26 | if use_half: 27 | output = Variable( 28 | torch.HalfTensor(h, b, c_out, w, d, t), requires_grad=data.requires_grad 29 | ) 30 | else: 31 | output = Variable( 32 | torch.zeros(h, b, c_out, w, d, t), requires_grad=data.requires_grad 33 | ) 34 | 35 | padding = filters.size(0) // 2 36 | if use_half: 37 | Z = Variable(torch.zeros(padding, b, c, w, d, t).half()) 38 | else: 39 | Z = Variable(torch.zeros(padding, b, c, w, d, t)) 40 | 41 | if data.is_cuda: 42 | Z = Z.cuda(data.get_device()) 43 | output = output.cuda(data.get_device()) 44 | 45 | data_padded = torch.cat((Z, data, Z), 0) 46 | 47 | for i in range(output.size(0)): # loop on first feature dimension 48 | # convolve with center channel of filter (at position=padding) 49 | output[i, :, :, :, :, :] = F.conv3d( 50 | data_padded[i + padding, :, :, :, :, :], 51 | filters[padding, :, :, :, :, :], 52 | bias=bias, 53 | stride=1, 54 | padding=padding, 55 | ) 56 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) 57 | for p in range(1, padding + 1): 58 | output[i, :, :, :, :, :] = output[i, :, :, :, :, :] + F.conv3d( 59 | data_padded[i + padding - p, :, :, :, :, :], 60 | filters[padding - p, :, :, :, :, :], 61 | bias=None, 62 | stride=1, 63 | padding=padding, 64 | ) 65 | output[i, :, :, :, :, :] = output[i, :, :, :, :, :] + F.conv3d( 66 | data_padded[i + padding + p, :, :, :, :, :], 67 | filters[padding + p, :, :, :, :, :], 68 | bias=None, 69 | stride=1, 70 | padding=padding, 71 | ) 72 | 73 | output = output.permute(1, 2, 0, 3, 4, 5).contiguous() 74 | return output 75 | 76 | 77 | class Conv4d(_ConvNd): 78 | """Applies a 4D convolution over an input signal composed of several input 79 | planes. 80 | """ 81 | 82 | def __init__( 83 | self, 84 | in_channels, 85 | out_channels, 86 | kernel_size, 87 | bias=True, 88 | pre_permuted_filters=True, 89 | ): 90 | # stride, dilation and groups !=1 functionality not tested 91 | stride = 1 92 | dilation = 1 93 | groups = 1 94 | # zero padding is added automatically in conv4d function to preserve tensor size 95 | padding = 0 96 | kernel_size = _quadruple(kernel_size) 97 | stride = _quadruple(stride) 98 | padding = _quadruple(padding) 99 | dilation = _quadruple(dilation) 100 | 101 | super().__init__( 102 | in_channels, 103 | out_channels, 104 | kernel_size, 105 | stride, 106 | padding, 107 | dilation, 108 | False, 109 | _quadruple(0), 110 | groups, 111 | bias, 112 | padding_mode="zeros", 113 | ) 114 | 115 | # weights will be sliced along one dimension during convolution loop 116 | # make the looping dimension to be the first one in the tensor, 117 | # so that we don't need to call contiguous() inside the loop 118 | self.pre_permuted_filters = pre_permuted_filters 119 | if self.pre_permuted_filters: 120 | self.weight.data = self.weight.data.permute(2, 0, 1, 3, 4, 5).contiguous() 121 | self.use_half = False 122 | 123 | def forward(self, input): 124 | return conv4d( 125 | input, 126 | self.weight, 127 | bias=self.bias, 128 | permute_filters=not self.pre_permuted_filters, 129 | use_half=self.use_half, 130 | ) # filters pre-permuted in constructor 131 | -------------------------------------------------------------------------------- /networks/ncn/extract_ncmatches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def corr_to_matches(corr4d, delta4d=None, ksize=1, do_softmax=True, scale='positive', 7 | invert_matching_direction=False, return_indices=True): 8 | to_cuda = lambda x: x.to(corr4d.device) if corr4d.is_cuda else x 9 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() # b, c, h, w, h, w 10 | if scale=='centered': 11 | XA,YA=np.meshgrid(np.linspace(-1,1,fs2*ksize),np.linspace(-1,1,fs1*ksize)) 12 | XB,YB=np.meshgrid(np.linspace(-1,1,fs4*ksize),np.linspace(-1,1,fs3*ksize)) 13 | elif scale=='positive': 14 | # Upsampled resolution linear space 15 | XA,YA=np.meshgrid(np.linspace(0,1,fs2*ksize),np.linspace(0,1,fs1*ksize)) 16 | XB,YB=np.meshgrid(np.linspace(0,1,fs4*ksize),np.linspace(0,1,fs3*ksize)) 17 | # Index meshgrid for current resolution 18 | JA,IA=np.meshgrid(range(fs2),range(fs1)) 19 | JB,IB=np.meshgrid(range(fs4),range(fs3)) 20 | 21 | XA,YA=Variable(to_cuda(torch.FloatTensor(XA))),Variable(to_cuda(torch.FloatTensor(YA))) 22 | XB,YB=Variable(to_cuda(torch.FloatTensor(XB))),Variable(to_cuda(torch.FloatTensor(YB))) 23 | 24 | JA,IA=Variable(to_cuda(torch.LongTensor(JA).view(1,-1))),Variable(to_cuda(torch.LongTensor(IA).view(1,-1))) 25 | JB,IB=Variable(to_cuda(torch.LongTensor(JB).view(1,-1))),Variable(to_cuda(torch.LongTensor(IB).view(1,-1))) 26 | 27 | if invert_matching_direction: 28 | nc_A_Bvec=corr4d.view(batch_size,fs1,fs2,fs3*fs4) 29 | 30 | if do_softmax: 31 | nc_A_Bvec=torch.nn.functional.softmax(nc_A_Bvec,dim=3) 32 | 33 | # Max and argmax 34 | match_A_vals,idx_A_Bvec=torch.max(nc_A_Bvec,dim=3) 35 | score=match_A_vals.view(batch_size,-1) 36 | 37 | # Pick the indices for the best score 38 | iB=IB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size,-1).contiguous() # b, h1*w1 39 | jB=JB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size,-1).contiguous() 40 | iA=IA.expand_as(iB).contiguous() 41 | jA=JA.expand_as(jB).contiguous() 42 | 43 | else: 44 | nc_B_Avec=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B] 45 | if do_softmax: 46 | nc_B_Avec=torch.nn.functional.softmax(nc_B_Avec,dim=1) 47 | 48 | match_B_vals,idx_B_Avec=torch.max(nc_B_Avec,dim=1) 49 | score=match_B_vals.view(batch_size,-1) 50 | 51 | iA=IA.view(-1)[idx_B_Avec.view(-1)].view(batch_size,-1).contiguous() # b, h2*w2 52 | jA=JA.view(-1)[idx_B_Avec.view(-1)].view(batch_size,-1).contiguous() 53 | iB=IB.expand_as(iA).contiguous() 54 | jB=JB.expand_as(jA).contiguous() 55 | 56 | if delta4d is not None: # relocalization, it is also the case ksize > 1 57 | # The shift within the pooling window reference to (0,0,0,0) 58 | delta_iA, delta_jA, delta_iB, delta_jB = delta4d # b, 1, h1, w1, h2, w2 59 | 60 | """ Original implementation 61 | # Reorder the indices according 62 | diA = delta_iA.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)] 63 | djA = delta_jA.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)] 64 | diB = delta_iB.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)] 65 | djB = delta_jB.squeeze(0).squeeze(0)[iA.view(-1), jA.view(-1), iB.view(-1), jB.view(-1)] 66 | 67 | # *ksize place the pixel to the 1st location in upsampled 4D-Volumn 68 | iA = iA * ksize + diA.expand_as(iA) 69 | jA = jA * ksize + djA.expand_as(jA) 70 | iB = iB * ksize + diB.expand_as(iB) 71 | jB = jB * ksize + djB.expand_as(jB) 72 | """ 73 | 74 | # Support batches 75 | for ibx in range(batch_size): 76 | diA = delta_iA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] # h*w 77 | djA = delta_jA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 78 | diB = delta_iB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 79 | djB = delta_jB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 80 | 81 | iA[ibx] = iA[ibx] * ksize + diA 82 | jA[ibx] = jA[ibx] * ksize + djA 83 | iB[ibx] = iB[ibx] * ksize + diB 84 | jB[ibx] = jB[ibx] * ksize + djB 85 | 86 | xA = XA[iA.view(-1), jA.view(-1)].view(batch_size, -1) 87 | yA = YA[iA.view(-1), jA.view(-1)].view(batch_size, -1) 88 | xB = XB[iB.view(-1), jB.view(-1)].view(batch_size, -1) 89 | yB = YB[iB.view(-1), jB.view(-1)].view(batch_size, -1) 90 | 91 | if return_indices: 92 | return (jA,iA,jB,iB,score) 93 | else: 94 | return (xA,yA,xB,yB,score) 95 | 96 | def corr_to_matches_topk(corr4d, delta4d=None, topk=1, ksize=1, do_softmax=True, 97 | invert_matching_direction=False): 98 | 99 | device = corr4d.device 100 | batch_size, ch, fs1, fs2, fs3, fs4 = corr4d.size() # b, c, h, w, h, w 101 | 102 | # Index meshgrid for current resolution 103 | JA, IA = np.meshgrid(range(fs2), range(fs1)) 104 | JB, IB = np.meshgrid(range(fs4), range(fs3)) 105 | JA, IA = torch.LongTensor(JA).view(1,-1).to(device), torch.LongTensor(IA).view(1,-1).to(device) 106 | JB, IB = torch.LongTensor(JB).view(1,-1).to(device), torch.LongTensor(IB).view(1,-1).to(device) 107 | 108 | if invert_matching_direction: 109 | nc_A_Bvec = corr4d.view(batch_size, fs1, fs2, fs3 * fs4) 110 | 111 | if do_softmax: 112 | nc_A_Bvec = torch.nn.functional.softmax(nc_A_Bvec, dim=3) 113 | 114 | # Max and argmax 115 | match_A_vals, idx_A_Bvec = torch.topk(nc_A_Bvec, topk, dim=3, largest=True, sorted=True) 116 | score = match_A_vals.view(batch_size, -1) 117 | 118 | # Pick the indices for the best score 119 | iB = IB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1, topk).contiguous() 120 | jB = JB.view(-1)[idx_A_Bvec.view(-1)].view(batch_size, -1, topk).contiguous() 121 | iA = IA.unsqueeze(-1).expand_as(iB).contiguous() 122 | jA = JA.unsqueeze(-1).expand_as(jB).contiguous() 123 | 124 | else: 125 | nc_B_Avec = corr4d.view(batch_size, fs1 * fs2, fs3, fs4) # [batch_idx,k_A,i_B,j_B] 126 | if do_softmax: 127 | nc_B_Avec = torch.nn.functional.softmax(nc_B_Avec, dim=1) 128 | 129 | match_B_vals, idx_B_Avec = torch.topk(nc_B_Avec, topk, dim=1, largest=True, sorted=True) 130 | score = match_B_vals.view(batch_size, -1) 131 | 132 | iA = IA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, topk, -1).contiguous() 133 | jA = JA.view(-1)[idx_B_Avec.view(-1)].view(batch_size, topk, -1).contiguous() 134 | iB = IB.unsqueeze(1).expand_as(iA).contiguous() 135 | jB = JB.unsqueeze(1).expand_as(jA).contiguous() 136 | 137 | iA = iA.view(batch_size, -1) 138 | jA = jA.view(batch_size, -1) 139 | iB = iB.view(batch_size, -1) 140 | jB = jB.view(batch_size, -1) 141 | 142 | if delta4d is not None: # relocalization, it is also the case ksize > 1 143 | # The shift within the pooling window reference to (0,0,0,0) 144 | delta_iA, delta_jA, delta_iB, delta_jB = delta4d 145 | 146 | # Support batches 147 | for ibx in range(batch_size): 148 | diA = delta_iA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] # h*w 149 | djA = delta_jA[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 150 | diB = delta_iB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 151 | djB = delta_jB[ibx][0][iA[ibx], jA[ibx], iB[ibx], jB[ibx]] 152 | 153 | iA[ibx] = iA[ibx] * ksize + diA 154 | jA[ibx] = jA[ibx] * ksize + djA 155 | iB[ibx] = iB[ibx] * ksize + diB 156 | jB[ibx] = jB[ibx] * ksize + djB 157 | 158 | return (jA, iA, jB, iB, score) 159 | -------------------------------------------------------------------------------- /networks/ncn/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | from .conv4d import Conv4d 8 | 9 | def Softmax1D(x,dim): 10 | x_k = torch.max(x,dim)[0].unsqueeze(dim) 11 | x -= x_k.expand_as(x) 12 | exp_x = torch.exp(x) 13 | return torch.div(exp_x,torch.sum(exp_x,dim).unsqueeze(dim).expand_as(x)) 14 | 15 | def featureL2Norm(feature): 16 | epsilon = 1e-6 17 | norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) 18 | feat_norm = torch.div(feature,norm) 19 | return feat_norm 20 | 21 | class FeatureExtraction(torch.nn.Module): 22 | def __init__(self, train_fe=False, feature_extraction_cnn='resnet101', feature_extraction_model_file='', normalization=True, last_layer='', use_cuda=True): 23 | super(FeatureExtraction, self).__init__() 24 | self.normalization = normalization 25 | self.feature_extraction_cnn=feature_extraction_cnn 26 | if feature_extraction_cnn == 'vgg': 27 | self.model = models.vgg16(pretrained=True) 28 | # keep feature extraction network up to indicated layer 29 | vgg_feature_layers=['conv1_1','relu1_1','conv1_2','relu1_2','pool1','conv2_1', 30 | 'relu2_1','conv2_2','relu2_2','pool2','conv3_1','relu3_1', 31 | 'conv3_2','relu3_2','conv3_3','relu3_3','pool3','conv4_1', 32 | 'relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4', 33 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5'] 34 | if last_layer=='': 35 | last_layer = 'pool4' 36 | last_layer_idx = vgg_feature_layers.index(last_layer) 37 | self.model = nn.Sequential(*list(self.model.features.children())[:last_layer_idx+1]) 38 | # for resnet below 39 | resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4'] 40 | if feature_extraction_cnn=='resnet101': 41 | self.model = models.resnet101(pretrained=True) 42 | if last_layer=='': 43 | last_layer = 'layer3' 44 | resnet_module_list = [getattr(self.model,l) for l in resnet_feature_layers] 45 | last_layer_idx = resnet_feature_layers.index(last_layer) 46 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx+1]) 47 | 48 | if feature_extraction_cnn=='resnet101fpn': 49 | if feature_extraction_model_file!='': 50 | resnet = models.resnet101(pretrained=True) 51 | # swap stride (2,2) and (1,1) in first layers (PyTorch ResNet is slightly different to caffe2 ResNet) 52 | # this is required for compatibility with caffe2 models 53 | resnet.layer2[0].conv1.stride=(2,2) 54 | resnet.layer2[0].conv2.stride=(1,1) 55 | resnet.layer3[0].conv1.stride=(2,2) 56 | resnet.layer3[0].conv2.stride=(1,1) 57 | resnet.layer4[0].conv1.stride=(2,2) 58 | resnet.layer4[0].conv2.stride=(1,1) 59 | else: 60 | resnet = models.resnet101(pretrained=True) 61 | resnet_module_list = [getattr(resnet,l) for l in resnet_feature_layers] 62 | conv_body = nn.Sequential(*resnet_module_list) 63 | self.model = fpn_body(conv_body, 64 | resnet_feature_layers, 65 | fpn_layers=['layer1','layer2','layer3'], 66 | normalize=normalization, 67 | hypercols=True) 68 | if feature_extraction_model_file!='': 69 | self.model.load_pretrained_weights(feature_extraction_model_file) 70 | 71 | if feature_extraction_cnn == 'densenet201': 72 | self.model = models.densenet201(pretrained=True) 73 | # keep feature extraction network up to denseblock3 74 | # self.model = nn.Sequential(*list(self.model.features.children())[:-3]) 75 | # keep feature extraction network up to transitionlayer2 76 | self.model = nn.Sequential(*list(self.model.features.children())[:-4]) 77 | if train_fe==False: 78 | # freeze parameters 79 | for param in self.model.parameters(): 80 | param.requires_grad = False 81 | # move to GPU 82 | if use_cuda: 83 | self.model = self.model.cuda() 84 | 85 | def forward(self, image_batch): 86 | features = self.model(image_batch) 87 | if self.normalization and not self.feature_extraction_cnn=='resnet101fpn': 88 | features = featureL2Norm(features) 89 | return features 90 | 91 | class FeatureCorrelation(torch.nn.Module): 92 | def __init__(self,shape='3D',normalization=True): 93 | super(FeatureCorrelation, self).__init__() 94 | self.normalization = normalization 95 | self.shape=shape 96 | self.ReLU = nn.ReLU() 97 | 98 | def forward(self, feature_A, feature_B): 99 | if self.shape=='3D': 100 | b,c,h,w = feature_A.size() 101 | # reshape features for matrix multiplication 102 | feature_A = feature_A.transpose(2,3).contiguous().view(b,c,h*w) 103 | feature_B = feature_B.view(b,c,h*w).transpose(1,2) 104 | # perform matrix mult. 105 | feature_mul = torch.bmm(feature_B,feature_A) 106 | # indexed [batch,idx_A=row_A+h*col_A,row_B,col_B] 107 | correlation_tensor = feature_mul.view(b,h,w,h*w).transpose(2,3).transpose(1,2) 108 | elif self.shape=='4D': 109 | b,c,hA,wA = feature_A.size() 110 | b,c,hB,wB = feature_B.size() 111 | # reshape features for matrix multiplication 112 | feature_A = feature_A.view(b,c,hA*wA).transpose(1,2) # size [b,c,h*w] 113 | feature_B = feature_B.view(b,c,hB*wB) # size [b,c,h*w] 114 | # perform matrix mult. 115 | feature_mul = torch.bmm(feature_A,feature_B) 116 | # indexed [batch,row_A,col_A,row_B,col_B] 117 | correlation_tensor = feature_mul.view(b,hA,wA,hB,wB).unsqueeze(1) 118 | 119 | if self.normalization: 120 | correlation_tensor = featureL2Norm(self.ReLU(correlation_tensor)) 121 | 122 | return correlation_tensor 123 | 124 | class NeighConsensus(torch.nn.Module): 125 | def __init__(self, use_cuda=True, kernel_sizes=[3,3,3], channels=[10,10,1], symmetric_mode=True): 126 | super(NeighConsensus, self).__init__() 127 | self.symmetric_mode = symmetric_mode 128 | self.kernel_sizes = kernel_sizes 129 | self.channels = channels 130 | num_layers = len(kernel_sizes) 131 | nn_modules = list() 132 | for i in range(num_layers): 133 | if i==0: 134 | ch_in = 1 135 | else: 136 | ch_in = channels[i-1] 137 | ch_out = channels[i] 138 | k_size = kernel_sizes[i] 139 | nn_modules.append(Conv4d(in_channels=ch_in,out_channels=ch_out,kernel_size=k_size,bias=True)) 140 | nn_modules.append(nn.ReLU(inplace=True)) 141 | self.conv = nn.Sequential(*nn_modules) 142 | if use_cuda: 143 | self.conv.cuda() 144 | 145 | def forward(self, x): 146 | if self.symmetric_mode: 147 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 148 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 149 | x = self.conv(x)+self.conv(x.permute(0,1,4,5,2,3)).permute(0,1,4,5,2,3) 150 | # because of the ReLU layers in between linear layers, 151 | # this operation is different than convolving a single time with the filters+filters^T 152 | # and therefore it makes sense to do this. 153 | else: 154 | x = self.conv(x) 155 | return x 156 | 157 | def MutualMatching(corr4d): 158 | # mutual matching 159 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() 160 | 161 | corr4d_B=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B] 162 | corr4d_A=corr4d.view(batch_size,fs1,fs2,fs3*fs4) 163 | 164 | # get max 165 | corr4d_B_max,_=torch.max(corr4d_B,dim=1,keepdim=True) 166 | corr4d_A_max,_=torch.max(corr4d_A,dim=3,keepdim=True) 167 | 168 | eps = 1e-5 169 | corr4d_B=corr4d_B/(corr4d_B_max+eps) 170 | corr4d_A=corr4d_A/(corr4d_A_max+eps) 171 | 172 | corr4d_B=corr4d_B.view(batch_size,1,fs1,fs2,fs3,fs4) 173 | corr4d_A=corr4d_A.view(batch_size,1,fs1,fs2,fs3,fs4) 174 | 175 | corr4d=corr4d*(corr4d_A*corr4d_B) # parenthesis are important for symmetric output 176 | return corr4d 177 | 178 | def MutualNorm(corr4d): 179 | # mutual matching 180 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() 181 | 182 | corr4d_B=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B] 183 | corr4d_A=corr4d.view(batch_size,fs1,fs2,fs3*fs4) 184 | 185 | # get max 186 | corr4d_B_max,_=torch.max(corr4d_B,dim=1,keepdim=True) 187 | corr4d_A_max,_=torch.max(corr4d_A,dim=3,keepdim=True) 188 | 189 | eps = 1e-5 190 | corr4d_B=corr4d_B/(corr4d_B_max+eps) 191 | corr4d_A=corr4d_A/(corr4d_A_max+eps) 192 | 193 | corr4d_B=corr4d_B.view(batch_size,1,fs1,fs2,fs3,fs4) 194 | corr4d_A=corr4d_A.view(batch_size,1,fs1,fs2,fs3,fs4) 195 | return (corr4d_A*corr4d_B) 196 | 197 | def maxpool4d(corr4d_hres,k_size=4): 198 | slices=[] 199 | for i in range(k_size): 200 | for j in range(k_size): 201 | for k in range(k_size): 202 | for l in range(k_size): 203 | sl = corr4d_hres[:,0,i::k_size,j::k_size,k::k_size,l::k_size].unsqueeze(0) 204 | slices.append(sl) 205 | 206 | slices=torch.cat(tuple(slices),dim=1) 207 | corr4d,max_idx=torch.max(slices,dim=1,keepdim=True) 208 | max_l=torch.fmod(max_idx,k_size) 209 | max_k=torch.fmod(max_idx.sub(max_l).div(k_size),k_size) 210 | max_j=torch.fmod(max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size),k_size) 211 | max_i=max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size).sub(max_j).div(k_size) 212 | # i,j,k,l represent the *relative* coords of the max point in the box of size k_size*k_size*k_size*k_size 213 | return (corr4d,max_i,max_j,max_k,max_l) 214 | 215 | class ImMatchNet(nn.Module): 216 | def __init__(self, 217 | feature_extraction_cnn='resnet101', 218 | feature_extraction_last_layer='', 219 | feature_extraction_model_file=None, 220 | return_correlation=False, 221 | ncons_kernel_sizes=[3,3,3], 222 | ncons_channels=[10,10,1], 223 | normalize_features=True, 224 | train_fe=False, 225 | use_cuda=True, 226 | relocalization_k_size=0, 227 | half_precision=False, 228 | checkpoint=None, 229 | ): 230 | 231 | super(ImMatchNet, self).__init__() 232 | # Load checkpoint 233 | if checkpoint is not None and checkpoint is not '': 234 | print('Loading checkpoint...') 235 | checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage) 236 | checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()]) 237 | # override relevant parameters 238 | print('Using checkpoint parameters: ') 239 | ncons_channels=checkpoint['args'].ncons_channels 240 | print(' ncons_channels: '+str(ncons_channels)) 241 | ncons_kernel_sizes=checkpoint['args'].ncons_kernel_sizes 242 | print(' ncons_kernel_sizes: '+str(ncons_kernel_sizes)) 243 | 244 | self.use_cuda = use_cuda 245 | self.normalize_features = normalize_features 246 | self.return_correlation = return_correlation 247 | self.relocalization_k_size = relocalization_k_size 248 | self.half_precision = half_precision 249 | 250 | self.FeatureExtraction = FeatureExtraction(train_fe=train_fe, 251 | feature_extraction_cnn=feature_extraction_cnn, 252 | feature_extraction_model_file=feature_extraction_model_file, 253 | last_layer=feature_extraction_last_layer, 254 | normalization=normalize_features, 255 | use_cuda=self.use_cuda) 256 | 257 | self.FeatureCorrelation = FeatureCorrelation(shape='4D',normalization=False) 258 | 259 | self.NeighConsensus = NeighConsensus(use_cuda=self.use_cuda, 260 | kernel_sizes=ncons_kernel_sizes, 261 | channels=ncons_channels) 262 | 263 | # Load weights 264 | if checkpoint is not None and checkpoint is not '': 265 | print('Copying weights...') 266 | for name, param in self.FeatureExtraction.state_dict().items(): 267 | if 'num_batches_tracked' not in name: 268 | self.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name]) 269 | for name, param in self.NeighConsensus.state_dict().items(): 270 | self.NeighConsensus.state_dict()[name].copy_(checkpoint['state_dict']['NeighConsensus.' + name]) 271 | print('Done!') 272 | 273 | self.FeatureExtraction.eval() 274 | 275 | if self.half_precision: 276 | for p in self.NeighConsensus.parameters(): 277 | p.data=p.data.half() 278 | for l in self.NeighConsensus.conv: 279 | if isinstance(l,Conv4d): 280 | l.use_half=True 281 | 282 | # used only for foward pass at eval and for training with strong supervision 283 | def forward(self, tnf_batch): 284 | # feature extraction 285 | feature_A = self.FeatureExtraction(tnf_batch['source_image']) 286 | feature_B = self.FeatureExtraction(tnf_batch['target_image']) 287 | if self.half_precision: 288 | feature_A=feature_A.half() 289 | feature_B=feature_B.half() 290 | 291 | # feature correlation 292 | corr4d = self.FeatureCorrelation(feature_A,feature_B) 293 | 294 | # do 4d maxpooling for relocalization 295 | if self.relocalization_k_size>1: 296 | corr4d,max_i,max_j,max_k,max_l=maxpool4d(corr4d,k_size=self.relocalization_k_size) 297 | 298 | # run match processing model 299 | corr4d = MutualMatching(corr4d) 300 | corr4d = self.NeighConsensus(corr4d) 301 | corr4d = MutualMatching(corr4d) 302 | 303 | if self.relocalization_k_size>1: 304 | delta4d=(max_i,max_j,max_k,max_l) 305 | return (corr4d,delta4d) 306 | else: 307 | return corr4d 308 | 309 | def forward_feat(self, featA, featB, normalize=True): 310 | # feature normalization 311 | if normalize: 312 | feature_A = featureL2Norm(featA) 313 | feature_B = featureL2Norm(featB) 314 | else: 315 | feature_A = featA 316 | feature_B = featB 317 | if self.half_precision: 318 | feature_A=feature_A.half() 319 | feature_B=feature_B.half() 320 | 321 | # feature correlation 322 | corr4d = self.FeatureCorrelation(feature_A,feature_B) 323 | # do 4d maxpooling for relocalization 324 | if self.relocalization_k_size>1: 325 | corr4d,max_i,max_j,max_k,max_l=maxpool4d(corr4d,k_size=self.relocalization_k_size) 326 | corr4d = MutualMatching(corr4d) 327 | corr4d = self.NeighConsensus(corr4d) 328 | corr4d = MutualMatching(corr4d) 329 | if self.relocalization_k_size>1: 330 | delta4d=(max_i,max_j,max_k,max_l) 331 | return (corr4d,delta4d) 332 | else: 333 | return corr4d 334 | -------------------------------------------------------------------------------- /networks/patch2pix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | import networks.resnet as resnet 7 | from networks.modules import * 8 | from networks.utils import select_local_patch_feats, filter_coarse 9 | from networks.ncn.model import MutualMatching, NeighConsensus 10 | from networks.ncn.extract_ncmatches import corr_to_matches 11 | 12 | class Patch2Pix(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | self.device = config.device 16 | self.backbone = config.backbone 17 | self.change_stride = config.change_stride 18 | self.upsample = 16 19 | self.feats_downsample = [1, 2, 2, 2, 2] 20 | feat_dims = [3, 64, 64, 128, 256] # Resnet34 block feature out dims 21 | 22 | # Initialize necessary network components 23 | self.extract = resnet.__dict__[self.backbone]() 24 | if self.change_stride: 25 | self.extract.change_stride(target='layer3') 26 | self.upsample //= 2 27 | self.feats_downsample[-1] = 1 28 | print(f'Initialize Patch2Pix: backbone={self.backbone} ' 29 | f'cstride={self.change_stride} upsample={self.upsample}') 30 | 31 | self.combine = FeatCorrelation(shape='4D') 32 | self.ncn = NeighConsensus(kernel_sizes=[3, 3], channels=[16, 1]) 33 | 34 | # Initialize regressor 35 | self.regressor_config = config.regressor_config 36 | if not self.regressor_config: 37 | # If no regressor defined, model only computes coarse matches 38 | self.regress_mid = None 39 | self.regress_fine = None 40 | else: 41 | print(f'Init regressor {self.regressor_config}') 42 | self.regr_batch = config.regr_batch 43 | self.feat_idx = config.feat_idx 44 | feat_dim = 0 # Regressor's input feature dim 45 | for idx in self.feat_idx: 46 | feat_dim += feat_dims[idx] 47 | self.regressor_config.feat_dim = feat_dim 48 | self.ptype = ['center', 'center'] 49 | self.psize = config.regressor_config.psize 50 | self.pshift = config.regressor_config.pshift 51 | self.panc = config.regressor_config.panc 52 | self.shared = config.regressor_config.shared 53 | self.regress_mid = FeatRegressNet(self.regressor_config, psize=self.psize[0]) 54 | if self.shared: 55 | self.regress_fine = self.regress_mid 56 | self.psize[1] = self.psize[0] 57 | else: 58 | self.regress_fine = FeatRegressNet(self.regressor_config, psize=self.psize[1]) 59 | 60 | self.to(self.device) 61 | self.init_weights_(weights_dict=config.weights_dict, pretrained=True) 62 | 63 | if config.training: 64 | self.freeze_feat = config.freeze_feat 65 | # Freeze (part of) the backbone 66 | print('Freezing feature extractor params upto layer {}'.format(self.freeze_feat)) 67 | for i, param in enumerate(self.extract.parameters()): 68 | # Resnet34 layer3=[48:87] blocks: 69 | # 0=[48:57] 1=[57:63] 2=[63:69] 70 | # 3=[69:75] 4=[75:81] 5=[81:87] 71 | if i < self.freeze_feat: 72 | param.requires_grad = False 73 | 74 | # Always freeze resnet layer4, since never used 75 | if i >= 87: 76 | param.requires_grad = False 77 | 78 | config.optim_config.start_epoch = config.start_epoch 79 | self.set_optimizer_(config.optim_config) 80 | 81 | def set_optimizer_(self, optim_config): 82 | params = [] 83 | if self.regress_mid: 84 | params += list(self.regress_mid.parameters()) 85 | if self.regress_fine and not self.shared: 86 | params += list(self.regress_fine.parameters()) 87 | params += list(self.ncn.parameters()) 88 | if self.freeze_feat < 87: 89 | params += list(self.extract.parameters())[self.freeze_feat:87] 90 | self.optimizer, self.lr_scheduler = init_optimizer(params, optim_config) 91 | print('Init optimizer, items: {}'.format(len(params))) 92 | 93 | def optim_step_(self, loss): 94 | self.optimizer.zero_grad() 95 | loss.backward() 96 | self.optimizer.step() 97 | 98 | def init_weights_(self, weights_dict=None, pretrained=True): 99 | print('Xavier initialize all model parameters') 100 | self.apply(xavier_init_func_) 101 | if pretrained: 102 | self.extract.load_pretrained_() 103 | if weights_dict: 104 | if len(weights_dict.items()) == len(self.state_dict()): 105 | print('Reload all model parameters from weights dict') 106 | self.load_state_dict(weights_dict) 107 | else: 108 | print('Reload part of model parameters from weights dict') 109 | self.load_state_dict(weights_dict, strict=False) 110 | 111 | def load_batch_(self, batch, dtype='pair'): 112 | im_src = batch['src_im'].to(self.device) 113 | im_pos = batch['pos_im'].to(self.device) 114 | Fs = batch['F'].to(self.device) 115 | if dtype == 'triplet': 116 | im_neg = batch['neg_im'].to(self.device) 117 | return im_src, im_pos, im_neg, Fs 118 | return im_src, im_pos, Fs 119 | 120 | def forward_coarse_match(self, feat1, feat2, ksize=1): 121 | # Feature normalization 122 | feat1 = L2Normalize(feat1, dim=1) 123 | feat2 = L2Normalize(feat2, dim=1) 124 | 125 | # Feature correlation 126 | corr4d = self.combine(feat1, feat2) 127 | 128 | # Do 4d maxpooling for relocalization 129 | delta4d = None 130 | if ksize > 1: 131 | corr4d, max_i, max_j, max_k, max_l = maxpool4d(corr4d, k_size=ksize) 132 | delta4d = (max_i,max_j,max_k,max_l) 133 | corr4d = MutualMatching(corr4d) 134 | corr4d = self.ncn(corr4d) 135 | corr4d = MutualMatching(corr4d) 136 | return corr4d, delta4d 137 | 138 | def parse_regressor_out(self, out, psize, ptype, imatches, max_val): 139 | w1, h1, w2, h2 = max_val 140 | offset = out[:, :4] # N, 4 141 | offset = psize * torch.tanh(nn.functional.relu(offset)) 142 | if ptype == 'center': 143 | shift = psize // 2 144 | offset -= shift 145 | fmatches = imatches.float() + offset 146 | io_probs = out[:, 4] 147 | io_probs = torch.sigmoid(io_probs) 148 | 149 | # Prevent out of range 150 | x1 = fmatches[:, 0].clamp(min=0, max=w1) 151 | y1 = fmatches[:, 1].clamp(min=0, max=h1) 152 | x2 = fmatches[:, 2].clamp(min=0, max=w2) 153 | y2 = fmatches[:, 3].clamp(min=0, max=h2) 154 | fmatches = torch.stack([x1, y1, x2, y2], dim=-1) 155 | return fmatches, io_probs 156 | 157 | def forward_fine_match_mini_batch(self, feats1, feats2, ibatch, imatches, 158 | psize, ptype, regressor): 159 | # ibatch: to index the feature map 160 | # imatches: input coarse matches, N, 4 161 | N = imatches.shape[0] 162 | _, _, h1, w1 = feats1[0].shape 163 | _, _, h2, w2 = feats2[0].shape 164 | max_val = [w1, h1, w2, h2] 165 | 166 | f1s, f2s, _, _ = select_local_patch_feats(feats1, feats2, 167 | ibatch, imatches, 168 | feat_idx=self.feat_idx, 169 | feats_downsample=self.feats_downsample, 170 | psize=psize, 171 | ptype=ptype) 172 | # Feature normalization 173 | f1s = L2Normalize(f1s, dim=0) # D, N*psize*psize 174 | f2s = L2Normalize(f2s, dim=0) # D, N*psize*psize 175 | 176 | # Reshaping: -> (D, N, psize, psize) -> (N, D, psize, psize) 177 | f1s = f1s.view(-1, N, psize, psize).permute(1, 0, 2, 3) 178 | f2s = f2s.view(-1, N, psize, psize).permute(1, 0, 2, 3) 179 | 180 | 181 | # From im1 to im2 182 | out = regressor(f1s, f2s) # N, 5 183 | fmatches, io_probs = self.parse_regressor_out(out, psize, ptype, imatches, max_val) 184 | return fmatches, io_probs 185 | 186 | def forward_fine_match(self, feats1, feats2, coarse_matches, 187 | psize, ptype, regressor): 188 | batch_size = self.regr_batch 189 | masks = [] 190 | fine_matches = [] 191 | for ibatch, imatches in enumerate(coarse_matches): 192 | # Use mini-batch if too many matches 193 | N = imatches.shape[0] 194 | if N > batch_size: 195 | batch_inds = [batch_size*i for i in range(N // batch_size + 1)] 196 | if batch_inds[-1] < N: 197 | if N - batch_inds[-1] == 1: 198 | # Special case, slicing leads to 1-dim missing 199 | batch_inds[-1] = N 200 | else: 201 | batch_inds += [N] 202 | fmatches = [] 203 | io_probs = [] 204 | for bi, (ist, ied) in enumerate(zip(batch_inds[0:-1], batch_inds[1::])): 205 | mini_results = self.forward_fine_match_mini_batch(feats1, feats2, 206 | ibatch, imatches[ist:ied], 207 | psize, ptype, regressor) 208 | fmatches.append(mini_results[0]) 209 | io_probs.append(mini_results[1]) 210 | fmatches = torch.cat(fmatches, dim=0).squeeze() 211 | io_probs = torch.cat(io_probs, dim=0).squeeze() 212 | else: 213 | fmatches, io_probs = self.forward_fine_match_mini_batch(feats1, feats2, 214 | ibatch, imatches, 215 | psize, ptype, regressor) 216 | fine_matches.append(fmatches) 217 | masks.append(io_probs) 218 | return fine_matches, masks 219 | 220 | def forward(self, im1, im2, ksize=1, return_feats=False): 221 | if return_feats: 222 | feat1s=[] 223 | feat2s=[] 224 | self.extract.forward_all(im1, feat1s, early_feat=True) 225 | self.extract.forward_all(im2, feat2s, early_feat=True) 226 | feat1 = feat1s[-1] 227 | feat2 = feat2s[-1] 228 | else: 229 | feat1 = self.extract(im1, early_feat=True) 230 | feat2 = self.extract(im2, early_feat=True) # Shared weights 231 | 232 | corr4d, delta4d = self.forward_coarse_match(feat1, feat2, ksize=ksize) 233 | 234 | if return_feats: 235 | return corr4d, delta4d, feat1s, feat2s 236 | else: 237 | return corr4d, delta4d 238 | 239 | 240 | def predict_coarse(self, im1, im2, ksize=2, ncn_thres=0.0, 241 | mutual=False, center=True): 242 | corr4d, delta4d = self.forward(im1, im2, ksize) 243 | coarse_matches, match_scores = self.cal_coarse_matches(corr4d, delta4d, ksize=ksize, 244 | upsample=self.upsample, center=center) 245 | 246 | # Filter coarse matches 247 | coarse_matches, match_scores = filter_coarse(coarse_matches, match_scores, ncn_thres, mutual) 248 | return coarse_matches, match_scores 249 | 250 | def predict_fine(self, im1, im2, ksize=2, ncn_thres=0.0, 251 | mutual=True, return_all=False): 252 | corr4d, delta4d, feats1, feats2 = self.forward(im1, im2, ksize=ksize, return_feats=True) 253 | coarse_matches, match_scores = self.cal_coarse_matches(corr4d, delta4d, ksize=ksize, 254 | upsample=self.upsample, center=True) 255 | # Filter coarse matches 256 | coarse_matches, match_scores = filter_coarse(coarse_matches, match_scores, ncn_thres, mutual) 257 | 258 | # Locate initial anchors 259 | coarse_matches = self.shift_to_anchors(coarse_matches) 260 | 261 | # Mid level matching 262 | mid_matches, mid_scores = self.forward_fine_match(feats1, feats2, 263 | coarse_matches, 264 | psize=self.psize[0], 265 | ptype=self.ptype[0], 266 | regressor=self.regress_mid) 267 | 268 | # Fine level matching 269 | fine_matches, fine_scores = self.forward_fine_match(feats1, feats2, 270 | mid_matches, 271 | psize=self.psize[1], 272 | ptype=self.ptype[1], 273 | regressor=self.regress_fine) 274 | if return_all: 275 | return fine_matches, fine_scores, mid_matches, mid_scores, coarse_matches 276 | return fine_matches, fine_scores, coarse_matches 277 | 278 | def refine_matches(self, im1, im2, coarse_matches, io_thres): 279 | # Handle empty coarse matches 280 | if len(coarse_matches) == 0: 281 | return np.empty((0, 4)), np.empty((0,)), np.empty((0, 4)) 282 | 283 | if type(coarse_matches) == np.ndarray: 284 | coarse_matches_ = torch.from_numpy(coarse_matches).to(self.device).unsqueeze(0) # 1, N, 4 285 | elif type(coarse_matches) == torch.Tensor: 286 | coarse_matches_ = coarse_matches.unsqueeze(0) # 1, N, 4 287 | coarse_matches = coarse_matches.cpu().data.numpy() 288 | 289 | # Extract local features 290 | feat1s=[] 291 | feat2s=[] 292 | self.extract.forward_all(im1, feat1s, early_feat=True) 293 | self.extract.forward_all(im2, feat2s, early_feat=True) 294 | 295 | # Mid level matching 296 | mid_matches, mid_scores = self.forward_fine_match(feat1s, feat2s, 297 | coarse_matches_, 298 | psize=self.psize[0], 299 | ptype=self.ptype[0], 300 | regressor=self.regress_mid) 301 | 302 | # Fine level matching 303 | fine_matches, fine_scores = self.forward_fine_match(feat1s, feat2s, 304 | mid_matches, 305 | psize=self.psize[1], 306 | ptype=self.ptype[1], 307 | regressor=self.regress_fine) 308 | refined_matches = fine_matches[0].cpu().data.numpy() 309 | scores = fine_scores[0].cpu().data.numpy() 310 | 311 | # Further filtering with threshold 312 | if io_thres > 0: 313 | pos_ids = np.where(scores > io_thres)[0] 314 | if len(pos_ids) > 0: 315 | coarse_matches = coarse_matches[pos_ids] 316 | refined_matches = refined_matches[pos_ids] 317 | scores = scores[pos_ids] 318 | return refined_matches, scores, coarse_matches 319 | 320 | def cal_coarse_score(self, corr4d, normalize='softmax'): 321 | if normalize is None: 322 | normalize = lambda x: x 323 | elif normalize == 'softmax': 324 | normalize = lambda x: nn.functional.softmax(x, 1) 325 | elif normalize == 'l1': 326 | normalize = lambda x: x / (torch.sum(x, dim=1, keepdim=True) + 0.0001) 327 | 328 | # Mutual matching score 329 | batch_size, _, h1, w1, h2, w2 = corr4d.shape 330 | nc_B_Avec=corr4d.view(batch_size, h1*w1, h2, w2) 331 | nc_A_Bvec=corr4d.view(batch_size, h1, w1, h2*w2).permute(0,3,1,2) # 332 | nc_B_Avec = normalize(nc_B_Avec) 333 | nc_A_Bvec = normalize(nc_A_Bvec) 334 | scores_B,_= torch.max(nc_B_Avec, dim=1) 335 | scores_A,_= torch.max(nc_A_Bvec, dim=1) 336 | scores_AB = torch.cat([scores_A.view(-1, h1*w1), scores_B.view(-1, h2*w2)], dim=1) 337 | score = scores_AB.mean() 338 | return score 339 | 340 | def cal_coarse_matches(self, corr4d, delta4d, ksize=1, do_softmax=True, 341 | upsample=16, sort=False, center=True, pshift=0): 342 | 343 | # Original nc implementation: only max locations 344 | (xA_, yA_, xB_, yB_, score_) = corr_to_matches(corr4d, delta4d=delta4d, 345 | do_softmax=do_softmax, 346 | ksize=ksize) 347 | (xA2_, yA2_, xB2_, yB2_, score2_) = corr_to_matches(corr4d, delta4d=delta4d, 348 | do_softmax=do_softmax, 349 | ksize=ksize, 350 | invert_matching_direction=True) 351 | xA_ = torch.cat((xA_, xA2_), 1) 352 | yA_ = torch.cat((yA_, yA2_), 1) 353 | xB_ = torch.cat((xB_, xB2_), 1) 354 | yB_ = torch.cat((yB_, yB2_), 1) 355 | score_ = torch.cat((score_, score2_),1) 356 | 357 | # Sort as descend 358 | if sort: 359 | sorted_index = torch.sort(-score_)[1] 360 | xA_ = torch.gather(xA_, 1, sorted_index) # B, 1, N 361 | yA_ = torch.gather(yA_, 1, sorted_index) 362 | xB_ = torch.gather(xB_, 1, sorted_index) 363 | yB_ = torch.gather(yB_, 1, sorted_index) 364 | score_ = torch.gather(score_, 1, sorted_index) # B, N 365 | 366 | xA_ = xA_.unsqueeze(1) 367 | yA_ = yA_.unsqueeze(1) 368 | xB_ = xB_.unsqueeze(1) 369 | yB_ = yB_.unsqueeze(1) 370 | # Create matches and upscale to input resolution 371 | matches_ = upsample * torch.cat([xA_, yA_, xB_, yB_], dim=1).permute(0, 2, 1) # B, N, 4 372 | if center: 373 | delta = upsample // 2 374 | matches_ += torch.tensor([[delta, delta, delta, delta]]).unsqueeze(0).to(matches_) 375 | return matches_, score_ 376 | 377 | def shift_to_anchors(self, matches): 378 | pshift = self.pshift 379 | panc = self.panc 380 | if panc == 1: 381 | return matches 382 | 383 | # Move pt1/pt2 to its upper-left, upper-right, down-left, down-right 384 | # location by pshift, leading to 4 corner anchors 385 | # Then take center vs corner from two directions as new matches 386 | shift_template = torch.tensor([ 387 | [-pshift, -pshift, 0, 0], 388 | [pshift, -pshift, 0, 0], 389 | [-pshift, pshift, 0, 0], 390 | [pshift, pshift, 0, 0], 391 | [0, 0, -pshift, -pshift], 392 | [0, 0, pshift, -pshift], 393 | [0, 0, -pshift, pshift], 394 | [0, 0, pshift, pshift] 395 | ]).to(self.device) 396 | 397 | matches_ = [] 398 | for imatches in matches: 399 | imatches = imatches.unsqueeze(1) + shift_template # N, 16, 4 400 | imatches = imatches.reshape(-1, 4) 401 | matches_.append(imatches) 402 | return matches_ 403 | -------------------------------------------------------------------------------- /networks/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is an adapted version of 3 | https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | The goal is to keep ResNet* as only feature extractor, 5 | so the code can be used independent of the types of specific tasks, 6 | i.e., classification or regression. 7 | """ 8 | import torch.nn as nn 9 | import torch.utils.model_zoo as model_zoo 10 | from collections import OrderedDict 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=1, bias=False) 17 | 18 | class BasicBlock(nn.Module): 19 | expansion = 1 20 | 21 | def __init__(self, inplanes, planes, stride=1, downsample=None): 22 | super(BasicBlock, self).__init__() 23 | self.conv1 = conv3x3(inplanes, planes, stride) 24 | self.bn1 = nn.BatchNorm2d(planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | self.conv2 = conv3x3(planes, planes) 27 | self.bn2 = nn.BatchNorm2d(planes) 28 | self.downsample = downsample 29 | self.stride = stride 30 | 31 | def forward(self, x): 32 | residual = x 33 | 34 | out = self.conv1(x) 35 | out = self.bn1(out) 36 | out = self.relu(out) 37 | 38 | out = self.conv2(out) 39 | out = self.bn2(out) 40 | 41 | if self.downsample is not None: 42 | residual = self.downsample(x) 43 | 44 | out += residual 45 | out = self.relu(out) 46 | 47 | return out 48 | 49 | class Bottleneck(nn.Module): 50 | expansion = 4 51 | 52 | def __init__(self, inplanes, planes, stride=1, downsample=None): 53 | super(Bottleneck, self).__init__() 54 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 55 | self.bn1 = nn.BatchNorm2d(planes) 56 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 57 | padding=1, bias=False) 58 | self.bn2 = nn.BatchNorm2d(planes) 59 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 60 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.downsample = downsample 63 | self.stride = stride 64 | 65 | def forward(self, x): 66 | residual = x 67 | 68 | out = self.conv1(x) 69 | out = self.bn1(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv2(out) 73 | out = self.bn2(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv3(out) 77 | out = self.bn3(out) 78 | 79 | if self.downsample is not None: 80 | residual = self.downsample(x) 81 | 82 | out += residual 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | class ResNet(nn.Module): 88 | PRETRAINED_URLs = { 89 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 90 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 91 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 92 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 93 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 94 | } 95 | 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def _build_model(self, block, layers): 100 | self.inplanes = 64 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 109 | 110 | def _make_layer(self, block, planes, blocks, stride=1): 111 | downsample = None 112 | if stride != 1 or self.inplanes != planes * block.expansion: 113 | downsample = nn.Sequential( 114 | nn.Conv2d(self.inplanes, planes * block.expansion, 115 | kernel_size=1, stride=stride, bias=False), 116 | nn.BatchNorm2d(planes * block.expansion), 117 | ) 118 | layers = [] 119 | layers.append(block(self.inplanes, planes, stride, downsample)) 120 | self.inplanes = planes * block.expansion 121 | for i in range(1, blocks): 122 | layers.append(block(self.inplanes, planes)) 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x, early_feat=False): 126 | x = self.conv1(x) 127 | x = self.bn1(x) 128 | x = self.relu(x) 129 | x = self.maxpool(x) 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x = self.layer3(x) 133 | if early_feat: 134 | return x 135 | x = self.layer4(x) 136 | return x 137 | 138 | def forward_all(self, x, feat_list=[], early_feat=True): 139 | feat_list.append(x) 140 | x = self.conv1(x) 141 | x = self.bn1(x) 142 | x = self.relu(x) 143 | feat_list.append(x) 144 | 145 | x = self.maxpool(x) 146 | x = self.layer1(x) 147 | feat_list.append(x) 148 | 149 | x = self.layer2(x) 150 | feat_list.append(x) 151 | 152 | x = self.layer3(x) 153 | feat_list.append(x) 154 | 155 | if not early_feat: 156 | x = self.layer4(x) 157 | feat_list.append(x) 158 | 159 | def load_pretrained_(self, ignore='fc'): 160 | print('Initialize ResNet using pretrained model from {}'.format(self.pretrained_url)) 161 | state_dict = model_zoo.load_url(self.pretrained_url) 162 | new_state_dict = OrderedDict() 163 | for k, v in state_dict.items(): 164 | if ignore in k: 165 | continue 166 | new_state_dict[k] = v 167 | self.load_state_dict(new_state_dict) 168 | 169 | def change_stride(self, target='layer3'): 170 | layer = getattr(self, target) 171 | layer[0].conv1.stride = (1, 1) 172 | layer[0].conv2.stride = (1, 1) 173 | layer[0].downsample[0].stride = (1, 1) 174 | 175 | class ResNet34(ResNet): 176 | def __init__(self): 177 | super().__init__() 178 | self.pretrained_url = self.PRETRAINED_URLs['resnet34'] 179 | self._build_model(BasicBlock, [3, 4, 6, 3]) 180 | 181 | class ResNet50(ResNet): 182 | def __init__(self): 183 | super().__init__() 184 | self.pretrained_url = self.PRETRAINED_URLs['resnet50'] 185 | self._build_model(Bottleneck, [3, 4, 6, 3]) 186 | 187 | class ResNet101(ResNet): 188 | def __init__(self): 189 | super().__init__() 190 | self.pretrained_url = self.PRETRAINED_URLs['resnet101'] 191 | self._build_model(Bottleneck, [3, 4, 23, 3]) 192 | -------------------------------------------------------------------------------- /networks/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def select_local_patch_feats(feats1, feats2, ibatch, imatches, 5 | feat_idx=[1, 2, 3, 4], 6 | feats_downsample=[1, 2, 2, 2, 2], 7 | psize=16, ptype='center'): 8 | dy, dx = torch.meshgrid(torch.arange(psize), torch.arange(psize)) 9 | dx = dx.flatten().view(1, -1).to(imatches.device) 10 | dy = dy.flatten().view(1, -1).to(imatches.device) 11 | 12 | if ptype == 'center': 13 | shift = psize // 2 14 | dy -= shift 15 | dx -= shift 16 | 17 | _, _, h1, w1 = feats1[0].shape 18 | _, _, h2, w2 = feats2[0].shape 19 | x1, y1, x2, y2 = imatches.permute(1, 0).long() 20 | 21 | # Ids for local patch 22 | get_x_pids = lambda x, w, ds: ((x.view(-1, 1) + dx).view(-1) // ds).long().clamp(min=0, max=w//ds-1) 23 | get_y_pids = lambda y, h, ds: ((y.view(-1, 1) + dy).view(-1) // ds).long().clamp(min=0, max=h//ds-1) 24 | 25 | # Collect features for local matches 26 | f1s, f2s = [], [] 27 | for j, (fmap1, fmap2) in enumerate(zip(feats1, feats2)): 28 | if j not in feat_idx: 29 | continue 30 | ds = np.prod(feats_downsample[0:j+1]) 31 | f1s.append(fmap1[ibatch, :, get_y_pids(y1, h1, ds), get_x_pids(x1, w1, ds)]) 32 | f2s.append(fmap2[ibatch, :, get_y_pids(y2, h2, ds), get_x_pids(x2, w2, ds)]) 33 | 34 | f1s = torch.cat(f1s, dim=0) # D, N*16 35 | f2s = torch.cat(f2s, dim=0) # D, N*16 36 | return f1s, f2s, dx.squeeze(), dy.squeeze() 37 | 38 | def filter_coarse(coarse_matches, match_scores, ncn_thres=0.0, mutual=True, ptmax=None): 39 | matches = [] 40 | scores = [] 41 | for imatches, iscores in zip(coarse_matches, match_scores): 42 | _, ids, counts = np.unique(imatches.cpu().data.numpy(), axis=0, return_index=True, return_counts=True) 43 | if mutual: 44 | # Consider only if they are multual consistant 45 | ids = ids[counts > 1] 46 | #print(len(imatches), len(ids)) 47 | 48 | if len(ids) > 0: 49 | iscores = iscores[ids] 50 | imatches = imatches[ids] 51 | 52 | # NC score filtering 53 | ids = torch.nonzero(iscores.flatten() > ncn_thres, as_tuple=False).flatten() 54 | 55 | # Cut or fill upto ptmax for memory control 56 | if ptmax: 57 | if len(ids) == 0: 58 | # insert a random match 59 | ids = torch.tensor([0, 0, 0, 0]).long() 60 | iids = np.arange(len(ids)) 61 | np.random.shuffle(iids) 62 | iids = np.tile(iids, (ptmax // len(ids) + 1))[:ptmax] 63 | ids = ids[iids] 64 | 65 | if len(ids) > 0: 66 | iscores = iscores[ids] 67 | imatches = imatches[ids] 68 | 69 | matches.append(imatches) 70 | scores.append(iscores) 71 | 72 | return matches, scores 73 | 74 | def sym_epi_dist(matches, F, sqrt=True, eps=1e-8): 75 | # matches: Nx4 76 | # F: 3x3 77 | N = matches.shape[0] 78 | matches = matches.to(F) 79 | ones = torch.ones((N,1)).to(F) 80 | p1 = torch.cat([matches[:, 0:2] , ones], dim=1) 81 | p2 = torch.cat([matches[:, 2:4] , ones], dim=1) 82 | 83 | # l2=F*x1, l1=F^T*x2 84 | l2 = F.matmul(p1.transpose(1, 0)) # 3,N 85 | l1 = F.transpose(1, 0).matmul(p2.transpose(1, 0)) 86 | dd = (l2.transpose(1, 0) * p2).sum(dim=1) 87 | 88 | sqrt = False 89 | if sqrt: 90 | d = dd.abs() * (1.0 / (eps + l1[0, :] ** 2 + l1[1, :] ** 2).sqrt() + 1.0 / (eps + l2[0, :] ** 2 + l2[1, :] ** 2).sqrt()) 91 | else: 92 | d = dd ** 2 * (1.0 / (eps + l1[0, :] ** 2 + l1[1, :] ** 2) + 1.0 / (eps + l2[0, :] ** 2 + l2[1, :] ** 2)) 93 | return d.float() 94 | 95 | def sampson_dist(matches, F, eps=1e-8): 96 | # First-order approximation to reprojection error 97 | # matches: Nx4 98 | # F: 3x3 99 | N = matches.shape[0] 100 | matches = matches.to(F) 101 | ones = torch.ones((N,1)).to(F) 102 | p1 = torch.cat([matches[:, 0:2] , ones], dim=1) 103 | p2 = torch.cat([matches[:, 2:4] , ones], dim=1) 104 | 105 | # l2=F*x1, l1=F^T*x2 106 | l2 = F.matmul(p1.transpose(1, 0)) # 3,N 107 | l1 = F.transpose(1, 0).matmul(p2.transpose(1, 0)) 108 | dd = (l2.transpose(1, 0) * p2).sum(dim=1) 109 | d = dd ** 2 / (eps + l1[0, :] ** 2 + l1[1, :] ** 2 + l2[0, :] ** 2 + l2[1, :] ** 2) 110 | return d.float() 111 | 112 | -------------------------------------------------------------------------------- /patch2pix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GrumpyZhou/patch2pix/0777495b2cff1e876b16cba7f1f7cfcde400bef5/patch2pix.png -------------------------------------------------------------------------------- /pretrained/download.sh: -------------------------------------------------------------------------------- 1 | # Patch2pix pretrained - https://drive.google.com/file/d/1ZbJIE0LcZ3Oti-h8zU72JRL7LryVxJOh/view?usp=drive_link 2 | gdown 1ZbJIE0LcZ3Oti-h8zU72JRL7LryVxJOh 3 | 4 | # NCNet IVD - https://drive.google.com/file/d/10GZ0x3CmObKzbAg1GKQhrkPeSLRpD4Rp/view?usp=drive_link 5 | gdown 10GZ0x3CmObKzbAg1GKQhrkPeSLRpD4Rp 6 | -------------------------------------------------------------------------------- /train_patch2pix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | import os 4 | import time 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.utils.data as data 9 | 10 | from utils.datasets import ImMatchDatasetMega 11 | from utils.train.helper import * 12 | from utils.train.eval_epoch_immatch import eval_immatch_val_sets 13 | from utils.common.setup_helper import * 14 | from networks.utils import sampson_dist, filter_coarse 15 | from networks.patch2pix import Patch2Pix 16 | 17 | def parse_agrs(): 18 | parser = argparse.ArgumentParser(description='Train Patch2Pix Matching Network') 19 | parser.add_argument('--gpu', '-gpu', type=int, default=0) 20 | parser.add_argument('--seed', type=int, default=1) 21 | parser.add_argument('--epochs', type=int, default=100) 22 | parser.add_argument('--save_step', type=int, default=1) 23 | parser.add_argument('--plot_counts', type=int, default=5) 24 | parser.add_argument('--batch', type=int, default=8) 25 | parser.add_argument('--regr_batch', type=int, default=1200) 26 | parser.add_argument('--visdom_host', '-vh', type=str, default=None) 27 | parser.add_argument('--visdom_port', '-vp', type=str, default=None) 28 | parser.add_argument('--prefix', type=str, default='') 29 | parser.add_argument('--out_dir', '-o', type=str, default='output/patch2pix') 30 | 31 | # Data loading config 32 | parser.add_argument('--dataset', type=str, default='MegaDepth') 33 | parser.add_argument('--data_root', type=str, default='data') 34 | parser.add_argument('--pair_root', type=str, default='data_pairs') 35 | parser.add_argument( 36 | '--match_npy', type=str, 37 | default='megadepth_pairs.ov0.35_imrat1.5.pair500.excl_test.npy' 38 | ) 39 | 40 | # Model architecture 41 | parser.add_argument('--backbone', type=str, default='ResNet34') 42 | parser.add_argument('--change_stride', action='store_true') 43 | parser.add_argument('--ksize', type=int, default=2) 44 | parser.add_argument('--freeze_feat', type=int, default=87) 45 | parser.add_argument('--feat_idx', type=int, nargs='*', default=[0, 1, 2, 3]) 46 | parser.add_argument('--feat_comb', type=str, default='pre') 47 | parser.add_argument('--conv_kers', type=int, nargs='*', default=[3, 3]) 48 | parser.add_argument('--conv_dims', type=int, nargs='*', default=[512, 512]) 49 | parser.add_argument('--conv_strs', type=int, nargs='*', default=[2, 1]) 50 | parser.add_argument('--fc_dims', type=int, nargs='*', default=[512, 256]) 51 | parser.add_argument('--psize', type=int, nargs=2, default=[16, 16]) 52 | parser.add_argument('--pshift', type=int, default=8) 53 | parser.add_argument('--panc', type=int, choices=[8, 1], default=8) 54 | parser.add_argument('--ptmax', type=int, default=400) 55 | parser.add_argument('--shared', action='store_true') 56 | 57 | # Matching thresholds 58 | parser.add_argument('--cthres', type=float, default=0.5) 59 | parser.add_argument('--cls_dthres', type=int, nargs=2, default=[50, 5]) 60 | parser.add_argument('--epi_dthres', type=int, nargs=2, default=[50, 5]) 61 | 62 | # Model intialize 63 | parser.add_argument('--pretrain', type=str, default=None) 64 | parser.add_argument('--ckpt', type=str, default=None) 65 | parser.add_argument('--resume', action='store_true') # Auto load last cpkt 66 | 67 | # Optimization 68 | parser.add_argument('--lr_init', '-lr', metavar='%f', type=float, default=5e-4) 69 | parser.add_argument('--lr_decay', '-lrd', metavar='%s[type] %f[*factor] %d[*step]', nargs='*', default=None) # Opt: 'step' 'multistep' 70 | parser.add_argument('--weight_decay', '-wd', metavar='%f', type=float, default=0) 71 | parser.add_argument('--weight_cls', '-wcls', metavar='%f', type=float, default=10.0) 72 | parser.add_argument('--weight_epi', '-wepi', metavar='%f[fine] %f[mid]', type=float, nargs='*', default=[1, 1]) 73 | 74 | args = parser.parse_args() 75 | return args 76 | 77 | def train_epoch(epoch, net, train_loader, train_vis, args, lprint_): 78 | net.train() 79 | train_vis.clear() # Clearn visdom plot data per epoch 80 | plot_step = len(train_loader) // args.plot_counts 81 | 82 | # Setup threshold params 83 | net.panc = args.panc 84 | ksize = args.ksize 85 | cthres, cls_dthres, epi_dthres = args.cthres, args.cls_dthres, args.epi_dthres 86 | cls_loss_weight = args.weight_cls 87 | efine_weight, emid_weight = args.weight_epi 88 | 89 | # Start training 90 | skipped = 0 91 | lprint_(f'ksize={ksize} cthres={cthres} cls_dthres={cls_dthres} ' 92 | f'epi_dthre={epi_dthres} ptmax={args.ptmax} panc={net.panc}') 93 | for i, batch in enumerate(train_loader): 94 | im_src, im_pos, Fs = net.load_batch_(batch, dtype='pair') 95 | 96 | # Estimate patch-level matches 97 | corr4d, delta4d, feats1, feats2 = net.forward(im_src, im_pos, ksize=ksize, return_feats=True) 98 | coarse_matches, match_scores = net.cal_coarse_matches(corr4d, delta4d, ksize=ksize, upsample=net.upsample, center=True) 99 | 100 | if net.panc > 1 and args.ptmax > 0: 101 | coarse_matches, match_scores = filter_coarse(coarse_matches, match_scores, 0.0, True, ptmax=args.ptmax) 102 | 103 | # Coarse matches to locate anchors 104 | coarse_matches = net.shift_to_anchors(coarse_matches) 105 | 106 | # Mid level matching on positive pairs 107 | mid_matches, mid_probs = net.forward_fine_match(feats1, feats2, 108 | coarse_matches, 109 | psize=net.psize[0], 110 | ptype=net.ptype[0], 111 | regressor=net.regress_mid) 112 | 113 | # Fine level matching based on mid matches 114 | fine_matches, fine_probs = net.forward_fine_match(feats1, feats2, 115 | mid_matches, 116 | psize=net.psize[1], 117 | ptype=net.ptype[1], 118 | regressor=net.regress_fine) 119 | # Calculate per pair losses 120 | cls_batch_lss = [] 121 | epi_batch_lss = [] 122 | for F, cmat, mmat, fmat, mcls_pred, fcls_pred in zip(Fs, coarse_matches, 123 | mid_matches, fine_matches, 124 | mid_probs, fine_probs): 125 | N = len(cmat) 126 | 127 | # Classification gt based on coarse matches 128 | cdist = net.geo_dist_fn(cmat, F) 129 | mdist = net.geo_dist_fn(mmat, F) 130 | fdist = net.geo_dist_fn(fmat, F) 131 | ones = torch.ones_like(cdist) 132 | zeros = torch.zeros_like(cdist) 133 | 134 | # Classification loss 135 | mcls_pos = torch.where(cdist < cls_dthres[0], ones, zeros) 136 | fcls_pos = torch.where(mdist < cls_dthres[1], ones, zeros) 137 | mcls_neg = 1 - mcls_pos 138 | fcls_neg = 1 - fcls_pos 139 | 140 | if mcls_pos.sum() == 0 or fcls_pos.sum() == 0: 141 | skipped += 1 142 | continue 143 | 144 | mcls_weights = mcls_neg.sum() / mcls_pos.sum() * mcls_pos + mcls_neg 145 | mcls_lss = nn.functional.binary_cross_entropy(mcls_pred, mcls_pos, reduction='none') 146 | mcls_lss = (mcls_weights * mcls_lss).mean() 147 | 148 | fcls_weights = fcls_neg.sum() / fcls_pos.sum() * fcls_pos + fcls_neg 149 | fcls_lss = nn.functional.binary_cross_entropy(fcls_pred, fcls_pos, reduction='none') 150 | fcls_lss = (fcls_weights * fcls_lss).mean() 151 | 152 | cls_lss = mcls_lss + fcls_lss 153 | cls_batch_lss.append(cls_lss) 154 | 155 | # Plot cls metric 156 | plot_cls_metric(mcls_pred, mcls_pos, cthres, train_vis.plots.cls_mid) 157 | plot_cls_metric(fcls_pred, fcls_pos, cthres, train_vis.plots.cls_fine) 158 | 159 | # Plot statis 160 | train_vis.plots.cls_ratios.mpos_gt.append(mcls_pos.sum().item() / N) 161 | train_vis.plots.cls_ratios.fpos_gt.append(fcls_pos.sum().item() / N) 162 | train_vis.plots.loss.cls_mid.append(mcls_lss.item()) 163 | train_vis.plots.loss.cls_fine.append(fcls_lss.item()) 164 | 165 | # Epipolar loss 166 | mids_gt = torch.where(cdist < epi_dthres[0], ones, zeros).nonzero(as_tuple=False).flatten() 167 | fids_gt = torch.where(mdist < epi_dthres[1], ones, zeros).nonzero(as_tuple=False).flatten() 168 | #lprint_(f'{len(mdist)} {len(mids_gt)} {len(fdist)} {len(fids_gt)}') 169 | 170 | if len(fids_gt) == 0 and len(mids_gt) == 0: 171 | skipped += 1 172 | continue 173 | 174 | epi_mid = mdist[mids_gt].mean() if len(mids_gt) > 0 else torch.tensor(0).to(mdist) 175 | epi_fine = fdist[fids_gt].mean() if len(fids_gt) > 0 else torch.tensor(0).to(fdist) 176 | epi_lss = emid_weight * epi_mid + efine_weight * epi_fine 177 | epi_batch_lss.append(epi_lss) 178 | 179 | # Plot epi dists 180 | if len(mids_gt) > 0: 181 | train_vis.plots.loss.epi_mid.append(epi_mid.item()) 182 | train_vis.plots.match_dist.mmid_gt.append(epi_mid.item()) 183 | train_vis.plots.match_dist.cmid_gt.append(cdist[mids_gt].mean().item()) 184 | 185 | if len(fids_gt) > 0: 186 | train_vis.plots.loss.epi_fine.append(epi_fine.item()) 187 | train_vis.plots.match_dist.ffid_gt.append(epi_fine.item()) 188 | train_vis.plots.match_dist.mfid_gt.append(mdist[fids_gt].mean().item()) 189 | 190 | # Total loss 191 | cls_loss = torch.stack(cls_batch_lss).mean() if len(cls_batch_lss) > 0 else torch.tensor(0.0, requires_grad=True).to(net.device) 192 | epi_loss = torch.stack(epi_batch_lss).mean() if len(epi_batch_lss) > 0 else torch.tensor(0.0, requires_grad=True).to(net.device) 193 | loss = cls_loss_weight * cls_loss + epi_loss 194 | train_vis.plots.loss.pair.append(loss.item()) 195 | 196 | # Optimize 197 | net.optim_step_(loss) 198 | 199 | # Monitor memory usage 200 | rss, vms = get_sys_mem() 201 | train_vis.plots.mem.rss.append(rss) 202 | train_vis.plots.mem.vms.append(vms) 203 | 204 | gpu_maloc, gpu_mres = get_gpu_mem() 205 | train_vis.plots.mem.gpu_maloc.append(gpu_maloc) 206 | train_vis.plots.mem.gpu_mres.append(gpu_mres) 207 | torch.cuda.empty_cache() 208 | 209 | # Update plots periocially 210 | if i % plot_step == 0 and i > 0: 211 | train_vis.plot(epoch=epoch + (i / len(train_loader))) 212 | lprint_('Batch:{} Loss:{}'.format(i, train_vis.get_plot_print(train_vis.plots.loss))) 213 | lprint_('Cls_mid:{}'.format(train_vis.get_plot_print(train_vis.plots.cls_mid))) 214 | lprint_('Cls_fine:{}'.format(train_vis.get_plot_print(train_vis.plots.cls_fine))) 215 | lprint_('Match:{}\n'.format(train_vis.get_plot_print(train_vis.plots.match_dist))) 216 | 217 | # Always update plots in the end of an epoch 218 | train_vis.plot(epoch=epoch + 1) 219 | lprint_('>Epoch:{} Skipped:{} Loss:{}'.format(epoch + 1, skipped, train_vis.get_plot_print(train_vis.plots.loss))) 220 | lprint_('Cls_mid:{}'.format(train_vis.get_plot_print(train_vis.plots.cls_mid))) 221 | lprint_('Cls_fine:{}'.format(train_vis.get_plot_print(train_vis.plots.cls_fine))) 222 | lprint_('Match:{}'.format(train_vis.get_plot_print(train_vis.plots.match_dist))) 223 | 224 | def main(): 225 | np.set_printoptions(precision=3) 226 | args = parse_agrs() 227 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 228 | make_deterministic(args.seed) 229 | print(args) 230 | 231 | # Init data loader 232 | match_npy_name = args.match_npy 233 | match_npy = os.path.join(args.pair_root, match_npy_name) 234 | pair_type = match_npy_name.replace('megadepth_pairs.', '').replace('_imrat1.5','').replace('.npy','') 235 | data_tag = 'Mega.' + pair_type 236 | train_set = ImMatchDatasetMega(args.data_root, match_npy, wt=480, ht=320) 237 | train_loader = data.DataLoader(train_set, batch_size=args.batch, shuffle=True) 238 | 239 | # Init output dir names 240 | if args.prefix is not '': 241 | odir_tag = args.prefix + '.' + data_tag 242 | else: 243 | odir_tag = data_tag 244 | odir_tag += '.freeze{}'.format(args.freeze_feat) 245 | if args.change_stride: 246 | odir_tag += '.cs' 247 | if args.pretrain: 248 | odir_tag += '.pretrain' 249 | 250 | # fe1234nc0.9ep50-100_cls{}lr5e-4.. 251 | feat_tag = 'ks{}fe{}'.format(args.ksize, ''.join([str(v) for v in args.feat_idx])) 252 | thres_tag = 'ep{}-{}cls{}-{}'.format(args.epi_dthres[0], args.epi_dthres[1], 253 | args.cls_dthres[0], args.cls_dthres[1]) 254 | train_tag = '_wcls{}wepi{}-{}.lr{}'.format(args.weight_cls, args.weight_epi[0], 255 | args.weight_epi[1], args.lr_init) 256 | 257 | if args.weight_decay > 0: 258 | train_tag += 'wd{}'.format(args.weight_decay) 259 | if args.lr_decay: 260 | decay_type = args.lr_decay[0] 261 | if decay_type == 'step': 262 | train_tag = '{}lrst{}-{}'.format(train_tag, args.lr_decay[1], args.lr_decay[2]) 263 | elif decay_type == 'multistep': 264 | train_tag = '{}lrms{}-{}'.format(train_tag, args.lr_decay[1], args.lr_decay[2]) 265 | 266 | exp_tag = '{}{}{}'.format(feat_tag, thres_tag, train_tag) 267 | 268 | # Regressor 269 | regress_tag = '{}{}_conv{}dim{}str{}fc{}_psz{}-{}a{}'.format(args.feat_comb, args.ptmax, 270 | ''.join(map(str, args.conv_kers)), 271 | '-'.join(map(str, args.conv_dims)), 272 | '-'.join(map(str, args.conv_strs)), 273 | '-'.join(map(str, args.fc_dims)), 274 | args.psize[0], args.psize[1], 275 | args.panc) 276 | if args.shared: 277 | regress_tag += '.shared' 278 | 279 | # Create output dirs and log file 280 | out_dir = os.path.join(args.out_dir, odir_tag, exp_tag, regress_tag) 281 | args.out_dir = out_dir 282 | if not os.path.exists(out_dir): 283 | os.makedirs(out_dir) 284 | log = open(os.path.join(out_dir, 'log.txt'), 'a') 285 | lprint_ = lambda ms: lprint(ms, log) 286 | lprint_(config2str(args)) 287 | lprint_('Log dir {}'.format(out_dir)) 288 | lprint_(f'>>>Load dataset:{data_tag}, train:{len(train_loader.dataset)}') 289 | 290 | 291 | # Initialize visdom 292 | env = '{}.{}_{}{}_{}'.format(odir_tag, feat_tag, thres_tag, train_tag, regress_tag) 293 | server = args.visdom_host 294 | port = args.visdom_port 295 | lprint_('>>Visdom server: {} port: {} env: {}'.format(server, port, env)) 296 | train_vis = get_visdom_plots(prefix='train', env=env, server=server, port=port) 297 | test_vis = get_visdom_plots(prefix='test', env=env, server=server, port=port) 298 | 299 | # Initialize model 300 | config, best_vals = init_model_config(args, lprint_) 301 | config.freeze_nc = True 302 | net = Patch2Pix(config) 303 | if args.weight_epi[0] == 0: 304 | lprint_('Freeze regressor_mid ...') 305 | for param in net.regress_mid.parameters(): 306 | param.requires_grad = False 307 | 308 | lprint_('Params backboone={} ncn={} regress_mid={} regress_fine={}'.format( 309 | count_parameters(net.extract), 310 | count_parameters(net.ncn), 311 | count_parameters(net.regress_mid), 312 | count_parameters(net.regress_fine) 313 | )) 314 | lprint_('Set geo dist: sampson distance') 315 | net.geo_dist_fn = sampson_dist 316 | 317 | 318 | # Training and validation 319 | t0 = time.time() 320 | lprint_('Start training from {} to {} ..'.format(config.start_epoch, args.epochs)) 321 | for epoch in range(config.start_epoch, args.epochs): 322 | 323 | # Always train on normally matching pairs 324 | lprint_('\n>>>Epoch {} training...'.format(epoch+1)) 325 | lprint_('>>>Current_lr={}\n'.format(net.optimizer.param_groups[0]['lr'])) 326 | t1 = time.time() 327 | train_epoch(epoch, net, train_loader, train_vis, args, lprint_) 328 | lprint_('Epoch training time: {:.2f}s'.format(time.time() - t1)) 329 | 330 | # Validation 331 | net.panc = 1 # Hard set topk to 1 332 | lprint_(f'Validation setting: panc={net.panc}') 333 | 334 | # Always save last ckpt 335 | save_ckpt(net, epoch, out_dir, best_vals=best_vals, last_ckpt=True) 336 | 337 | # Save model periodically 338 | if (epoch + 1) % args.save_step == 0: 339 | save_ckpt(net, epoch, out_dir, best_vals=best_vals) 340 | 341 | 342 | # Eval immatch 343 | try: 344 | res = eval_immatch_val_sets(net, 345 | data_root=f'{args.data_root}/immatch_benchmark/val_dense', 346 | ksize=2, imsize=1024, 347 | eval_type='fine', io_thres=0.5, 348 | sample_max=150, lprint_=lprint_) 349 | 350 | # Save the best model based on immatch 351 | qt_err, pass_rate = res 352 | rate = 0.34 * pass_rate[0] + 0.33 * pass_rate[4] + 0.33 * pass_rate[9] # % < 1/5/10 px 353 | if qt_err < best_vals[2] or rate > best_vals[3]: 354 | if qt_err < best_vals[2]: 355 | best_vals[2] = qt_err 356 | if rate > best_vals[3]: 357 | best_vals[3] = rate 358 | save_ckpt(net, epoch, out_dir, best_vals=best_vals, name='immatch_best_ckpt') 359 | lprint_('>>Save best immatch model: epoch={} qt={:.3f} rate={:.2f}%'.format(epoch+1, qt_err, rate)) 360 | 361 | except: 362 | lprint_('Failed to eval immatch') 363 | res = None 364 | 365 | # Update the learning rate 366 | if net.lr_scheduler: 367 | net.lr_scheduler.step() 368 | lprint_('Finished, time:{:.4f}s'.format(time.time() - t0)) 369 | log.close() 370 | 371 | 372 | if __name__ == '__main__': 373 | main() 374 | -------------------------------------------------------------------------------- /utils/colmap/data_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import numpy as np 4 | from utils.colmap.read_database import COLMAPDataLoader 5 | from utils.colmap.read_write_model import * 6 | 7 | def sav_model_multi_ov_pairs(model_dir, overlaps): 8 | sav_file_path = os.path.join(model_dir, 'ov_pairs.npy') 9 | if os.path.exists(sav_file_path): 10 | ov_pair_dict = np.load(sav_file_path, allow_pickle=True).item() 11 | all_exists = True 12 | for k in overlaps: 13 | if k not in ov_pair_dict: 14 | all_exists = False 15 | if all_exists: 16 | print('All overlaps have been computed.') 17 | return ov_pair_dict 18 | 19 | ov_pair_dict = {} 20 | images = read_images_binary(os.path.join(model_dir, 'images.bin')) 21 | im_ids = list(images.keys()) 22 | overlap_scores, _ = cal_overlap_scores(im_ids, images) 23 | 24 | for min_overlap in overlaps: 25 | if min_overlap in ov_pair_dict: 26 | print(f'ov>{min_overlap} exists, skip.') 27 | continue 28 | valid_scores = np.logical_and(overlap_scores >= min_overlap, overlap_scores < 1) 29 | pair_ids = np.vstack(np.where(valid_scores)).T 30 | pair_names = [] 31 | for id1, id2 in pair_ids: 32 | im1 = images[im_ids[id1]].name 33 | im2 = images[im_ids[id2]].name 34 | pair_names.append((max(im1, im2), min(im1, im2))) 35 | print(f'ov>{min_overlap} pairs: {len(pair_names)}') 36 | ov_pair_dict[min_overlap] = pair_names 37 | np.save(sav_file_path, ov_pair_dict) 38 | return ov_pair_dict 39 | 40 | def load_model_ov_pairs(model_dir, min_overlap=0.3): 41 | images = read_images_binary(os.path.join(model_dir, 'images.bin')) 42 | im_ids = list(images.keys()) 43 | overlap_scores, _ = cal_overlap_scores(im_ids, images) 44 | valid_scores = np.logical_and(overlap_scores >= min_overlap, overlap_scores < 1) 45 | pair_ids = np.vstack(np.where(valid_scores)).T 46 | pair_names = [] 47 | for id1, id2 in pair_ids: 48 | im1 = images[im_ids[id1]].name 49 | im2 = images[im_ids[id2]].name 50 | pair_names.append((max(im1, im2), min(im1, im2))) 51 | print('Loaded ov>{} pairs: {}'.format(min_overlap, len(pair_names))) 52 | return pair_names 53 | 54 | def cal_overlap_scores(im_ids, images): 55 | N = len(im_ids) 56 | overlap_scores = np.eye(N) 57 | im_3ds = [np.where(images[i].point3D_ids > 0)[0] for i in im_ids] 58 | 59 | for i in range(N): 60 | im1 = images[im_ids[i]] 61 | pts1 = im_3ds[i] 62 | for j in range(N): 63 | if j <= i : 64 | continue 65 | im2 = images[im_ids[j]] 66 | pts2 = im_3ds[j] 67 | ov = len(np.intersect1d(pts1, pts2)) / max(len(pts1), len(pts2)) 68 | overlap_scores[i, j] = ov 69 | nums_3d = np.array([len(v) for v in im_3ds]) 70 | return overlap_scores, nums_3d 71 | 72 | def load_model_ims(model_dir): 73 | cameras = read_cameras_binary(os.path.join(model_dir, 'cameras.bin')) 74 | images = read_images_binary(os.path.join(model_dir, 'images.bin')) 75 | #print(len(cameras), len(images)) 76 | imdict = {} 77 | for i in images: 78 | cid = images[i].camera_id 79 | if cid not in cameras: 80 | continue 81 | data = parse_data(images[i], cameras[cid]) 82 | imdict[data.name] = data 83 | return imdict 84 | 85 | def cam_params_to_matrix(params, model): 86 | if model == 'SIMPLE_PINHOLE': 87 | f, ox, oy = params 88 | K = np.array([[f, 0, ox], [0, f, oy], [0, 0, 1]]) 89 | elif model == 'PINHOLE': 90 | f1, f2, ox, oy = params 91 | K = np.array([[f1, 0, ox], [0, f2, oy], [0, 0, 1]]) 92 | elif model == 'SIMPLE_RADIAL': 93 | f, ox, oy, _ = params 94 | K = np.array([[f, 0, ox], [0, f, oy], [0, 0, 1]]) 95 | elif model == 'RADIAL': 96 | f, ox, oy, _, _ = params 97 | K = np.array([[f, 0, ox], [0, f, oy], [0, 0, 1]]) 98 | return K 99 | 100 | def parse_data(im, cam): 101 | # Extract information from Image&Camera objects 102 | K = cam_params_to_matrix(cam.params, cam.model) 103 | q = im.qvec 104 | R = qvec2rotmat(q) 105 | t = im.tvec 106 | c = - R.T.dot(t) 107 | return Namespace(name=im.name, K=K, c=c, q=q, R=R, id=im.id) 108 | 109 | def load_colmap_matches(db_path, pair_names): 110 | # Loading data from colmap database 111 | db_loader = COLMAPDataLoader(db_path) 112 | keypoints = db_loader.load_keypoints(key_len=6) 113 | images = db_loader.load_images(name_based=True) 114 | pair_ids = [(images[im1][0], images[im2][0]) for im1, im2 in pair_names] 115 | db_matches = db_loader.load_pair_matches(pair_ids) 116 | match_dict = {} 117 | for pname, pid in zip(pair_names, pair_ids): 118 | (im1, im2) = pid 119 | kpts1 = keypoints[im1] 120 | kpts2 = keypoints[im2] 121 | if pid not in db_matches: 122 | matches = None 123 | else: 124 | key_ids = db_matches[pid] 125 | N = key_ids.shape[0] 126 | matches = np.zeros((N, 4)) 127 | for j in range(N): 128 | k1, k2 = key_ids[j,:] 129 | x1, y1 = kpts1[k1][0:2] 130 | x2, y2 = kpts2[k2][0:2] 131 | matches[j, :] = [x1, y1, x2, y2] 132 | match_dict[pname] = matches 133 | return match_dict 134 | 135 | def export_intrinsics_txt(model_dir, sav_path): 136 | cameras = read_cameras_binary(os.path.join(model_dir, 'cameras.bin')) 137 | images = read_images_binary(os.path.join(model_dir, 'images.bin')) 138 | with open(sav_path, 'w') as f: 139 | for imid in images: 140 | im = images[imid] 141 | cid = im.camera_id 142 | if cid not in cameras: 143 | continue 144 | cam = cameras[cid] 145 | model = cam.model 146 | w, h = cam.width, cam.height 147 | params = cam.params 148 | 149 | # Line format: im SIMPLE_RADIAL 1600 1200 1199.91 800 600 -0.0324314 150 | line = f'{im.name} {model} {w} {h} ' 151 | for p in params: 152 | line += f'{p} ' 153 | line += '\n' 154 | f.write(line) 155 | f.flush() 156 | print('Finished, save to ', sav_path) 157 | 158 | def parse_camera_matrices(intrinsic_txt): 159 | with open(intrinsic_txt) as f: 160 | # Line format: im SIMPLE_RADIAL 1600 1200 1199.91 800 600 -0.0324314 161 | intrinsic_lines = f.readlines() 162 | camera_matrices = {} 163 | for line in intrinsic_lines: 164 | cur = line.split() 165 | name, model = cur[0:2] 166 | params = [float(v) for v in cur[4::]] 167 | K = cam_params_to_matrix(params, model) 168 | camera_matrices[name] = K 169 | return camera_matrices 170 | -------------------------------------------------------------------------------- /utils/colmap/read_database.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import sqlite3 3 | import numpy as np 4 | import collections 5 | 6 | from utils.colmap.read_write_model import CAMERA_MODEL_IDS, Camera 7 | 8 | IS_PYTHON3 = sys.version_info[0] >= 3 9 | MAX_IMAGE_ID = 2**31 - 1 10 | 11 | """ 12 | Camera parameters: 13 | f, fx, fy: focal length 14 | ox, oy: principle point 15 | r, r1, r2: radial distortion 16 | """ 17 | CAMERA_PARAMS = { 18 | 'SIMPLE_PINHOLE' : collections.namedtuple("CameraParam", ['f', 'ox', 'oy']), 19 | 'PINHOLE' : collections.namedtuple("CameraParam", ['fx', 'fy', 'ox', 'oy']), 20 | 'SIMPLE_RADIAL' : collections.namedtuple("CameraParam", ['f', 'ox', 'oy', 'r']), 21 | 'RADIAL' : collections.namedtuple("CameraParam", ['f', 'ox', 'oy', 'r1', 'r2']), 22 | } 23 | 24 | def image_ids_to_pair_id(image_id1, image_id2): 25 | if image_id1 > image_id2: 26 | image_id1, image_id2 = image_id2, image_id1 27 | return image_id1 * MAX_IMAGE_ID + image_id2 28 | 29 | def pair_id_to_image_ids(pair_id): 30 | image_id2 = pair_id % MAX_IMAGE_ID 31 | image_id1 = (pair_id - image_id2) / MAX_IMAGE_ID 32 | return int(image_id1), int(image_id2) 33 | 34 | def array_to_blob(array): 35 | if IS_PYTHON3: 36 | return array.tostring() 37 | else: 38 | return np.getbuffer(array) 39 | 40 | def blob_to_array(blob, dtype, shape=(-1,)): 41 | if IS_PYTHON3: 42 | return np.fromstring(blob, dtype=dtype).reshape(*shape) 43 | else: 44 | return np.frombuffer(blob, dtype=dtype).reshape(*shape) 45 | 46 | 47 | class COLMAPDataLoader: 48 | def __init__(self, database_path): 49 | super().__init__() 50 | self.db = sqlite3.connect(database_path) 51 | self.cameras = None 52 | self.images_name_based = None 53 | self.images_id_based = None 54 | self.keypoints = None 55 | self.matches = None 56 | self.descriptors = None 57 | 58 | def load_images(self, name_based=False): 59 | if not self.images_id_based or not self.images_name_based: 60 | images_name_based = {} 61 | images_id_based = {} 62 | for image_id, name, camera_id in self.db.execute("SELECT image_id, name, camera_id FROM images"): 63 | images_name_based[name] = (image_id, camera_id) 64 | images_id_based[image_id] = (name, camera_id) 65 | self.images_name_based = images_name_based 66 | self.images_id_based = images_id_based 67 | print('Load images to dataloader') 68 | return self.images_name_based if name_based else self.images_id_based 69 | 70 | def load_cameras(self,): 71 | if not self.cameras: 72 | cameras = {} 73 | for row in self.db.execute("SELECT * FROM cameras"): 74 | camera_id, model_id, width, height, params, prior = row 75 | model_name = CAMERA_MODEL_IDS[model_id].model_name 76 | #num_params = CAMERA_MODEL_IDS[model_id].num_params 77 | params = blob_to_array(params, np.float64) 78 | CameraParam = CAMERA_PARAMS[model_name] 79 | cameras[camera_id] = Camera(id=camera_id, 80 | model=model_name, 81 | width=width, 82 | height=height, 83 | params=CameraParam._make(params)) 84 | self.cameras = cameras 85 | print('Load cameras to dataloader') 86 | return self.cameras 87 | 88 | def load_descriptors(self): 89 | if not self.descriptors: 90 | descriptors = dict( 91 | (image_id, blob_to_array(data, np.uint8, (-1, 128))) 92 | for image_id, data in self.db.execute( 93 | "SELECT image_id, data FROM descriptors")) 94 | self.descriptors = descriptors 95 | print('Load descriptors to dataloader') 96 | return self.descriptors 97 | 98 | def load_keypoints(self, key_len=6): 99 | """ 100 | Note that COLMAP supports: 101 | - 2D keypoints: (x, y) 102 | - 4D keypoints: (x, y, scale, orientation) 103 | - 6D affine keypoints: (x, y, a_11, a_12, a_21, a_22) 104 | 105 | Return: 106 | keypoints: dict {image_id: keypoints} 107 | """ 108 | 109 | if not self.keypoints: 110 | keypoints = dict( 111 | (image_id, blob_to_array(data, np.float32, (-1, key_len))) 112 | for image_id, data in self.db.execute( 113 | "SELECT image_id, data FROM keypoints")) 114 | self.keypoints = keypoints 115 | print('Load keypoints to dataloader') 116 | return self.keypoints 117 | 118 | def load_matches(self): 119 | """Load all matches. 120 | 121 | Notice this will take lots of time if there are lots of images 122 | Return: 123 | matches: dict {image_pair_id: matches} 124 | """ 125 | 126 | if not self.matches: 127 | matches = {} 128 | for pair_id, data in self.db.execute("SELECT pair_id, data FROM matches"): 129 | if data is not None: 130 | im_pair_id = pair_id_to_image_ids(pair_id) 131 | matches[im_pair_id] = blob_to_array(data, np.uint32, (-1, 2)) 132 | self.matches = matches 133 | print('Load matches to dataloader') 134 | return self.matches 135 | 136 | def load_pair_matches(self, im_pair_ids): 137 | '''Load specified matches 138 | 139 | Arg: 140 | im_pair_ids: list of tuple (im1_id, im2_id) 141 | Return: 142 | matches: dict {image_pair_id: matches} 143 | ''' 144 | if not self.matches: 145 | matches = {} 146 | for im_pair_id in im_pair_ids: 147 | im1, im2= im_pair_id 148 | pair_id = image_ids_to_pair_id(im1, im2) 149 | data = self.db.execute("SELECT data FROM matches where pair_id={}".format(pair_id)).fetchall()[0][0] 150 | if data is not None: 151 | match_val = blob_to_array(data, np.uint32, (-1, 2)) 152 | if im1 > im2: 153 | match_val = match_val[:,::-1] # swap the indices 154 | matches[(im1, im2)] = match_val 155 | self.matches = matches 156 | print('Load matches to dataloader') 157 | return self.matches 158 | 159 | def get_intrinsics(self, im_name): 160 | self.load_images(name_based=True) 161 | self.load_cameras() 162 | cid = self.images_name_based[train_im][1] 163 | camera = self.cameras[cid] 164 | param = camera.params 165 | ox, oy = param.ox, param.oy 166 | if 'f' in param: 167 | fx, fy = param.f, param.f 168 | else: 169 | fx, fy = param.fx, param.fy 170 | return (fx, fy, ox, oy) 171 | 172 | 173 | 174 | def load_two_view_geometry(self): 175 | raise NotImplementedError 176 | -------------------------------------------------------------------------------- /utils/colmap/read_write_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018, ETH Zurich and UNC Chapel Hill. 2 | # All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions are met: 6 | # 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # 10 | # * Redistributions in binary form must reproduce the above copyright 11 | # notice, this list of conditions and the following disclaimer in the 12 | # documentation and/or other materials provided with the distribution. 13 | # 14 | # * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of 15 | # its contributors may be used to endorse or promote products derived 16 | # from this software without specific prior written permission. 17 | # 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 21 | # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE 22 | # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 23 | # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 24 | # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 25 | # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 26 | # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 27 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 28 | # POSSIBILITY OF SUCH DAMAGE. 29 | # 30 | # Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de) 31 | 32 | import os 33 | import sys 34 | import collections 35 | import numpy as np 36 | import struct 37 | import argparse 38 | 39 | 40 | CameraModel = collections.namedtuple( 41 | "CameraModel", ["model_id", "model_name", "num_params"]) 42 | Camera = collections.namedtuple( 43 | "Camera", ["id", "model", "width", "height", "params"]) 44 | BaseImage = collections.namedtuple( 45 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 46 | Point3D = collections.namedtuple( 47 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 48 | 49 | 50 | class Image(BaseImage): 51 | def qvec2rotmat(self): 52 | return qvec2rotmat(self.qvec) 53 | 54 | 55 | CAMERA_MODELS = { 56 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 57 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 58 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 59 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 60 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 61 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 62 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 63 | CameraModel(model_id=7, model_name="FOV", num_params=5), 64 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 65 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 66 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 67 | } 68 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 69 | for camera_model in CAMERA_MODELS]) 70 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 71 | for camera_model in CAMERA_MODELS]) 72 | 73 | 74 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 75 | """Read and unpack the next bytes from a binary file. 76 | :param fid: 77 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 78 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 79 | :param endian_character: Any of {@, =, <, >, !} 80 | :return: Tuple of read and unpacked values. 81 | """ 82 | data = fid.read(num_bytes) 83 | return struct.unpack(endian_character + format_char_sequence, data) 84 | 85 | 86 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 87 | """pack and write to a binary file. 88 | :param fid: 89 | :param data: data to send, if multiple elements are sent at the same time, 90 | they should be encapsuled either in a list or a tuple 91 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 92 | should be the same length as the data list or tuple 93 | :param endian_character: Any of {@, =, <, >, !} 94 | """ 95 | if isinstance(data, (list, tuple)): 96 | bytes = struct.pack(endian_character + format_char_sequence, *data) 97 | else: 98 | bytes = struct.pack(endian_character + format_char_sequence, data) 99 | fid.write(bytes) 100 | 101 | 102 | def read_cameras_text(path): 103 | """ 104 | see: src/base/reconstruction.cc 105 | void Reconstruction::WriteCamerasText(const std::string& path) 106 | void Reconstruction::ReadCamerasText(const std::string& path) 107 | """ 108 | cameras = {} 109 | with open(path, "r") as fid: 110 | while True: 111 | line = fid.readline() 112 | if not line: 113 | break 114 | line = line.strip() 115 | if len(line) > 0 and line[0] != "#": 116 | elems = line.split() 117 | camera_id = int(elems[0]) 118 | model = elems[1] 119 | width = int(elems[2]) 120 | height = int(elems[3]) 121 | params = np.array(tuple(map(float, elems[4:]))) 122 | cameras[camera_id] = Camera(id=camera_id, model=model, 123 | width=width, height=height, 124 | params=params) 125 | return cameras 126 | 127 | 128 | def read_cameras_binary(path_to_model_file): 129 | """ 130 | see: src/base/reconstruction.cc 131 | void Reconstruction::WriteCamerasBinary(const std::string& path) 132 | void Reconstruction::ReadCamerasBinary(const std::string& path) 133 | """ 134 | cameras = {} 135 | with open(path_to_model_file, "rb") as fid: 136 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 137 | for camera_line_index in range(num_cameras): 138 | camera_properties = read_next_bytes( 139 | fid, num_bytes=24, format_char_sequence="iiQQ") 140 | camera_id = camera_properties[0] 141 | model_id = camera_properties[1] 142 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 143 | width = camera_properties[2] 144 | height = camera_properties[3] 145 | num_params = CAMERA_MODEL_IDS[model_id].num_params 146 | params = read_next_bytes(fid, num_bytes=8*num_params, 147 | format_char_sequence="d"*num_params) 148 | cameras[camera_id] = Camera(id=camera_id, 149 | model=model_name, 150 | width=width, 151 | height=height, 152 | params=np.array(params)) 153 | assert len(cameras) == num_cameras 154 | return cameras 155 | 156 | 157 | def write_cameras_text(cameras, path): 158 | """ 159 | see: src/base/reconstruction.cc 160 | void Reconstruction::WriteCamerasText(const std::string& path) 161 | void Reconstruction::ReadCamerasText(const std::string& path) 162 | """ 163 | HEADER = '# Camera list with one line of data per camera:\n' 164 | '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n' 165 | '# Number of cameras: {}\n'.format(len(cameras)) 166 | with open(path, "w") as fid: 167 | fid.write(HEADER) 168 | for _, cam in cameras.items(): 169 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 170 | line = " ".join([str(elem) for elem in to_write]) 171 | fid.write(line + "\n") 172 | 173 | 174 | def write_cameras_binary(cameras, path_to_model_file): 175 | """ 176 | see: src/base/reconstruction.cc 177 | void Reconstruction::WriteCamerasBinary(const std::string& path) 178 | void Reconstruction::ReadCamerasBinary(const std::string& path) 179 | """ 180 | with open(path_to_model_file, "wb") as fid: 181 | write_next_bytes(fid, len(cameras), "Q") 182 | for _, cam in cameras.items(): 183 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 184 | camera_properties = [cam.id, 185 | model_id, 186 | cam.width, 187 | cam.height] 188 | write_next_bytes(fid, camera_properties, "iiQQ") 189 | for p in cam.params: 190 | write_next_bytes(fid, float(p), "d") 191 | return cameras 192 | 193 | 194 | def read_images_text(path): 195 | """ 196 | see: src/base/reconstruction.cc 197 | void Reconstruction::ReadImagesText(const std::string& path) 198 | void Reconstruction::WriteImagesText(const std::string& path) 199 | """ 200 | images = {} 201 | with open(path, "r") as fid: 202 | while True: 203 | line = fid.readline() 204 | if not line: 205 | break 206 | line = line.strip() 207 | if len(line) > 0 and line[0] != "#": 208 | elems = line.split() 209 | image_id = int(elems[0]) 210 | qvec = np.array(tuple(map(float, elems[1:5]))) 211 | tvec = np.array(tuple(map(float, elems[5:8]))) 212 | camera_id = int(elems[8]) 213 | image_name = elems[9] 214 | elems = fid.readline().split() 215 | xys = np.column_stack([tuple(map(float, elems[0::3])), 216 | tuple(map(float, elems[1::3]))]) 217 | point3D_ids = np.array(tuple(map(int, elems[2::3]))) 218 | images[image_id] = Image( 219 | id=image_id, qvec=qvec, tvec=tvec, 220 | camera_id=camera_id, name=image_name, 221 | xys=xys, point3D_ids=point3D_ids) 222 | return images 223 | 224 | 225 | def read_images_binary(path_to_model_file): 226 | """ 227 | see: src/base/reconstruction.cc 228 | void Reconstruction::ReadImagesBinary(const std::string& path) 229 | void Reconstruction::WriteImagesBinary(const std::string& path) 230 | """ 231 | images = {} 232 | with open(path_to_model_file, "rb") as fid: 233 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 234 | for image_index in range(num_reg_images): 235 | binary_image_properties = read_next_bytes( 236 | fid, num_bytes=64, format_char_sequence="idddddddi") 237 | image_id = binary_image_properties[0] 238 | qvec = np.array(binary_image_properties[1:5]) 239 | tvec = np.array(binary_image_properties[5:8]) 240 | camera_id = binary_image_properties[8] 241 | image_name = "" 242 | current_char = read_next_bytes(fid, 1, "c")[0] 243 | while current_char != b"\x00": # look for the ASCII 0 entry 244 | image_name += current_char.decode("utf-8") 245 | current_char = read_next_bytes(fid, 1, "c")[0] 246 | num_points2D = read_next_bytes(fid, num_bytes=8, 247 | format_char_sequence="Q")[0] 248 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 249 | format_char_sequence="ddq"*num_points2D) 250 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 251 | tuple(map(float, x_y_id_s[1::3]))]) 252 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 253 | images[image_id] = Image( 254 | id=image_id, qvec=qvec, tvec=tvec, 255 | camera_id=camera_id, name=image_name, 256 | xys=xys, point3D_ids=point3D_ids) 257 | return images 258 | 259 | 260 | def write_images_text(images, path): 261 | """ 262 | see: src/base/reconstruction.cc 263 | void Reconstruction::ReadImagesText(const std::string& path) 264 | void Reconstruction::WriteImagesText(const std::string& path) 265 | """ 266 | if len(images) == 0: 267 | mean_observations = 0 268 | else: 269 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 270 | HEADER = '# Image list with two lines of data per image:\n' 271 | '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n' 272 | '# POINTS2D[] as (X, Y, POINT3D_ID)\n' 273 | '# Number of images: {}, mean observations per image: {}\n'.format(len(images), mean_observations) 274 | 275 | with open(path, "w") as fid: 276 | fid.write(HEADER) 277 | for _, img in images.items(): 278 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 279 | first_line = " ".join(map(str, image_header)) 280 | fid.write(first_line + "\n") 281 | 282 | points_strings = [] 283 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 284 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 285 | fid.write(" ".join(points_strings) + "\n") 286 | 287 | 288 | def write_images_binary(images, path_to_model_file): 289 | """ 290 | see: src/base/reconstruction.cc 291 | void Reconstruction::ReadImagesBinary(const std::string& path) 292 | void Reconstruction::WriteImagesBinary(const std::string& path) 293 | """ 294 | with open(path_to_model_file, "wb") as fid: 295 | write_next_bytes(fid, len(images), "Q") 296 | for _, img in images.items(): 297 | write_next_bytes(fid, img.id, "i") 298 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 299 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 300 | write_next_bytes(fid, img.camera_id, "i") 301 | for char in img.name: 302 | write_next_bytes(fid, char.encode("utf-8"), "c") 303 | write_next_bytes(fid, b"\x00", "c") 304 | write_next_bytes(fid, len(img.point3D_ids), "Q") 305 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 306 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 307 | 308 | 309 | def read_points3D_text(path): 310 | """ 311 | see: src/base/reconstruction.cc 312 | void Reconstruction::ReadPoints3DText(const std::string& path) 313 | void Reconstruction::WritePoints3DText(const std::string& path) 314 | """ 315 | points3D = {} 316 | with open(path, "r") as fid: 317 | while True: 318 | line = fid.readline() 319 | if not line: 320 | break 321 | line = line.strip() 322 | if len(line) > 0 and line[0] != "#": 323 | elems = line.split() 324 | point3D_id = int(elems[0]) 325 | xyz = np.array(tuple(map(float, elems[1:4]))) 326 | rgb = np.array(tuple(map(int, elems[4:7]))) 327 | error = float(elems[7]) 328 | image_ids = np.array(tuple(map(int, elems[8::2]))) 329 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 330 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 331 | error=error, image_ids=image_ids, 332 | point2D_idxs=point2D_idxs) 333 | return points3D 334 | 335 | 336 | def read_points3d_binary(path_to_model_file): 337 | """ 338 | see: src/base/reconstruction.cc 339 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 340 | void Reconstruction::WritePoints3DBinary(const std::string& path) 341 | """ 342 | points3D = {} 343 | with open(path_to_model_file, "rb") as fid: 344 | num_points = read_next_bytes(fid, 8, "Q")[0] 345 | for point_line_index in range(num_points): 346 | binary_point_line_properties = read_next_bytes( 347 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 348 | point3D_id = binary_point_line_properties[0] 349 | xyz = np.array(binary_point_line_properties[1:4]) 350 | rgb = np.array(binary_point_line_properties[4:7]) 351 | error = np.array(binary_point_line_properties[7]) 352 | track_length = read_next_bytes( 353 | fid, num_bytes=8, format_char_sequence="Q")[0] 354 | track_elems = read_next_bytes( 355 | fid, num_bytes=8*track_length, 356 | format_char_sequence="ii"*track_length) 357 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 358 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 359 | points3D[point3D_id] = Point3D( 360 | id=point3D_id, xyz=xyz, rgb=rgb, 361 | error=error, image_ids=image_ids, 362 | point2D_idxs=point2D_idxs) 363 | return points3D 364 | 365 | 366 | def write_points3D_text(points3D, path): 367 | """ 368 | see: src/base/reconstruction.cc 369 | void Reconstruction::ReadPoints3DText(const std::string& path) 370 | void Reconstruction::WritePoints3DText(const std::string& path) 371 | """ 372 | if len(points3D) == 0: 373 | mean_track_length = 0 374 | else: 375 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) 376 | HEADER = '# 3D point list with one line of data per point:\n' 377 | '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n' 378 | '# Number of points: {}, mean track length: {}\n'.format(len(points3D), mean_track_length) 379 | 380 | with open(path, "w") as fid: 381 | fid.write(HEADER) 382 | for _, pt in points3D.items(): 383 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 384 | fid.write(" ".join(map(str, point_header)) + " ") 385 | track_strings = [] 386 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 387 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 388 | fid.write(" ".join(track_strings) + "\n") 389 | 390 | 391 | def write_points3d_binary(points3D, path_to_model_file): 392 | """ 393 | see: src/base/reconstruction.cc 394 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 395 | void Reconstruction::WritePoints3DBinary(const std::string& path) 396 | """ 397 | with open(path_to_model_file, "wb") as fid: 398 | write_next_bytes(fid, len(points3D), "Q") 399 | for _, pt in points3D.items(): 400 | write_next_bytes(fid, pt.id, "Q") 401 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 402 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 403 | write_next_bytes(fid, pt.error, "d") 404 | track_length = pt.image_ids.shape[0] 405 | write_next_bytes(fid, track_length, "Q") 406 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 407 | write_next_bytes(fid, [image_id, point2D_id], "ii") 408 | 409 | 410 | def read_model(path, ext): 411 | if ext == ".txt": 412 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 413 | images = read_images_text(os.path.join(path, "images" + ext)) 414 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 415 | else: 416 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 417 | images = read_images_binary(os.path.join(path, "images" + ext)) 418 | points3D = read_points3d_binary(os.path.join(path, "points3D") + ext) 419 | return cameras, images, points3D 420 | 421 | 422 | def write_model(cameras, images, points3D, path, ext): 423 | if ext == ".txt": 424 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 425 | write_images_text(images, os.path.join(path, "images" + ext)) 426 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 427 | else: 428 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 429 | write_images_binary(images, os.path.join(path, "images" + ext)) 430 | write_points3d_binary(points3D, os.path.join(path, "points3D") + ext) 431 | return cameras, images, points3D 432 | 433 | 434 | def qvec2rotmat(qvec): 435 | return np.array([ 436 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 437 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 438 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 439 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 440 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 441 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 442 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 443 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 444 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 445 | 446 | 447 | def rotmat2qvec(R): 448 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 449 | K = np.array([ 450 | [Rxx - Ryy - Rzz, 0, 0, 0], 451 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 452 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 453 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 454 | eigvals, eigvecs = np.linalg.eigh(K) 455 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 456 | if qvec[0] < 0: 457 | qvec *= -1 458 | return qvec 459 | 460 | 461 | def main(): 462 | parser = argparse.ArgumentParser(description='Read and write COLMAP binary and text models') 463 | parser.add_argument('input_model', help='path to input model folder') 464 | parser.add_argument('input_format', choices=['.bin', '.txt'], 465 | help='input model format') 466 | parser.add_argument('--output_model', metavar='PATH', 467 | help='path to output model folder') 468 | parser.add_argument('--output_format', choices=['.bin', '.txt'], 469 | help='outut model format', default='.txt') 470 | args = parser.parse_args() 471 | 472 | cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format) 473 | 474 | print("num_cameras:", len(cameras)) 475 | print("num_images:", len(images)) 476 | print("num_points3D:", len(points3D)) 477 | 478 | if args.output_model is not None: 479 | write_model(cameras, images, points3D, path=args.output_model, ext=args.output_format) 480 | 481 | 482 | if __name__ == "__main__": 483 | main() -------------------------------------------------------------------------------- /utils/common/plotting.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | def plot_imlist_to_pdf(ims, sav_name, figsize=(50, 35), dpi=250): 6 | import matplotlib.pyplot as plt 7 | row_num = len(plot_ims) 8 | fig = plt.figure(figsize=figsize) 9 | for i in range(row_num): 10 | ax = fig.add_subplot(row_num, 1, i+1) 11 | ax.imshow(plot_ims[i]) 12 | ax.axis('off') 13 | fig.tight_layout() 14 | fig.savefig(sav_name, dpi=dpi, bbox_inches='tight') 15 | plt.show() 16 | 17 | def plot_imlist(imlist): 18 | '''Plot a list of images in a row''' 19 | import matplotlib.pyplot as plt 20 | if type(imlist) is str: 21 | fig = plt.figure(figsize=(5, 3)) 22 | imlist = [imlist] 23 | else: 24 | fig = plt.figure(figsize=(25, 3)) 25 | num = len(imlist) 26 | for i, im in enumerate(imlist): 27 | im = Image.open(im) 28 | ax = fig.add_subplot(1, num, i+1) 29 | ax.imshow(im) 30 | plt.show() 31 | 32 | def plot_pair(pair): 33 | import matplotlib.pyplot as plt 34 | 35 | fig = plt.figure(figsize=(20, 5)) 36 | im1 = Image.open(pair[0]) 37 | im2 = Image.open(pair[1]) 38 | ax1 = fig.add_subplot(1, 2, 1) 39 | ax1.imshow(im1) 40 | ax2 = fig.add_subplot(1, 2, 2) 41 | ax2.imshow(im2) 42 | plt.show() 43 | 44 | def plot_triple(pair): 45 | import matplotlib.pyplot as plt 46 | 47 | fig = plt.figure(figsize=(20, 5)) 48 | im1 = Image.open(pair[0]) 49 | im2 = Image.open(pair[1]) 50 | im3 = Image.open(pair[2]) 51 | ax1 = fig.add_subplot(1, 3, 1) 52 | ax1.imshow(im1) 53 | ax2 = fig.add_subplot(1, 3, 2) 54 | ax2.imshow(im2) 55 | ax3 = fig.add_subplot(1, 3, 3) 56 | ax3.imshow(im3) 57 | plt.show() 58 | 59 | def torch2rgb(im): 60 | im = im.squeeze().permute(1, 2, 0) 61 | if im.device.type == 'cuda': 62 | im = im.data.cpu().numpy() 63 | else: 64 | im = im.data.numpy() 65 | return im.astype(np.uint8) 66 | 67 | def undo_normalize_scale(im): 68 | mean=[0.485, 0.456, 0.406] 69 | std=[0.229, 0.224, 0.225] 70 | im = im * std + mean 71 | im *= 255.0 72 | return im.astype(np.uint8) 73 | 74 | def recover_im_from_torch(torch_im): 75 | im = torch_im 76 | im = im.squeeze().permute(1, 2, 0).cpu().data.numpy() 77 | im = undo_normalize_scale(im) 78 | return im 79 | 80 | def scatter_pts(im, pts, unnormalize=True): 81 | import matplotlib.pyplot as plt 82 | 83 | if isinstance(im, torch.Tensor): 84 | im = im.squeeze().permute(1, 2, 0).cpu().data.numpy() 85 | if unnormalize: 86 | im = undo_normalize_scale(im) 87 | I = Image.fromarray(im) 88 | elif isinstance(im, np.ndarray): 89 | I = Image.fromarray(im) 90 | elif isinstance(im, str): 91 | I = Image.open(im) 92 | else: 93 | I = im 94 | plt.imshow(I) 95 | ax = plt.gca() 96 | for x in pts: 97 | ax.add_artist(plt.Circle((x[0], x[1]), radius=1, color='red')) 98 | plt.gcf().set_dpi(150) 99 | plt.show() 100 | 101 | def plot_pair_loader(data_loader, row_max=2, normalize_and_scale=False): 102 | import matplotlib.pyplot as plt 103 | for i, batch in enumerate(data_loader): 104 | print('>>>>>>>>>') 105 | fig1 = plt.figure(figsize=(20, 5)) 106 | fig2 = plt.figure(figsize=(20, 5)) 107 | num = len(batch['im_pairs'][0]) 108 | for j in range(num): 109 | im_pair = batch['im_pairs'] 110 | im1 = im_pair[0][j,:, :, :].permute(1, 2, 0).data.numpy() 111 | im2 = im_pair[1][j,:, :, :].permute(1, 2, 0).data.numpy() 112 | if normalize_and_scale: 113 | im1 = undo_normalize_scale(im1) 114 | im2 = undo_normalize_scale(im2) 115 | else: 116 | im1 = im1.astype(np.uint8) 117 | im2 = im2.astype(np.uint8) 118 | ax1 = fig1.add_subplot(1, num, j+1) 119 | ax1.imshow(im1) 120 | ax2 = fig2.add_subplot(1, num, j+1) 121 | ax2.imshow(im2) 122 | plt.show() 123 | if i >= row_max: 124 | break 125 | 126 | def plot_immatch_loader(data_loader, normalize_and_scale=False, num_sample=2, 127 | axis='on', dtype='triplet'): 128 | import matplotlib.pyplot as plt 129 | 130 | num = data_loader.batch_size 131 | count = 0 132 | if dtype == 'pair': 133 | ncols = 2 134 | else: 135 | ncols = 3 136 | 137 | for i, batch in enumerate(data_loader): 138 | print('Batch >>>>>>>>>') 139 | for j in range(num): 140 | fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(20, 5)) 141 | im_src = batch['src_im'][j, :, :, :].permute(1, 2, 0).data.numpy() 142 | im_pos = batch['pos_im'][j, :, :, :].permute(1, 2, 0).data.numpy() 143 | 144 | im_src = undo_normalize_scale(im_src) if normalize_and_scale else im_src.astype(np.uint8) 145 | im_pos = undo_normalize_scale(im_pos) if normalize_and_scale else im_pos.astype(np.uint8) 146 | axs[0].imshow(im_src) 147 | axs[0].axis(axis) 148 | axs[1].imshow(im_pos) 149 | axs[1].axis(axis) 150 | 151 | if ncols == 3: 152 | im_neg = batch['neg_im'][j, :, :, :].permute(1, 2, 0).data.numpy() 153 | im_neg = undo_normalize_scale(im_neg) if normalize_and_scale else im_neg.astype(np.uint8) 154 | axs[2].imshow(im_neg) 155 | axs[2].axis(axis) 156 | count += 1 157 | plt.gcf().set_dpi(350) 158 | plt.show() 159 | 160 | if count > num_sample: 161 | break 162 | 163 | def plot_triple_loader(data_loader, normalize_and_scale=False, num_sample=2, 164 | axis='on', dtype='triplet'): 165 | import matplotlib.pyplot as plt 166 | 167 | num = data_loader.batch_size 168 | count = 0 169 | if dtype == 'pair': 170 | ncols = 3 171 | else: 172 | ncols = 4 173 | 174 | for i, batch in enumerate(data_loader): 175 | print('Batch >>>>>>>>>') 176 | for j in range(num): 177 | fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(20, 5)) 178 | im1 = batch['im1'][j, :, :, :].permute(1, 2, 0).data.numpy() 179 | im2 = batch['im2'][j, :, :, :].permute(1, 2, 0).data.numpy() 180 | im3 = batch['im3'][j, :, :, :].permute(1, 2, 0).data.numpy() 181 | 182 | im1 = undo_normalize_scale(im1) if normalize_and_scale else im_src.astype(np.uint8) 183 | im2 = undo_normalize_scale(im2) if normalize_and_scale else im_pos.astype(np.uint8) 184 | im3 = undo_normalize_scale(im3) if normalize_and_scale else im_pos.astype(np.uint8) 185 | 186 | axs[0].imshow(im1) 187 | axs[0].axis(axis) 188 | axs[1].imshow(im2) 189 | axs[1].axis(axis) 190 | axs[2].imshow(im3) 191 | axs[2].axis(axis) 192 | 193 | if ncols == 4: 194 | im_neg = batch['neg_im'][j, :, :, :].permute(1, 2, 0).data.numpy() 195 | im_neg = undo_normalize_scale(im_neg) if normalize_and_scale else im_neg.astype(np.uint8) 196 | axs[3].imshow(im_neg) 197 | axs[3].axis(axis) 198 | count += 1 199 | plt.gcf().set_dpi(350) 200 | plt.show() 201 | 202 | if count > num_sample: 203 | break 204 | 205 | def plot_matches_cv(im1, im2, matches, inliers=None, Npts=1000, radius=3, dpi=350, sav_fig=None, ret_im=False): 206 | import matplotlib.pyplot as plt 207 | import cv2 208 | 209 | # Read images and resize 210 | if isinstance(im1, torch.Tensor): 211 | im1 = im1.squeeze().permute(1, 2, 0).cpu().data.numpy() 212 | im2 = im2.squeeze().permute(1, 2, 0).cpu().data.numpy() 213 | I1 = undo_normalize_scale(im1) 214 | I2 = undo_normalize_scale(im2) 215 | elif isinstance(im1, str): 216 | I1 = np.array(Image.open(im1)) 217 | I2 = np.array(Image.open(im2)) 218 | else: 219 | I1 = im1 220 | I2 = im2 221 | 222 | if inliers is None: 223 | inliers = np.arange(len(matches)) 224 | 225 | if Npts < len(inliers): 226 | inliers = inliers[:Npts] 227 | 228 | # Only matches 229 | p1s = [] 230 | p2s = [] 231 | dmatches = [] 232 | for i, (x1, y1, x2, y2) in enumerate(matches): 233 | if i in inliers: 234 | p1s.append(cv2.KeyPoint(x1, y1, 1)) 235 | p2s.append(cv2.KeyPoint(x2, y2, 1)) 236 | j = len(p1s) - 1 237 | dmatches.append(cv2.DMatch(j, j, 1)) 238 | print('Plot {} matches'.format(len(dmatches))) 239 | 240 | I3 = cv2.drawMatches(I1, p1s, I2, p2s, dmatches, None) 241 | 242 | fig = plt.figure(figsize=(50, 50)) 243 | axis = fig.add_subplot(1, 1, 1) 244 | axis.imshow(I3) 245 | axis.axis('off') 246 | if sav_fig: 247 | fig.savefig(sav_fig, dpi=150, bbox_inches='tight') 248 | plt.show() 249 | if ret_im: 250 | return I3 251 | 252 | def plot_matches(im1, im2, matches, inliers=None, Npts=None, lines=False, 253 | unnormalize=True, radius=5, dpi=150, sav_fig=False, 254 | colors=None): 255 | import matplotlib.pyplot as plt 256 | 257 | # Read images and resize 258 | if isinstance(im1, torch.Tensor): 259 | im1 = im1.squeeze().permute(1, 2, 0).cpu().data.numpy() 260 | im2 = im2.squeeze().permute(1, 2, 0).cpu().data.numpy() 261 | 262 | if unnormalize: 263 | im1 = undo_normalize_scale(im1) 264 | im2 = undo_normalize_scale(im2) 265 | else: 266 | im1 = im1.astype(np.uint8) 267 | im2 = im2.astype(np.uint8) 268 | I1 = Image.fromarray(im1) 269 | I2 = Image.fromarray(im2) 270 | elif isinstance(im1, np.ndarray): 271 | I1 = Image.fromarray(im1) 272 | I2 = Image.fromarray(im2) 273 | elif isinstance(im1, str): 274 | I1 = Image.open(im1) 275 | I2 = Image.open(im2) 276 | else: 277 | I1 = im1 278 | I2 = im2 279 | 280 | w1, h1 = I1.size 281 | w2, h2 = I2.size 282 | 283 | if h1 <= h2: 284 | scale1 = 1; 285 | scale2 = h1/h2 286 | w2 = int(scale2 * w2) 287 | I2 = I2.resize((w2, h1)) 288 | else: 289 | scale1 = h2/h1 290 | scale2 = 1 291 | w1 = int(scale1 * w1) 292 | I1 = I1.resize((w1, h2)) 293 | catI = np.concatenate([np.array(I1), np.array(I2)], axis=1) 294 | 295 | # Load all matches 296 | match_num = matches.shape[0] 297 | if inliers is None: 298 | if Npts is not None: 299 | Npts = Npts if Npts < match_num else match_num 300 | else: 301 | Npts = matches.shape[0] 302 | inliers = range(Npts) # Everthing as an inlier 303 | else: 304 | if Npts is not None and Npts < len(inliers): 305 | inliers = inliers[:Npts] 306 | print('Plotting inliers: ', len(inliers)) 307 | 308 | x1 = scale1*matches[inliers, 0] 309 | y1 = scale1*matches[inliers, 1] 310 | x2 = scale2*matches[inliers, 2] + w1 311 | y2 = scale2*matches[inliers, 3] 312 | c = np.random.rand(len(inliers), 3) 313 | 314 | if colors is not None: 315 | c = colors 316 | 317 | # Plot images and matches 318 | fig = plt.figure(figsize=(30, 20)) 319 | axis = plt.gca()#fig.add_subplot(1, 1, 1) 320 | axis.imshow(catI) 321 | axis.axis('off') 322 | 323 | #plt.imshow(catI) 324 | #ax = plt.gca() 325 | for i, inid in enumerate(inliers): 326 | # Plot 327 | axis.add_artist(plt.Circle((x1[i], y1[i]), radius=radius, color=c[i,:])) 328 | axis.add_artist(plt.Circle((x2[i], y2[i]), radius=radius, color=c[i,:])) 329 | if lines: 330 | axis.plot([x1[i], x2[i]], [y1[i], y2[i]], c=c[i,:], linestyle='-', linewidth=radius) 331 | if sav_fig: 332 | fig.savefig(sav_fig, dpi=dpi, bbox_inches='tight') 333 | plt.show() 334 | 335 | 336 | def plot_epilines(im1, im2, x1s, x2s, F, Npts=50, 337 | figsize=(30, 20), unnormalize=True, dpi=350): 338 | """ 339 | Args: 340 | - x1s, x2s: shape (N, 3) 341 | 342 | """ 343 | import matplotlib.pyplot as plt 344 | 345 | # Read images and resize 346 | if isinstance(im1, torch.Tensor): 347 | im1 = im1.squeeze().permute(1, 2, 0).cpu().data.numpy() 348 | im2 = im2.squeeze().permute(1, 2, 0).cpu().data.numpy() 349 | 350 | if unnormalize: 351 | im1 = undo_normalize_scale(im1) 352 | im2 = undo_normalize_scale(im2) 353 | 354 | I1 = Image.fromarray(im1) 355 | I2 = Image.fromarray(im2) 356 | elif isinstance(im1, np.ndarray): 357 | I1 = Image.fromarray(im1) 358 | I2 = Image.fromarray(im2) 359 | elif isinstance(im1, str): 360 | I1 = Image.open(im1) 361 | I2 = Image.open(im2) 362 | else: 363 | I1 = im1 364 | I2 = im2 365 | 366 | w1, h1 = I1.size 367 | w2, h2 = I2.size 368 | 369 | fig = plt.figure(figsize=figsize) 370 | ax1 = fig.add_subplot(211) 371 | ax1.imshow(I1) 372 | ax2 = fig.add_subplot(212) 373 | ax2.imshow(I2) 374 | 375 | num_pts = x1s.shape[0] 376 | Npts = min(num_pts, Npts) 377 | ids = np.random.permutation(num_pts)[0:Npts] 378 | colors = np.random.rand(Npts, 3) 379 | 380 | for p1, p2, color in zip(x1s[ids, :], x2s[ids, :], colors): 381 | ax1.add_artist(plt.Circle((p1[0], p1[1]), radius=3, color=color)) 382 | ax2.add_artist(plt.Circle((p2[0], p2[1]), radius=3, color=color)) 383 | 384 | # Calculate epilines 385 | l2 = F.dot(p1) # On the second image 386 | a, b, c = l2 387 | ax2.plot([0, w2 ], [-c/b, -(c + a*w2)/b], c=color, linestyle='-', linewidth=0.4) 388 | 389 | l1 = F.T.dot(p2) 390 | a, b, c = l1 391 | ax1.plot([0, w1 ], [-c/b, -(c + a*w1)/b], c=color, linestyle='-', linewidth=0.4) 392 | plt.gcf().set_dpi(dpi) 393 | plt.show() 394 | -------------------------------------------------------------------------------- /utils/common/setup_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | mem_size_of = lambda a: a.element_size() * a.nelement() # Check array/tensor size 6 | 7 | def count_parameters(model): 8 | if not model: 9 | return 0 10 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 11 | 12 | def get_sys_mem(): 13 | import psutil 14 | gb = lambda bs : bs / 2. ** 30 15 | p = psutil.Process() 16 | pmem = p.memory_info() 17 | return gb(pmem.rss), gb(pmem.vms) 18 | 19 | def get_gpu_mem(): 20 | gb = lambda bs : bs / 2. ** 30 21 | max_allocated = torch.cuda.max_memory_allocated() 22 | max_reserved = torch.cuda.max_memory_reserved() 23 | return gb(max_allocated), gb(max_reserved) 24 | 25 | def load_weights(weights_dir, device): 26 | map_location = lambda storage, loc: storage.cuda(device.index) if torch.cuda.is_available() else storage 27 | weights_dict = None 28 | if weights_dir is not None: 29 | weights_dict = torch.load(weights_dir, map_location=map_location) 30 | return weights_dict 31 | 32 | def lprint(ms, log=None): 33 | '''Print message on console and in a log file''' 34 | print(ms) 35 | if log: 36 | log.write(ms+'\n') 37 | log.flush() 38 | 39 | def make_deterministic(seed, benchmark=False): 40 | random.seed(seed) 41 | np.random.seed(seed) 42 | torch.manual_seed(seed) 43 | torch.backends.cudnn.deterministic = True 44 | torch.backends.cudnn.benchmark = benchmark # Important also 45 | 46 | def config2str(config): 47 | print_ignore = ['weights_dict', 'optimizer_dict'] 48 | args = vars(config) 49 | separator = '\n' 50 | confstr = '' 51 | confstr += '------------ Configuration -------------{}'.format(separator) 52 | for k, v in sorted(args.items()): 53 | if k in print_ignore: 54 | if v is not None: 55 | confstr += '{}:{}{}'.format(k, len(v), separator) 56 | continue 57 | confstr += '{}:{}{}'.format(k, str(v), separator) 58 | confstr += '----------------------------------------{}'.format(separator) 59 | return confstr 60 | -------------------------------------------------------------------------------- /utils/common/visdom_helper.py: -------------------------------------------------------------------------------- 1 | import visdom 2 | import numpy as np 3 | 4 | class VisMeter: 5 | def __init__(self, name, vis=None, env='', ptit=''): 6 | self.name = name 7 | self.meter = [] 8 | self.vis = vis 9 | self.env = env 10 | self.ptit = ptit 11 | self.opts = dict(mode='lines', showlegend=True, 12 | layoutopts={'plotly': dict(title=ptit, 13 | xaxis={'title': 'iters'})}) 14 | def clear(self): 15 | self.meter = [] 16 | 17 | def append(self, x): 18 | self.meter.append(x) 19 | 20 | def mean(self): 21 | if len(self.meter) > 0: 22 | return np.mean(self.meter) 23 | else: 24 | return None 25 | 26 | def plot(self, epoch): 27 | if self.vis is None: 28 | return 29 | 30 | if self.mean() is None: 31 | return 32 | 33 | X = [epoch] 34 | Y = [self.mean()] 35 | self.vis.line(X=X, Y=Y, env=self.env, 36 | win=self.ptit, 37 | name=self.name, 38 | opts=self.opts, 39 | update='append') 40 | self.vis.save(envs=[self.env]) 41 | 42 | def __repr__(self): 43 | return 'Visdom meter(env={}, plot={}, name={})'.format(self.env, self.ptit, self.name) 44 | 45 | 46 | class VisPlots: 47 | def __init__(self, plots, vis, env, prefix='train'): 48 | """ 49 | plots: namespace(plot1=namespace(legend1, legend2), plot2=namespace(legend1, legend2)..) 50 | """ 51 | self.vis = vis 52 | self.env = env 53 | 54 | for name in plots.__dict__: 55 | self.init_plot_meters('{}.{}'.format(prefix, name), plots.__dict__[name]) 56 | self.plots = plots 57 | 58 | def init_plot_meters(self, name, plot): 59 | """ 60 | name: plot name 61 | plot: Namespace(leg1=None, leg2=None) 62 | """ 63 | for legend in plot.__dict__ : 64 | plot.__dict__[legend] = VisMeter(legend, self.vis, self.env, ptit=name) 65 | 66 | def plot(self, epoch): 67 | plots = self.plots 68 | for name in plots.__dict__: 69 | plot = plots.__dict__[name] 70 | for legend in plot.__dict__: 71 | plot.__dict__[legend].plot(epoch) 72 | 73 | def clear(self): 74 | plots = self.plots 75 | for name in plots.__dict__: 76 | plot = plots.__dict__[name] 77 | for legend in plot.__dict__: 78 | plot.__dict__[legend].clear() 79 | 80 | def get_plot_print(self, plot): 81 | """ 82 | plot: Namespace(leg1=None, leg2=None) 83 | """ 84 | mprint = '' 85 | for legend in plot.__dict__: 86 | meter = plot.__dict__[legend] 87 | val = meter.mean() 88 | if val: 89 | mprint = '{}{}={:.2f} '.format(mprint, legend, meter.mean()) 90 | return mprint 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /utils/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_megadepth import ImMatchDatasetMega 2 | -------------------------------------------------------------------------------- /utils/datasets/data_parsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def parse_3d_points_from_nvm(nvm_file): 5 | """ 6 | Formats of nvm file: 7 | 8 | 9 | = 10 | = 11 | """ 12 | cams = [] # List image frames 13 | cam_points = {} # Map key: index of frame, value: list of indices of 3d points that are visible to this frame. 14 | points = [] # List of 3d points in the reconstruction model 15 | 16 | print('Read 3D points from {}'.format(nvm_file)) 17 | with open(nvm_file, 'r') as f: 18 | next(f) # Skip headding lines 19 | next(f) 20 | 21 | # Load images 22 | cam_num = int(next(f).split()[0]) 23 | for i in range(cam_num): 24 | line = next(f) 25 | frame = line.split()[0] 26 | cams.append(frame) 27 | cam_points[frame] = [] 28 | 29 | next(f) # Skip the separation line 30 | point_num = int(next(f).split()[0]) 31 | for i in range(point_num): 32 | line = next(f) 33 | cur = line.split() 34 | X = cur[0:3] 35 | points.append(X) 36 | measure_num = int(cur[6]) 37 | for j in range(measure_num): 38 | idx = int(cur[7+j*4]) 39 | frame = cams[idx] 40 | cam_points[frame].append(i) 41 | print('Loading finished: camera frames {}, total 3d points {}'.format(len(cam_points), len(points))) 42 | return (points, cam_points) 43 | 44 | def parse_abs_pose_txt(fpath): 45 | """Absolute pose label format: 46 | 3 header lines 47 | list of samples with format: 48 | image x y z w p q r 49 | """ 50 | 51 | pose_dict = {} 52 | f = open(fpath) 53 | for line in f.readlines()[3::]: # Skip 3 header lines 54 | cur = line.split(' ') 55 | c = np.array([float(v) for v in cur[1:4]], dtype=np.float32) 56 | q = np.array([float(v) for v in cur[4:8]], dtype=np.float32) 57 | im = cur[0] 58 | pose_dict[im] = (c, q) 59 | f.close() 60 | return pose_dict 61 | 62 | class CambridgeIntrinsics: 63 | scenes = ['KingsCollege', 'OldHospital', 'ShopFacade', 'StMarysChurch'] 64 | def __init__(self, base_dir, scene, wt=1920, ht=1080, w=1920, h=1080): 65 | assert scene in self.scenes 66 | self.base_dir = base_dir 67 | self.scene = scene 68 | self.wt, self.ht = wt, ht 69 | self.w, self.h = w, h 70 | self.ox, self.oy = w / 2, h / 2 71 | self.sK = np.array([[wt / w, 0, 0], 72 | [0, ht / h, 0], 73 | [0, 0, 1]]) 74 | self.focals = self.get_focals() 75 | self.im_list = list(self.focals.keys()) 76 | self.intrinsic_matrices = {} 77 | for im in self.im_list: 78 | f = self.focals[im] 79 | K = np.array([[f, 0, self.ox], 80 | [0, f, self.oy], 81 | [0, 0, 1]], dtype=np.float32) 82 | K = self.sK.dot(K) 83 | self.intrinsic_matrices[im] = K 84 | 85 | def get_focals(self): 86 | focals = {} 87 | nvm = os.path.join(self.base_dir, self.scene,'reconstruction.nvm') 88 | with open(nvm, 'r') as f: 89 | # Skip headding lines 90 | next(f) 91 | next(f) 92 | cam_num = int(f.readline().split()[0]) 93 | print('Loading focals scene: {} cameras: {}'.format(self.scene, cam_num)) 94 | 95 | focals = {} 96 | for i in range(cam_num): 97 | line = f.readline() 98 | cur = line.split() 99 | focals[cur[0].replace('jpg', 'png')] = float(line.split()[1]) 100 | return focals 101 | 102 | def get_intrinsic_matrices(self): 103 | return self.intrinsic_matrices 104 | 105 | def get_im_intrinsics(self, im): 106 | return self.intrinsic_matrices[im] 107 | 108 | def get_positive_pairs(cam_points, imlist, thres_min=0.15, thres_max=0.8): 109 | """ 110 | Args: 111 | cam_points: (key:cam, val: [3d_point_ids]) 112 | thres_min, thres_max: min/max thresholds for overlapped 3D points 113 | Return: 114 | pairs: {(im1, im2): PosPair}) 115 | """ 116 | from argparse import Namespace 117 | from transforms3d.quaternions import quat2mat 118 | from utils.eval.geometry import abs2relapose 119 | 120 | # Pairwise overlapping calculation 121 | pairs = [] 122 | overlaps = [] 123 | total_num_pos = 0 124 | for i, im1 in enumerate(imlist): 125 | for j, im2 in enumerate(imlist): 126 | if j <= i: 127 | continue 128 | 129 | # Calculate overlapping 130 | p1 = cam_points[im1.name.replace('png', 'jpg')] 131 | p2 = cam_points[im2.name.replace('png', 'jpg')] 132 | p12 = list(set(p1).intersection(p2)) # Common visible points 133 | score = min(1.0 * len(p12) / len(p1), 1.0 * len(p12) / len(p2)) 134 | overlaps.append(score) 135 | if score < thres_min or score > thres_max: 136 | continue 137 | 138 | # Calculate relative pose and essential matrix 139 | (t, q) = abs2relapose(im1.c, im2.c, im1.q, im2.q) # t12 is un-normalized version 140 | R = quat2mat(q) 141 | pairs.append(Namespace(im1=im1.name, im2=im2.name, 142 | overlap=score, K1=im1.K, K2=im2.K, t=t, q=q, R=R)) 143 | print('Total pairs {} Positive({} imsize and imsize > 0: 13 | scale = imsize / max(wo, ho) 14 | ht, wt = int(round(ho * scale)), int(round(wo * scale)) 15 | im = im.resize((wt, ht), Image.BICUBIC) 16 | scale = (wo / wt, ho / ht) 17 | 18 | # Gray 19 | gray = None 20 | if with_gray: 21 | im_gray = np.array(im.convert('L')) 22 | gray = transforms.functional.to_tensor(im_gray).unsqueeze(0).to(device) 23 | 24 | # RGB 25 | im = transforms.functional.to_tensor(im) 26 | im = transforms.functional.normalize(im , mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 27 | im = im.unsqueeze(0).to(device) 28 | if with_gray: 29 | return im, gray, scale 30 | return im, scale 31 | 32 | def load_im_flexible(im_path, k_size=2, upsample=16, imsize=None, crop_square=False): 33 | """Load image for matching on fly. 34 | The image can be arbitrary size and will be processed to 35 | fulfil the format required for matching. 36 | """ 37 | # Load original image 38 | img = Image.open(im_path) 39 | img = img.convert('RGB') 40 | wo, ho = img.width, img.height 41 | if not (imsize and imsize > 0): 42 | imsize = max(wo, ho) 43 | elif imsize > max(wo, ho): # Disable upsampling 44 | imsize = max(wo, ho) 45 | 46 | wt, ht = cal_rescale_size(image_size=imsize, w=wo, h=ho, k_size=k_size, 47 | scale_factor=1./upsample, no_print=True) 48 | 49 | # Resize image and transform to tensor 50 | ops = get_tuple_transform_ops(resize=None, normalize=True) 51 | img = transforms.functional.resize(img, (ht, wt), Image.BICUBIC) 52 | img = ops([img])[0] 53 | 54 | # Mainly for beauty plotting 55 | if crop_square: 56 | _, h, w = img.shape 57 | img = img[:, :w,:] 58 | 59 | scale = (wo / wt, ho / ht) 60 | return img, scale 61 | 62 | def crop_from_bottom_right(w, h, target_ratio=1.5, min_ratio=1.3, max_ratio=1.7): 63 | ratio = w / h 64 | if ratio < min_ratio or ratio > max_ratio: 65 | return None 66 | if ratio == target_ratio: 67 | return 0, 0 68 | if ratio > target_ratio: 69 | # Cut the width 70 | dh = h % 2 71 | ht = h - dh 72 | dw = w - ht * target_ratio 73 | wt = w - dw 74 | 75 | if ratio < target_ratio: 76 | # Cut the height 77 | dw = w % 3 78 | wt = w - dw 79 | dh = h - wt / target_ratio 80 | ht = h - dh 81 | return dw, dh 82 | 83 | def cal_rescale_size(image_size, w, h, k_size=2, scale_factor=1/16, no_print=False): 84 | # Calculate target image size (lager side=image_size) 85 | wt = int(np.floor(w/(max(w, h)/image_size)*scale_factor/k_size)/scale_factor*k_size) 86 | ht = int(np.floor(h/(max(w, h)/image_size)*scale_factor/k_size)/scale_factor*k_size) 87 | N = wt * ht * scale_factor * scale_factor / (k_size ** 2) 88 | if not no_print: 89 | print('Target size {} Original: (w={},h={}), Rescaled: (w={},h={}) , matches resolution: {}'.format(image_size, w, h, 90 | wt, ht, N)) 91 | return wt, ht 92 | 93 | def get_tuple_transform_ops(resize=None, normalize=True, unscale=False): 94 | ops = [] 95 | if resize: 96 | ops.append(TupleResize(resize)) 97 | if normalize: 98 | ops.append(TupleToTensorScaled()) 99 | ops.append(TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) # Imagenet mean/std 100 | else: 101 | if unscale: 102 | ops.append(TupleToTensorUnscaled()) 103 | else: 104 | ops.append(TupleToTensorScaled()) 105 | return TupleCompose(ops) 106 | 107 | class ToTensorScaled(object): 108 | '''Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]''' 109 | def __call__(self, im): 110 | im = np.array(im, dtype=np.float32).transpose((2, 0, 1)) 111 | im /= 255.0 112 | return torch.from_numpy(im) 113 | 114 | def __repr__(self): 115 | return 'ToTensorScaled(./255)' 116 | 117 | class TupleToTensorScaled(object): 118 | def __init__(self): 119 | self.to_tensor = ToTensorScaled() 120 | 121 | def __call__(self, im_tuple): 122 | return [self.to_tensor(im) for im in im_tuple] 123 | 124 | def __repr__(self): 125 | return 'TupleToTensorScaled(./255)' 126 | 127 | class ToTensorUnscaled(object): 128 | '''Convert a RGB PIL Image to a CHW ordered Tensor''' 129 | def __call__(self, im): 130 | return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1))) 131 | 132 | def __repr__(self): 133 | return 'ToTensorUnscaled()' 134 | 135 | class TupleToTensorUnscaled(object): 136 | '''Convert a RGB PIL Image to a CHW ordered Tensor''' 137 | def __init__(self): 138 | self.to_tensor = ToTensorUnscaled() 139 | 140 | def __call__(self, im_tuple): 141 | return [self.to_tensor(im) for im in im_tuple] 142 | 143 | def __repr__(self): 144 | return 'TupleToTensorUnscaled()' 145 | 146 | class TupleResize(object): 147 | def __init__(self, size): 148 | self.size = size 149 | self.resize = transforms.Resize(size, Image.BICUBIC) 150 | 151 | def __call__(self, im_tuple): 152 | return [self.resize(im) for im in im_tuple] 153 | 154 | def __repr__(self): 155 | return 'TupleResize(size={})'.format(self.size) 156 | 157 | class TupleNormalize(object): 158 | def __init__(self, mean, std): 159 | self.mean = mean 160 | self.std = std 161 | self.normalize = transforms.Normalize(mean=mean, std=std) 162 | 163 | def __call__(self, im_tuple): 164 | return [self.normalize(im) for im in im_tuple] 165 | 166 | def __repr__(self): 167 | return 'TupleNormalize(mean={}, std={})'.format(self.mean, self.std) 168 | 169 | class TupleCompose(object): 170 | def __init__(self, transforms): 171 | self.transforms = transforms 172 | 173 | def __call__(self, im_tuple): 174 | for t in self.transforms: 175 | im_tuple = t(im_tuple) 176 | return im_tuple 177 | 178 | def __repr__(self): 179 | format_string = self.__class__.__name__ + '(' 180 | for t in self.transforms: 181 | format_string += '\n' 182 | format_string += ' {0}'.format(t) 183 | format_string += '\n)' 184 | return format_string -------------------------------------------------------------------------------- /utils/eval/geometry.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from transforms3d.quaternions import quat2mat, mat2quat 3 | 4 | 5 | # The skew-symmetric matrix of vector 6 | skew = lambda v: np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]]) 7 | 8 | # Essential matrix & fundamental matrix 9 | ess2fund = lambda K1, K2, E: np.linalg.inv(K2).T @ E @ np.linalg.inv(K1) 10 | ess2fund_inv = lambda K1_inv, K2_inv, E: K2_inv.T @ E @ K1_inv 11 | fund2ess = lambda F, K2, K1: K2.T @ F @ K1 12 | 13 | # Camera relative pose to fundamental matrix 14 | pose2ess = lambda R, t: skew(t.reshape(3,)) @ R 15 | pose2fund = lambda K1, K2, R, t: np.linalg.inv(K2).T @ R @ K1.T @ skew((K1 @ R.T).dot(t.reshape(3,))) 16 | pose2fund_inv = lambda K1, K2_inv, R, t: K2_inv.T @ R @ K1.T @ skew((K1 @ R.T).dot(t)) 17 | 18 | # Normalize fundamental matrix 19 | normF = lambda F: F / F[-1,-1] # Normalize F by the last value 20 | normalize = lambda A: A / np.linalg.norm(A) 21 | 22 | def compose_projection_matrix(R, t): 23 | """Construct projection matrix 24 | Args: 25 | - R: rotation matrix, size (3,3); 26 | - t: translation vector, size (3,); 27 | Return: 28 | - projection matrix [R|t], size (3,4) 29 | """ 30 | return np.hstack([R, np.expand_dims(t, axis=1)]) 31 | 32 | def matches2relapose_cv(p1, p2, K1, K2, rthres=1): 33 | import cv2 34 | # Move back to image center based coordinates 35 | f1, f2, = K1[0,0], K2[0, 0] 36 | pc1 = np.array([K1[:2, 2]]) 37 | pc2 = np.array([K2[:2, 2]]) 38 | 39 | # Rescale to im2 's focal setting 40 | p1 = (p1 - pc1) * f2 / f1 41 | p2 = (p2 - pc2) 42 | K = np.array([[f2, 0, 0], 43 | [0, f2, 0], 44 | [0, 0, 1]]) 45 | E, inls = cv2.findEssentialMat(p1, p2, cameraMatrix=K, method=cv2.FM_RANSAC, threshold=rthres) 46 | inls = np.where(inls > 0)[0] 47 | _, R, t, _ = cv2.recoverPose(E, p1[inls], p2[inls], K) 48 | return E, inls, R, t 49 | 50 | def matches2relapose_degensac(p1, p2, K1, K2, rthres=1): 51 | import pydegensac 52 | import cv2 53 | 54 | # Move back to image center based coordinates 55 | f1, f2 = K1[0,0], K2[0, 0] 56 | pc1 = np.array([K1[:2, 2]]) 57 | pc2 = np.array([K2[:2, 2]]) 58 | 59 | # Rescale to im2 's focal setting 60 | p1 = (p1 - pc1) * f2 / f1 61 | p2 = (p2 - pc2) 62 | K = np.array([[f2, 0, 0], 63 | [0, f2, 0], 64 | [0, 0, 1]]) 65 | K1 = K2 = K 66 | 67 | F, inls = pydegensac.findFundamentalMatrix(p1, p2, rthres) 68 | E = fund2ess(F, K1, K2) 69 | inls = np.where(inls > 0)[0] 70 | _, R, t, _ = cv2.recoverPose(E, p1[inls], p2[inls], K) 71 | return E, inls, R, t 72 | 73 | def abs2relapose(c1, c2, q1, q2): 74 | """Calculate relative pose between two cameras 75 | Args: 76 | - c1: absolute position of the first camera 77 | - c2: absolute position of the second camera 78 | - q1: orientation quaternion of the first camera 79 | - q2: orientation quaternion of the second camera 80 | Return: 81 | - (t12, q12): relative pose giving the transformation from the 1st camera to the 2nd camera coordinates, 82 | t12 is translation, q12 is relative rotation quaternion 83 | """ 84 | r1 = quat2mat(q1) 85 | r2 = quat2mat(q2) 86 | r12 = r2.dot(r1.T) 87 | q12 = mat2quat(r12) 88 | t12 = r2.dot(c1 - c2) 89 | return (t12, q12) 90 | -------------------------------------------------------------------------------- /utils/eval/measure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | get_statis = lambda arr: 'Size={} Min={:.2f} Max={:.2f} Mean={:.2f} Median={:.2f}'.format( 4 | arr.shape, np.min(arr), np.max(arr), np.mean(arr), np.median(arr)) 5 | 6 | def expand_homo_ones(arr2d, axis=1): 7 | """Raise 2D array to homogenous coordinates 8 | Args: 9 | - arr2d: (N, 2) or (2, N) 10 | - axis: the axis to append the ones 11 | """ 12 | if axis == 0: 13 | ones = np.ones((1, arr2d.shape[1])) 14 | else: 15 | ones = np.ones((arr2d.shape[0], 1)) 16 | return np.concatenate([arr2d, ones], axis=axis) 17 | 18 | def sampson_distance(pts1, pts2, F, homos=True, eps=1e-8): 19 | """Calculate symmetric epipolar distance between 2 sets of points 20 | Args: 21 | - pts1, pts2: points correspondences in the two images, 22 | each has shape of (num_points, 2) 23 | - F: fundamental matrix that fulfills x2^T*F*x1=0, 24 | where x1 and x2 are the correspondence points in the 1st and 2nd image 25 | Return: 26 | A vector of (num_points,), containing root-squared epipolar distances 27 | 28 | """ 29 | 30 | # Homogenous coordinates 31 | if homos: 32 | pts1 = expand_homo_ones(pts1, axis=1) #if pts1.shape[1] == 2 else pts1 33 | pts2 = expand_homo_ones(pts2, axis=1) #if pts2.shape[1] == 2 else pts2 34 | 35 | # l2=F*x1, l1=F^T*x2 36 | l2 = np.dot(F, pts1.T) # 3,N 37 | l1 = np.dot(F.T, pts2.T) 38 | dd = np.sum(l2.T * pts2, 1) # Distance from pts2 to l2 39 | d = dd ** 2 / (eps + l1[0, :] ** 2 + l1[1, :] ** 2 + l2[0, :] ** 2 + l2[1, :] ** 2) 40 | return d 41 | 42 | 43 | def symmetric_epipolar_distance(pts1, pts2, F, homos=True, sqrt=False): 44 | """Calculate symmetric epipolar distance between 2 sets of points 45 | Args: 46 | - pts1, pts2: points correspondences in the two images, 47 | each has shape of (num_points, 2) 48 | - F: fundamental matrix that fulfills x2^T*F*x1=0, 49 | where x1 and x2 are the correspondence points in the 1st and 2nd image 50 | Return: 51 | A vector of (num_points,), containing root-squared epipolar distances 52 | 53 | """ 54 | 55 | # Homogenous coordinates 56 | if homos: 57 | pts1 = expand_homo_ones(pts1, axis=1) #if pts1.shape[1] == 2 else pts1 58 | pts2 = expand_homo_ones(pts2, axis=1) #if pts2.shape[1] == 2 else pts2 59 | 60 | # l2=F*x1, l1=F^T*x2 61 | l2 = np.dot(F, pts1.T) # 3,N 62 | l1 = np.dot(F.T, pts2.T) 63 | dd = np.sum(l2.T * pts2, 1) # Distance from pts2 to l2 64 | 65 | if sqrt: 66 | # The one following DFM and find correspondence paper 67 | d = np.abs(dd) * (1.0 / np.sqrt(l1[0, :] ** 2 + l1[1, :] ** 2) + 1.0 / np.sqrt(l2[0, :] ** 2 + l2[1, :] ** 2)) 68 | else: 69 | # Original one as in MVG Hartley. 70 | d = dd ** 2 * (1.0 / (l1[0, :] ** 2 + l1[1, :] ** 2) + 1.0 /(l2[0, :] ** 2 + l2[1, :] ** 2)) 71 | return d 72 | 73 | def cal_vec_angle_error(label, pred, eps=1e-14): 74 | if len(label.shape) == 1: 75 | label = np.expand_dims(label, axis=0) 76 | if len(pred.shape) == 1: 77 | pred = np.expand_dims(pred, axis=0) 78 | 79 | v1 = pred / (np.linalg.norm(pred, axis=1, keepdims=True) + eps) 80 | v2 = label / (np.linalg.norm(label, axis=1, keepdims=True) + eps) 81 | d = np.sum(np.multiply(v1,v2), axis=1, keepdims=True) 82 | d = np.clip(d, a_min=-1, a_max=1) 83 | error = np.degrees(np.arccos(d)) 84 | return error.squeeze() 85 | 86 | def cal_quat_angle_error(label, pred, eps=1e-14): 87 | if len(label.shape) == 1: 88 | label = np.expand_dims(label, axis=0) 89 | if len(pred.shape) == 1: 90 | pred = np.expand_dims(pred, axis=0) 91 | q1 = pred / (np.linalg.norm(pred, axis=1, keepdims=True) + eps) 92 | q2 = label / (np.linalg.norm(label, axis=1, keepdims=True) + eps) 93 | d = np.abs(np.sum(np.multiply(q1,q2), axis=1, keepdims=True)) 94 | d = np.clip(d, a_min=-1, a_max=1) 95 | error = 2 * np.degrees(np.arccos(d)) 96 | return error.squeeze() 97 | 98 | def cal_rot_angle_error(Rgt, Rpred): 99 | # Identical to quaternion angular error 100 | return np.rad2deg(np.arccos((np.trace(Rpred.T.dot(Rgt)) - 1) / 2)) 101 | 102 | def eval_matches_relapose(matches, K1, K2, q_, t_, cv_thres=1.0): 103 | from utils.eval.geometry import matches2relapose_cv 104 | from transforms3d.quaternions import mat2quat 105 | 106 | p1 = matches[:,:2] 107 | p2 = matches[:,2:4] 108 | E, inls, R, t = matches2relapose_cv(p1, p2, K1, K2, rthres=cv_thres) 109 | 110 | # Calculate relative angle errors 111 | terr = cal_vec_angle_error(t.squeeze(), t_) 112 | qerr = cal_quat_angle_error(mat2quat(R), q_) 113 | return terr, qerr, inls 114 | 115 | def check_inliers_distr(inlier_dists, 116 | bins = [0, 1e-2, 1, 5, 10, 25, 50, 100, 400, 2500, 1e5], 117 | tag='', return_ratios=False): 118 | if not inlier_dists: 119 | if return_ratios: 120 | return None, '' 121 | return '' 122 | inlier_ratios = [] 123 | Npts = [] 124 | for dists in inlier_dists: 125 | N = len(dists) 126 | if N == 0: 127 | continue 128 | Npts.append(N) 129 | hists = np.histogram(dists, bins)[0] 130 | inlier_ratios.append(hists / N) 131 | 132 | ratio_print = '{} Sample:{} N(mean/max/min):{:.0f}/{:.0f}/{:.0f}\nRatios(%):'.format(tag, len(inlier_dists), np.mean(Npts), 133 | np.max(Npts), np.min(Npts)) 134 | 135 | ratios = [] 136 | for val, low, high in zip(np.mean(inlier_ratios, axis=0), bins[0:-1], bins[1::]): 137 | ratio_print = '{} [{},{})={:.2f}'.format(ratio_print, low, high, 100*val) 138 | ratios.append(100*val) 139 | if return_ratios: 140 | return ratios, ratio_print 141 | return ratio_print 142 | 143 | def check_data_hist(data_list, bins, tag='', return_hist=False): 144 | if not data_list: 145 | return '' 146 | hists = [] 147 | means = [] 148 | for data in data_list: 149 | if len(data) == 0: 150 | continue 151 | nums = np.histogram(data, bins)[0] 152 | hists.append(nums / len(data)) 153 | means.append(np.mean(data)) 154 | 155 | hist_print = f'{tag} mean={np.mean(means):.2f}' 156 | mean_hists = np.mean(hists, axis=0) 157 | for val, low, high in zip(mean_hists, bins[0:-1], bins[1::]): 158 | hist_print += ' [{},{})={:.2f}'.format(low, high, 100 * val) 159 | if return_hist: 160 | return mean_hists, hist_print 161 | return hist_print 162 | -------------------------------------------------------------------------------- /utils/eval/model_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from argparse import Namespace 5 | 6 | from utils.common.setup_helper import load_weights 7 | from utils.datasets.preprocess import load_im_flexible, load_im_tensor 8 | from networks.patch2pix import Patch2Pix 9 | 10 | def init_patch2pix_matcher(args): 11 | net = load_model(args.ckpt, method='patch2pix') 12 | matcher = lambda imq, imr: estimate_matches(net, imq, imr, 13 | ksize=args.ksize, 14 | io_thres=args.io_thres, 15 | eval_type='fine', 16 | imsize=args.imsize) 17 | return matcher 18 | 19 | def init_ncn_matcher(args): 20 | net = load_model(args.ckpt, method='nc') 21 | matcher = lambda imq, imr: estimate_matches(net, imq, imr, 22 | ksize=args.ksize, 23 | ncn_thres=args.ncn_thres, 24 | eval_type='coarse', 25 | imsize=args.imsize) 26 | return matcher 27 | 28 | def load_model(ckpt_path, method='patch2pix', lprint=print): 29 | # Initialize network 30 | device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') 31 | ckpt = load_weights(ckpt_path, device) 32 | config = Namespace(training=False, 33 | device=device, 34 | regr_batch=1200, 35 | backbone='ResNet34', 36 | feat_idx=None, 37 | weights_dict=None, 38 | regressor_config=None, 39 | change_stride=True) 40 | lprint('\nLoad model method:{} '.format(method)) 41 | if 'patch2pix' in method: 42 | config.backbone = ckpt['backbone'] 43 | config.feat_idx = ckpt['feat_idx'] 44 | config.weights_dict = ckpt['state_dict'] 45 | config.regressor_config = ckpt['regressor_config'] 46 | config.regressor_config.panc = 1 # Only use panc 1 during evaluation 47 | if 'last_epoch' in ckpt: 48 | epoch = ckpt['last_epoch']+1 49 | lprint(f'Ckpt:{ckpt_path} epochs:{epoch}') 50 | else: 51 | lprint(f'Ckpt:{ckpt_path}') 52 | 53 | elif 'nc' in method: 54 | if type(ckpt) is dict: 55 | ckpt = ckpt['state_dict'] 56 | config.weights_dict = ckpt 57 | lprint('Load pretrained weights: {}'.format(ckpt_path)) 58 | else: 59 | lprint('Wrong method name.') 60 | net = Patch2Pix(config) 61 | net.eval() 62 | return net 63 | 64 | def estimate_matches(net, im1, im2, ksize=2, ncn_thres=0.0, mutual=True, 65 | io_thres=0.25, eval_type='fine', imsize=None): 66 | # Assume batch size is 1 67 | # Load images 68 | im1, sc1 = load_im_flexible(im1, ksize, net.upsample, imsize=imsize) 69 | im2, sc2 = load_im_flexible(im2, ksize, net.upsample, imsize=imsize) 70 | upscale = np.array([sc1 + sc2]) 71 | im1 = im1.unsqueeze(0).to(net.device) 72 | im2 = im2.unsqueeze(0).to(net.device) 73 | 74 | # Coarse matching 75 | if eval_type == 'coarse': 76 | with torch.no_grad(): 77 | coarse_matches, scores = net.predict_coarse(im1, im2, ksize=ksize, 78 | ncn_thres=ncn_thres, 79 | mutual=mutual) 80 | matches = coarse_matches[0].cpu().data.numpy() 81 | scores = scores[0].cpu().data.numpy() 82 | matches = upscale * matches 83 | return matches, scores, matches 84 | 85 | # Fine matching 86 | if eval_type == 'fine': 87 | # Fine matches 88 | with torch.no_grad(): 89 | fine_matches, fine_scores, coarse_matches = net.predict_fine(im1, im2, ksize=ksize, 90 | ncn_thres=ncn_thres, 91 | mutual=mutual) 92 | coarse_matches = coarse_matches[0].cpu().data.numpy() 93 | fine_matches = fine_matches[0].cpu().data.numpy() 94 | fine_scores = fine_scores[0].cpu().data.numpy() 95 | 96 | # Inlier filtering 97 | pos_ids = np.where(fine_scores > io_thres)[0] 98 | if len(pos_ids) > 0: 99 | coarse_matches = coarse_matches[pos_ids] 100 | matches = fine_matches[pos_ids] 101 | scores = fine_scores[pos_ids] 102 | else: 103 | # Simply take all matches for this case 104 | matches = fine_matches 105 | scores = fine_scores 106 | 107 | matches = upscale * matches 108 | coarse_matches = upscale * coarse_matches 109 | return matches, scores, coarse_matches 110 | 111 | def refine_matches(im1_path, im2_path, net, coarse_matcher, 112 | io_thres=0.0, imsize=None, coarse_only=False): 113 | # Load images 114 | im1, grey1, sc1 = load_im_tensor(im1_path, net.device, imsize, with_gray=True) 115 | im2, grey2, sc2 = load_im_tensor(im2_path, net.device, imsize, with_gray=True) 116 | upscale = np.array([sc1 + sc2]) 117 | 118 | # Predict coarse matches 119 | coarse_matches = coarse_matcher(grey1, grey2) 120 | if coarse_only: 121 | coarse_all = upscale * coarse_matches.cpu().data.numpy() 122 | return coarse_all, None, None 123 | 124 | refined_matches, scores, coarse_matches = net.refine_matches(im1, im2, coarse_matches, io_thres) 125 | refined_matches = upscale * refined_matches 126 | coarse_matches = upscale * coarse_matches 127 | return refined_matches, scores, coarse_matches 128 | 129 | 130 | -------------------------------------------------------------------------------- /utils/train/eval_epoch_immatch.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | import os 3 | import numpy as np 4 | import time 5 | from utils.colmap.data_loading import load_model_ims 6 | from utils.eval.geometry import abs2relapose, pose2fund 7 | from transforms3d.quaternions import quat2mat 8 | from utils.eval.model_helper import estimate_matches 9 | from utils.eval.measure import check_inliers_distr, eval_matches_relapose, sampson_distance 10 | 11 | 12 | def eval_immatch_val_sets(net, data_root='data/immatch_benchmark/val_dense', 13 | ksize=2, eval_type='fine', io_thres=0.5, 14 | ncn_thres=0.0, imsize=1024, 15 | rthres=0.5, sample_max=300, min_overlap=0.3, 16 | lprint_=print): 17 | net.eval() 18 | np.random.seed(0) # Deterministic pairs for evaluation 19 | scenes = os.listdir(data_root) 20 | lprint_(f'\n>>Eval on immatch: rthres={rthres} eval_type={eval_type} ov<{min_overlap} ' 21 | f'nc={ncn_thres} ksize={ksize} io={io_thres} im={imsize}') 22 | 23 | errs = Namespace(qt=[], fdist=[], cdist=[], indist=[], irat=[], 24 | num_matches=[], num_inls=[], match_failed=[], geo_failed=[]) 25 | count = 0 26 | start_time = time.time() 27 | for scene in scenes: 28 | # Load scene ims and pairs 29 | model_dir = os.path.join(data_root, scene, 'dense/sparse') 30 | im_dir = os.path.join(data_root, scene, 'dense/images') 31 | ims = load_model_ims(model_dir) 32 | ov_pair_dict = np.load(os.path.join(model_dir,'ov_pairs.npy'), allow_pickle=True).item() 33 | pair_names = ov_pair_dict[min_overlap] 34 | total = len(pair_names) 35 | if total > sample_max: 36 | np.random.shuffle(pair_names) 37 | pair_names = pair_names[0:sample_max] 38 | 39 | for i, (im1_name, im2_name) in enumerate(pair_names): 40 | im1 = ims[im1_name] 41 | im2 = ims[im2_name] 42 | t, q = abs2relapose(im1.c, im2.c, im1.q, im2.q) 43 | R = quat2mat(q) 44 | F = pose2fund(im1.K, im2.K, R, t) 45 | 46 | # Compute matches 47 | im1_path = os.path.join(im_dir, im1_name) 48 | im2_path = os.path.join(im_dir, im2_name) 49 | count += 1 50 | try: 51 | matches, scores, coarse_matches = estimate_matches(net, im1_path, im2_path, 52 | ksize=ksize, 53 | ncn_thres=ncn_thres, 54 | eval_type=eval_type, 55 | io_thres=io_thres, 56 | imsize=imsize) 57 | except: 58 | errs.match_failed.append((im1_path, im2_path)) 59 | continue 60 | 61 | N = len(matches) 62 | cdists = sampson_distance(coarse_matches[:, 0:2], coarse_matches[:, 2:4], F) 63 | fdists = sampson_distance(matches[:, 0:2], matches[:, 2:4], F) 64 | errs.cdist.append(cdists) 65 | errs.fdist.append(fdists) 66 | errs.num_matches.append(N) 67 | 68 | try: 69 | # Eval relaposes 70 | terr, qerr, inls = eval_matches_relapose(matches, im1.K, im2.K, q, t, rthres) 71 | indists = fdists[inls] 72 | irat = len(inls) / N 73 | # print(f't={terr:.2f} q={qerr:.2f} inls={len(inls)}') 74 | except: 75 | errs.geo_failed.append((im1_path, im2_path)) 76 | continue 77 | errs.qt.append(max(terr, qerr)) 78 | errs.irat.append(irat) 79 | errs.indist.append(indists) 80 | errs.num_inls.append(len(inls)) 81 | runtime = time.time() - start_time 82 | lprint_(f'Pairs {count} match_failed={len(errs.match_failed)} geo_failed={len(errs.geo_failed)} ' 83 | f'num_matches={np.mean(errs.num_matches):.2f} irat={ np.mean(errs.irat):.3f} time:{runtime:.2f}s') 84 | 85 | bins = [0, 1e-2, 1, 5, 10, 25, 50, 100, 2500, 1e5] 86 | cdist_print = check_inliers_distr(errs.cdist, bins=bins, tag='cdist') 87 | fdist_ratios, fdist_print = check_inliers_distr(errs.fdist, bins=bins, tag='fdist', return_ratios=True) 88 | indist_ratios, indist_print = check_inliers_distr(errs.indist, bins=bins, tag='indist', return_ratios=True) 89 | lprint_(cdist_print) 90 | lprint_(fdist_print) 91 | lprint_(indist_print) 92 | 93 | pass_rate = np.array([100.0*np.mean(np.array(errs.qt) < thre) for thre in range(1, 11, 1)]) 94 | qt_err_mean = np.mean(errs.qt) 95 | qt_err_med = np.median(errs.qt) 96 | lprint_('Pose err: qt_mean={:.2f}/{:.2f} qt<[1-10]deg:{}'.format(qt_err_mean, qt_err_med, pass_rate)) 97 | 98 | return qt_err_mean, pass_rate 99 | -------------------------------------------------------------------------------- /utils/train/helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | from argparse import Namespace 5 | import visdom 6 | from utils.common.visdom_helper import VisPlots 7 | from utils.common.setup_helper import * 8 | 9 | def save_ckpt(net, epoch, sav_dir, best_vals=None, last_ckpt=False, is_best=False, name=None): 10 | ckpt = {'last_epoch': epoch, 11 | 'best_vals' : best_vals, 12 | 'backbone' : net.backbone, 13 | 'feat_idx' : net.feat_idx, 14 | 'change_stride' : net.change_stride, 15 | 'regressor_config' : net.regressor_config, 16 | 'state_dict': net.state_dict(), 17 | 'optim': net.optimizer.state_dict()} 18 | 19 | if net.lr_scheduler: 20 | ckpt['lr_scheduler'] = net.lr_scheduler.state_dict() 21 | 22 | if last_ckpt: 23 | ckpt_name = 'last_ckpt.pth' 24 | elif is_best: 25 | ckpt_name = 'best_ckpt.pth' 26 | else: 27 | ckpt_name = 'ckpt_ep{}.pth'.format(epoch+1) 28 | 29 | # Overwrite name 30 | if name: 31 | ckpt_name = '{}.pth'.format(name) 32 | ckpt_path = os.path.join(sav_dir, ckpt_name) 33 | torch.save(ckpt, ckpt_path) 34 | 35 | def load_ckpt(ckpt_path, config, resume=False): 36 | # Fine matching dist and io , qt_err, pass_rate 37 | best_vals = [np.inf, 0.0, np.inf, 0.0] 38 | ckpt = load_weights(ckpt_path, config.device) 39 | 40 | if 'backbone' in ckpt: 41 | config.feat_idx = ckpt['feat_idx'] 42 | config.weights_dict = ckpt['state_dict'] 43 | config.backbone = ckpt['backbone'] 44 | config.regressor_config = ckpt['regressor_config'] 45 | if 'change_stride' in ckpt: 46 | config.change_stride = ckpt['change_stride'] 47 | 48 | if resume: 49 | config.start_epoch = ckpt['last_epoch'] + 1 50 | config.optim_config.optimizer_dict = ckpt['optim'] 51 | if 'lr_scheduler' in ckpt: 52 | config.optim_config.lr_scheduler_dict = ckpt['lr_scheduler'] 53 | 54 | if 'best_vals' in ckpt: 55 | if len(ckpt['best_vals']) == len(best_vals): 56 | best_vals = ckpt['best_vals'] 57 | 58 | else: 59 | # Only the pretrained weights 60 | config.weights_dict = ckpt 61 | return best_vals 62 | 63 | 64 | def init_model_config(args, lprint_): 65 | """This is a quick wrapper for model initialization 66 | Currently support method = patch2pix / ncnet. 67 | """ 68 | 69 | # Initialize model 70 | device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') 71 | regressor_config = Namespace(conv_dims=args.conv_dims, 72 | conv_kers=args.conv_kers, 73 | conv_strs=args.conv_strs, 74 | fc_dims=args.fc_dims, 75 | feat_comb=args.feat_comb, 76 | psize=args.psize, 77 | pshift = args.pshift, 78 | panc = args.panc, 79 | shared = args.shared) 80 | 81 | optim_config = Namespace(opt='adam', 82 | lr_init=args.lr_init, 83 | weight_decay=args.weight_decay, 84 | lr_decay=args.lr_decay, 85 | optimizer_dict=None, 86 | lr_scheduler_dict=None) 87 | 88 | config = Namespace(training=True, 89 | start_epoch=0, 90 | device=device, 91 | regr_batch=args.regr_batch, 92 | backbone=args.backbone, 93 | freeze_feat=args.freeze_feat, 94 | change_stride=args.change_stride, 95 | feat_idx=args.feat_idx, 96 | regressor_config=regressor_config, 97 | weights_dict=None, 98 | optim_config=optim_config) 99 | 100 | 101 | # Fine matching dist and io , qt_err, pass_rate 102 | best_vals = [np.inf, 0.0, np.inf, 0.0] 103 | if args.resume: 104 | # Continue training 105 | ckpt = os.path.join(args.out_dir, 'last_ckpt.pth') 106 | if os.path.exists(ckpt): 107 | args.ckpt = ckpt 108 | if args.pretrain: 109 | # Initialize with pretrained nc 110 | best_vals = load_ckpt(args.pretrain, config) 111 | lprint_('Load pretrained: {} vals: {}'.format(args.pretrain, best_vals)) 112 | if args.ckpt: 113 | # Load a specific model 114 | best_vals = load_ckpt(args.ckpt, config, resume=args.resume) 115 | lprint_('Load model: {} vals: {}'.format(args.ckpt, best_vals)) 116 | 117 | return config, best_vals 118 | 119 | def get_visdom_plots(prefix='train', env='main', server='localhost', port=9333): 120 | """Initialize Visdom plots following the pre-defined schema. 121 | Adapt train_plots Namespace if one needs to add/remove legends or plots. 122 | Args: 123 | - prefix: the name prefix will be add to the original name of each plot. 124 | - env: the name of visdom envirionment where plots will appear there. 125 | - server: the name of the host where visdom server is running. 126 | Make sure visdom service is running correctly on specified port and host. 127 | No visdom connection will be initialized if server is None. 128 | And the program gives dummy plots. 129 | """ 130 | if server is None: 131 | vis = None 132 | else: 133 | 134 | vis = visdom.Visdom(server='http://{}'.format(server), port=port) 135 | 136 | """Initialize visdom plots: 137 | plots format: Namespace([plot_name]=Namespace([plot_legends...])) 138 | """ 139 | plots = Namespace(pair_scores=Namespace(pos=None, neg=None), 140 | cls_ratios=Namespace(mpos_gt=None, mpos_pred=None, 141 | fpos_gt=None, fpos_pred=None), 142 | loss=Namespace(pair=None, nc=None, 143 | cls_mid=None, cls_fine=None, 144 | epi_mid=None, epi_fine=None), 145 | cls_mid=Namespace(rec=None, prec=None, spec=None, acc=None, f1=None), 146 | cls_fine=Namespace(rec=None, prec=None, spec=None, acc=None, f1=None), 147 | match_dist=Namespace(cmid_gt=None, mmid_gt=None, 148 | mfid_gt=None, ffid_gt=None, 149 | cmid_pred=None, mmid_pred=None, 150 | mfid_pred=None, ffid_pred=None), 151 | mem=Namespace(rss=None, vms=None, 152 | gpu_maloc=None, 153 | gpu_mres=None)) 154 | vis_plots = VisPlots(plots, vis, env=env, prefix=prefix) 155 | return vis_plots 156 | 157 | def plot_cls_metric(mpred, mgt, thres=0.5, plot=None): 158 | """ 159 | Args: 160 | - mpred: predicted probability(mask), torch tensor with shape (N,) 161 | - mgt: ground truth probability(mask), same shape as pred. 162 | """ 163 | if not plot: 164 | return 165 | 166 | try: 167 | Pgt = mgt 168 | Ngt = (mgt == 0).float() 169 | Pgt_num = Pgt.sum() 170 | Ngt_num = Ngt.sum() 171 | 172 | Ppred = (mpred > thres).float() 173 | Npred = (mpred <= thres).float() 174 | TP = (Ppred * Pgt).sum() 175 | TN = (Npred * Ngt).sum() 176 | 177 | recall = (TP / Pgt_num).item() if Pgt_num > 0 else (1.0 if Ppred.sum() == 0.0 else 0) # Correct pred pos among all gt pos 178 | specifity = (TN / Ngt_num).item() if Ngt_num > 0 else (1.0 if Npred.sum() == 0.0 else 0) # Correct pred neg among all gt neg 179 | precision = (TP / Ppred.sum()).item() if Ppred.sum() > 0 else 0 # Correct pos among predicted pos 180 | accuracy = (Pgt == Ppred).float().mean().item() # Correct preds among all preds 181 | f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0 182 | 183 | # Update the plot 184 | plot.rec.append(recall) 185 | plot.spec.append(specifity) 186 | plot.prec.append(precision) 187 | plot.acc.append(accuracy) 188 | plot.f1.append(f1) 189 | except: 190 | print('Error happened during plot cls') 191 | print(f'mpred={mpred.shape} mgt={mgt.shape}') 192 | print(f'Pgt_num={Pgt_num} Ngt_num={Ngt_num}') 193 | return 194 | 195 | return Namespace(recall=recall, specifity=specifity, precision=precision, accuracy=accuracy, f1=f1) 196 | 197 | --------------------------------------------------------------------------------