├── .gitignore ├── INSTALL.md ├── LICENSE ├── README.md ├── datasets ├── contents.txt ├── hpatches │ └── cache-top │ │ ├── d2-net.npy │ │ ├── delf.npy │ │ ├── hesaff.npy │ │ ├── hesaffnet.npy │ │ ├── ncnet.densencnet_1600_hard_2k.npy │ │ ├── ncnet.densencnet_800_noreloc_2k.npy │ │ ├── ncnet.sparsencnet_1600_hard_2k.npy │ │ ├── ncnet.sparsencnet_1600_hard_soft_2k.npy │ │ ├── ncnet.sparsencnet_3200_hard_soft_1k.npy │ │ ├── ncnet.sparsencnet_3200_hard_soft_2k.npy │ │ ├── ncnet.sparsencnet_800_noreloc_2k.npy │ │ ├── r2d2.npy │ │ └── superpoint.npy └── inloc │ └── shortlists │ ├── densePE_top100_shortlist_cvpr18.mat │ ├── densePV_top10_shortlist_cvpr18.mat │ ├── ncnet_shortlist_neurips18.mat │ ├── sparsencnet_shortlist_1600_hard.mat │ └── sparsencnet_shortlist_3200_hard_soft.mat ├── demo ├── aachen_db.jpg ├── aachen_query.jpg └── demo.ipynb ├── eval ├── eval_aachen_extract.py ├── eval_aachen_reconstruct.py ├── eval_hpatches_extract.py ├── eval_hpatches_generate_plot.ipynb ├── eval_inloc_compute_poses.m ├── eval_inloc_extract.py └── eval_inloc_generate_plot.m ├── eval_ncnetdense ├── eval_hpatches_ncnetdense_extract.py └── eval_inloc_ncnetdense_extract.py ├── lib ├── conv4d.py ├── dataloader.py ├── eval_util.py ├── im_pair_dataset.py ├── knn.py ├── model.py ├── normalization.py ├── plot.py ├── point_tnf.py ├── py_util.py ├── relocalize.py ├── sparse.py ├── torch_util.py └── transformation.py ├── lib_matlab ├── at_imageresize_nc4d.m ├── at_pv_wrapper.m ├── ht_plotcurve_WUSTL.m ├── ht_top10_NC4D_PV_localization.m ├── ir_top100_NC4D_localization_pnponly.m ├── p2c.m ├── p2dist.m ├── parfor_NC4D_PE_pnponly.m ├── parfor_nc4d_PV.m └── show_matches2_horizontal.m ├── train.py └── trained_models ├── .gitignore └── download.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | **/.DS_Store 4 | matches/ 5 | /*.ipynb 6 | datasets/Aachen-Day-Night 7 | datasets/hpatches/hpatches-sequences-release 8 | datasets/ivd 9 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | This project depends on the following main libraries: `pytorch>=1.3`, `faiss-gpu>=1.4` and `MinkowskiEngine>=0.4`. You will need a full installation of CUDA 10.1.243 in order to compile MinkowskiEngine (note that the installation can be performed on any directory by specifying a custom path. Your chosen path needs to be specified with the CUDA_HOME environment variable). 4 | 5 | One possible way to install these three libraries (and some other relevant ones) on Linux-64 through Anaconda and Pip is by using the following commands: 6 | 7 | ``` 8 | # create new environment and install compiler and other tools 9 | conda create -n sparsencnet python=3.6.9=h265db76_0 10 | conda activate sparsencnet 11 | conda install gcc-5 -c psi4 12 | conda install numpy openblas 13 | conda install libstdcxx-ng -c anaconda 14 | 15 | # set environment variables for the compilation of MinkowskiEngine 16 | export CUDA_HOME=/your_path_to/cuda-10.1.243 17 | export LD_LIBRARY_PATH="${CUDA_HOME}/lib64":"${CONDA_PREFIX}/lib":"/usr/lib/x86_64-linux-gnu/" 18 | export PATH="${CONDA_PREFIX}/bin":"${CUDA_HOME}/bin":/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin 19 | export CPP="${CONDA_PREFIX}/bin/g++ -E" 20 | export CXX="${CONDA_PREFIX}/bin/g++" 21 | export LIBRARY_PATH=$LD_LIBRARY_PATH 22 | export PYTHONPATH="${CONDA_PREFIX}/lib/python3.6/site-packages/" 23 | 24 | # install PyTorch and ME 25 | pip install torch torchvision 26 | pip install -U MinkowskiEngine # compilation may take a while 27 | 28 | # download cuda8 runtime libs which are required by faiss-gpu. As the dependencies for the faiss-gpu package are broken we do this manually. 29 | conda install https://repo.anaconda.com/pkgs/free/linux-64/cudatoolkit-8.0-3.tar.bz2 30 | conda install --force --no-deps faiss-gpu=1.4.0=py36_cuda8.0.61_1 -c pytorch 31 | 32 | # install some additional libraries 33 | conda install matplotlib scikit-image pandas 34 | 35 | # replace pillow with pillow-simd 36 | pip uninstall pillow 37 | CC="cc -mavx2" pip install -U --force-reinstall pillow-simd 38 | 39 | # install jupyter lab for evaluation on HPatches-Seq 40 | conda install -c conda-forge jupyterlab 41 | 42 | ``` 43 | 44 | With this newly created environment, you can now clone this repo and start using it. 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ignacio Rocco 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse Neighbourhood Consensus Networks 2 | 3 | ![](https://www.di.ens.fr/willow/research/sparse-ncnet/images/teaser.jpg) 4 | 5 | ## About 6 | This is the implementation of the paper "Efficient Neighbourhood Consensus Networks via Submanifold Sparse Convolutions" by Ignacio Rocco, Relja Arandjelović and Josef Sivic, accepted to ECCV 2020 [[arXiv](https://arxiv.org/abs/2004.10566)]. 7 | 8 | ## Installation 9 | For installation instructions, please see [INSTALL.md](INSTALL.md). 10 | 11 | ## Quickstart 12 | For a demo of the method, see the Jupyter notebook [`demo/demo.ipynb`](demo/demo.ipynb). 13 | 14 | ## Training 15 | To train a model with the default parameters run `python train.py`. 16 | 17 | ## Evaluation on HPatches Sequences 18 | 1. Browse to `eval/`. 19 | 2. Run `python eval_hpatches_extract.py` adjusting the checkpoint and experiment name. 20 | 3. Use `eval_hpatches_generate_plot.ipynb` with the appropriate experiment name to generate the plot. 21 | 22 | ## Evaluation on InLoc 23 | In order to run the InLoc evaluation, you first need to clone the [InLoc demo repo](https://github.com/HajimeTaira/InLoc_demo), and download and compile all the required depedencies. Then: 24 | 25 | 1. Browse to `eval/`. 26 | 2. Run `python eval_inloc_extract.py` adjusting the checkpoint and experiment name. 27 | This will generate a series of matches files in the `datasets/inloc/matches/` folder that then need to be fed to the InLoc evaluation Matlab code. 28 | 3. Modify the `eval/eval_inloc_compute_poses.m` file provided in this repo to indicate the path of the InLoc demo repo, and the name of the experiment (the particular folder name inside `datasets/inloc/matches/`), and run it using Matlab. 29 | 4. Use the `eval/eval_inloc_generate_plot.m` file to plot the results from shortlist file generated in the previous stage: `/your_path_to/InLoc_demo_old/experiment_name/shortlist_densePV.mat`. Precomputed shortlist files are provided in `datasets/inloc/shortlist`. 30 | 31 | ## Evaluation on Aachen Day-Night 32 | In order to run the Aachen Day-Night evaluation, you first need to clone the [Visualization benchmark repo](https://github.com/tsattler/visuallocalizationbenchmark), and download and compile [all the required depedencies](https://github.com/tsattler/visuallocalizationbenchmark/tree/master/local_feature_evaluation) (in particular, you'll need to compile Colmap if you have not done so yet). Then: 33 | 34 | 1. Browse to `eval/`. 35 | 2. Run `python eval_aachen_extract.py` adjusting the checkpoint and experiment name. 36 | 3. Copy the `eval_aachen_reconstruct.py` file to `visuallocalizationbenchmark/local_feature_evaluation` and run it in the following way: 37 | 38 | ``` 39 | python eval_aachen_reconstruct.py 40 | --dataset_path /path_to_aachen/aachen 41 | --colmap_path /local/colmap/build/src/exe 42 | --method_name experiment_name 43 | ``` 44 | 4. Upload the file `/path_to_aachen/aachen/Aachen_eval_[experiment_name].txt` to `https://www.visuallocalization.net/` to get the results on this benchmark. 45 | 46 | ## BibTeX 47 | 48 | If you use this code in your project, please cite our paper: 49 | 50 | ```` 51 | @inproceedings{Rocco20, 52 | author = "Rocco, I. and Arandjelovi\'c, R. and Sivic, J.", 53 | title = "Efficient Neighbourhood Consensus Networks via Submanifold Sparse Convolutions", 54 | booktitle = "European Conference on Computer Vision", 55 | year = 2020, 56 | } 57 | ```` 58 | 59 | -------------------------------------------------------------------------------- /datasets/contents.txt: -------------------------------------------------------------------------------- 1 | This dataset should contain the following folders (or symlinks to them) 2 | 3 | ./ivd: Indoor-venue dataset (for training) 4 | ./hpatches/hpatches-sequences-release - HPatches dataset (for evaluation) 5 | ./hpatches/caches - precomputed curves for evaluation on HPatches 6 | ./Aachen-Day-Night - Aachen Day-Night dataset (for evaluation) 7 | ./inloc: InLoc dataset (for evaluation) -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/d2-net.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/d2-net.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/delf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/delf.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/hesaff.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/hesaff.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/hesaffnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/hesaffnet.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.densencnet_1600_hard_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.densencnet_1600_hard_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.densencnet_800_noreloc_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.densencnet_800_noreloc_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.sparsencnet_1600_hard_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.sparsencnet_1600_hard_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.sparsencnet_1600_hard_soft_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.sparsencnet_1600_hard_soft_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.sparsencnet_3200_hard_soft_1k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.sparsencnet_3200_hard_soft_1k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.sparsencnet_3200_hard_soft_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.sparsencnet_3200_hard_soft_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/ncnet.sparsencnet_800_noreloc_2k.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/ncnet.sparsencnet_800_noreloc_2k.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/r2d2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/r2d2.npy -------------------------------------------------------------------------------- /datasets/hpatches/cache-top/superpoint.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/hpatches/cache-top/superpoint.npy -------------------------------------------------------------------------------- /datasets/inloc/shortlists/densePE_top100_shortlist_cvpr18.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/inloc/shortlists/densePE_top100_shortlist_cvpr18.mat -------------------------------------------------------------------------------- /datasets/inloc/shortlists/densePV_top10_shortlist_cvpr18.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/inloc/shortlists/densePV_top10_shortlist_cvpr18.mat -------------------------------------------------------------------------------- /datasets/inloc/shortlists/ncnet_shortlist_neurips18.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/inloc/shortlists/ncnet_shortlist_neurips18.mat -------------------------------------------------------------------------------- /datasets/inloc/shortlists/sparsencnet_shortlist_1600_hard.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/inloc/shortlists/sparsencnet_shortlist_1600_hard.mat -------------------------------------------------------------------------------- /datasets/inloc/shortlists/sparsencnet_shortlist_3200_hard_soft.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/datasets/inloc/shortlists/sparsencnet_shortlist_3200_hard_soft.mat -------------------------------------------------------------------------------- /demo/aachen_db.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/demo/aachen_db.jpg -------------------------------------------------------------------------------- /demo/aachen_query.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ignacio-rocco/sparse-ncnet/45866540bfb49e413b1feeb45c2cd4d3c276923b/demo/aachen_query.jpg -------------------------------------------------------------------------------- /eval/eval_aachen_extract.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import os 7 | from os.path import exists, join, basename 8 | from collections import OrderedDict 9 | 10 | import sys 11 | sys.path.append('..') 12 | 13 | from lib.model import ImMatchNet, MutualMatching 14 | from lib.normalization import imreadth, resize, normalize 15 | from lib.torch_util import str_to_bool 16 | from lib.point_tnf import normalize_axis,unnormalize_axis,corr_to_matches 17 | from lib.sparse import get_matches_both_dirs, torch_to_me, me_to_torch 18 | from lib.relocalize import relocalize, relocalize_soft, eval_model_reloc 19 | 20 | import numpy as np 21 | import numpy.random 22 | from scipy.io import loadmat 23 | from scipy.io import savemat 24 | 25 | import argparse 26 | 27 | print('Sparse-NCNet evaluation script - Aachen dataset') 28 | 29 | use_cuda = torch.cuda.is_available() 30 | 31 | # Argument parsing 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--checkpoint', type=str, default='../trained_models/sparsencnet_k10.pth.tar') 34 | parser.add_argument('--aachen_path', type=str, default='../datasets/Aachen-Day-Night') 35 | parser.add_argument('--k_size', type=int, default=1) 36 | parser.add_argument('--image_size', type=int, default=3200) 37 | parser.add_argument('--experiment_name', type=str, default='sparsencnet_3200_hard') 38 | parser.add_argument('--symmetric_mode', type=str_to_bool, default=True) 39 | parser.add_argument('--nchunks', type=int, default=1) 40 | parser.add_argument('--chunk_idx', type=int, default=0) 41 | parser.add_argument('--skip_up_to', type=str, default='') 42 | parser.add_argument('--relocalize', type=int, default=1) 43 | parser.add_argument('--reloc_type', type=str, default='hard') 44 | parser.add_argument('--change_stride', type=int, default=1) 45 | parser.add_argument('--benchmark', type=int, default=0) 46 | parser.add_argument('--no_ncnet', type=int, default=0) 47 | parser.add_argument('--Npts', type=int, default=8000) 48 | parser.add_argument('--image_pairs', type=str, default='all') 49 | 50 | args = parser.parse_args() 51 | 52 | print(args) 53 | 54 | chp_args = torch.load(args.checkpoint)['args'] 55 | model = ImMatchNet(use_cuda=use_cuda, 56 | checkpoint=args.checkpoint, 57 | ncons_kernel_sizes=chp_args.ncons_kernel_sizes, 58 | ncons_channels=chp_args.ncons_channels, 59 | sparse=True, 60 | symmetric_mode=bool(chp_args.symmetric_mode), 61 | feature_extraction_cnn=chp_args.feature_extraction_cnn, 62 | bn=bool(chp_args.bn), 63 | k=chp_args.k, 64 | return_fs=True) 65 | 66 | scale_factor = 0.0625 67 | if args.relocalize==1: 68 | scale_factor = scale_factor/2 69 | if args.change_stride==1: 70 | scale_factor = scale_factor*2 71 | elif args.change_stride==2: 72 | scale_factor = scale_factor*4 73 | 74 | if args.change_stride>=1: 75 | model.FeatureExtraction.model[-1][0].conv1.stride=(1,1) 76 | model.FeatureExtraction.model[-1][0].conv2.stride=(1,1) 77 | model.FeatureExtraction.model[-1][0].downsample[0].stride=(1,1) 78 | if args.change_stride>=2: 79 | model.FeatureExtraction.model[-2][0].conv1.stride=(1,1) 80 | model.FeatureExtraction.model[-2][0].conv2.stride=(1,1) 81 | model.FeatureExtraction.model[-2][0].downsample[0].stride=(1,1) 82 | 83 | try: 84 | os.mkdir(os.path.join(args.aachen_path,'matches')) 85 | except FileExistsError: 86 | pass 87 | 88 | try: 89 | os.mkdir(os.path.join(args.aachen_path,'matches',args.experiment_name)) 90 | except FileExistsError: 91 | pass 92 | 93 | # Get shortlists for each query image 94 | if args.image_pairs=='all': 95 | pair_names_fn = os.path.join(args.aachen_path,'image_pairs_to_match.txt') 96 | elif args.image_pairs=='queries': 97 | pair_names_fn = os.path.join(args.aachen_path,'query_pairs_to_match.txt') 98 | 99 | with open(pair_names_fn) as f: 100 | pair_names = [line.rstrip('\n') for line in f] 101 | 102 | pair_names=np.array(pair_names) 103 | pair_names_split = np.array_split(pair_names,args.nchunks) 104 | pair_names_chunk = pair_names_split[args.chunk_idx] 105 | 106 | pair_names_chunk=list(pair_names_chunk) 107 | if args.skip_up_to!='': 108 | pair_names_chunk = pair_names_chunk[pair_names_chunk.index(args.skip_up_to)+1:] 109 | 110 | if args.benchmark: 111 | start = torch.cuda.Event(enable_timing=True) 112 | match = torch.cuda.Event(enable_timing=True) 113 | reloc = torch.cuda.Event(enable_timing=True) 114 | end = torch.cuda.Event(enable_timing=True) 115 | pair_names_chunk = [pair_names_chunk[0]] 116 | indices = [2 for i in range(21)] 117 | first_iter=True 118 | else: 119 | indices = range(2, 7) 120 | 121 | for pair in pair_names_chunk: 122 | src_fn = os.path.join(args.aachen_path,'images','images_upright',pair.split(' ')[0]) 123 | src=imreadth(src_fn) 124 | hA,wA=src.shape[-2:] 125 | src=resize(normalize(src), args.image_size, scale_factor) 126 | hA_,wA_=src.shape[-2:] 127 | 128 | tgt_fn = os.path.join(args.aachen_path,'images','images_upright',pair.split(' ')[1]) 129 | tgt=imreadth(tgt_fn) 130 | hB,wB=tgt.shape[-2:] 131 | tgt=resize(normalize(tgt), args.image_size, scale_factor) 132 | hB_,wB_=tgt.shape[-2:] 133 | 134 | if args.benchmark: 135 | start.record() 136 | 137 | with torch.no_grad(): 138 | if args.benchmark: 139 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4, fe_time, cnn_time = eval_model_reloc( 140 | model, 141 | {'source_image':src, 142 | 'target_image':tgt}, 143 | args 144 | ) 145 | else: 146 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4 = eval_model_reloc( 147 | model, 148 | {'source_image':src, 149 | 'target_image':tgt}, 150 | args 151 | ) 152 | 153 | delta4d=None 154 | if args.benchmark: 155 | match.record() 156 | 157 | xA_, yA_, xB_, yB_, score_ = get_matches_both_dirs(corr4d, fs1, fs2, fs3, fs4) 158 | 159 | if args.Npts is not None: 160 | matches_idx_sorted = torch.argsort(-score_.view(-1)) 161 | N_matches = min(args.Npts, matches_idx_sorted.shape[0]) 162 | matches_idx_sorted = matches_idx_sorted[:N_matches] 163 | score_ = score_[:,matches_idx_sorted] 164 | xA_ = xA_[:,matches_idx_sorted] 165 | yA_ = yA_[:,matches_idx_sorted] 166 | xB_ = xB_[:,matches_idx_sorted] 167 | yB_ = yB_[:,matches_idx_sorted] 168 | 169 | if args.benchmark: 170 | reloc.record() 171 | 172 | if args.relocalize: 173 | if args.reloc_type=='hard': 174 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 175 | elif args.reloc_type=='hard_soft': 176 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 177 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=False) 178 | elif args.reloc_type=='soft': 179 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=True) 180 | elif args.reloc_type=='hard_hard': 181 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 182 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=False) 183 | 184 | fs1,fs2,fs3,fs4=2*fs1,2*fs2,2*fs3,2*fs4 185 | 186 | if args.benchmark: 187 | end.record() 188 | torch.cuda.synchronize() 189 | total_time = start.elapsed_time(end)/1000 190 | processing_time = start.elapsed_time(match)/1000 191 | match_processing_time = match.elapsed_time(reloc)/1000 192 | reloc_processing_time = reloc.elapsed_time(end)/1000 193 | max_mem = torch.cuda.max_memory_allocated()/1024/1024 194 | if first_iter: 195 | first_iter=False 196 | ttime = [] 197 | mmem = [] 198 | else: 199 | ttime.append(total_time) 200 | mmem.append(max_mem) 201 | print('fe: {:.2f}, cnn: {:.2f}, pp: {:.2f}, reloc: {:.2f}, total: {:.2f}, max mem: {:.2f}MB'.format(fe_time, cnn_time, 202 | match_processing_time, 203 | reloc_processing_time, 204 | total_time, 205 | max_mem)) 206 | 207 | YA,XA=torch.meshgrid(torch.arange(fs1),torch.arange(fs2)) 208 | YB,XB=torch.meshgrid(torch.arange(fs3),torch.arange(fs4)) 209 | 210 | YA = YA.contiguous() 211 | XA = XA.contiguous() 212 | YB = YB.contiguous() 213 | XB = XB.contiguous() 214 | 215 | YA=(YA+0.5)/(fs1)*hA 216 | XA=(XA+0.5)/(fs2)*wA 217 | YB=(YB+0.5)/(fs3)*hB 218 | XB=(XB+0.5)/(fs4)*wB 219 | 220 | XA = XA.view(-1).data.cpu().float().numpy() 221 | YA = YA.view(-1).data.cpu().float().numpy() 222 | XB = XB.view(-1).data.cpu().float().numpy() 223 | YB = YB.view(-1).data.cpu().float().numpy() 224 | 225 | keypoints_A=np.stack((XA,YA),axis=1) 226 | keypoints_B=np.stack((XB,YB),axis=1) 227 | 228 | # idx_A = (yA_*fs2+xA_).long().view(-1,1) 229 | # idx_B = (yB_*fs4+xB_).long().view(-1,1) 230 | idx_A = (yA_*fs2+xA_).view(-1,1) 231 | idx_B = (yB_*fs4+xB_).view(-1,1) 232 | score = score_.view(-1,1) 233 | 234 | matches = torch.cat((idx_A,idx_B,score),dim=1).cpu().numpy() 235 | 236 | kp_A_fn = src_fn+'.'+args.experiment_name 237 | kp_B_fn = tgt_fn+'.'+args.experiment_name 238 | 239 | if not args.benchmark and not os.path.exists(kp_A_fn): 240 | with open(kp_A_fn, 'wb') as output_file: 241 | np.savez(output_file,keypoints=keypoints_A) 242 | 243 | if not args.benchmark and not os.path.exists(kp_B_fn): 244 | with open(kp_B_fn, 'wb') as output_file: 245 | np.savez(output_file,keypoints=keypoints_B) 246 | 247 | matches_fn = pair.replace('/','-').replace(' ','--')+'.'+args.experiment_name 248 | matches_path = os.path.join(args.aachen_path,'matches',args.experiment_name,matches_fn) 249 | 250 | if not args.benchmark: 251 | with open(matches_path, 'wb') as output_file: 252 | np.savez(output_file,matches=matches) 253 | print(matches_fn) 254 | 255 | del corr4d,delta4d,src,tgt, feature_A_2x, feature_B_2x 256 | del xA_,xB_,yA_,yB_,score_ 257 | torch.cuda.empty_cache() 258 | torch.cuda.reset_max_memory_allocated() 259 | 260 | if args.benchmark: 261 | print('{}x{},{:.4f},{:.4f}'.format( 262 | wA_, 263 | hA_, 264 | torch.tensor(ttime).mean(), 265 | torch.tensor(mmem).mean())) -------------------------------------------------------------------------------- /eval/eval_aachen_reconstruct.py: -------------------------------------------------------------------------------- 1 | # Code adapted by Ignacio Rocco from `reconstruction_pipeline.py` from https://github.com/tsattler/visuallocalizationbenchmark by Torsten Sattler and Mihai Dusmanu 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import os 8 | 9 | import shutil 10 | 11 | import subprocess 12 | 13 | import sqlite3 14 | 15 | import torch 16 | 17 | import types 18 | 19 | from tqdm import tqdm 20 | 21 | from camera import Camera 22 | 23 | from utils import quaternion_to_rotation_matrix, camera_center_to_translation 24 | 25 | from matchers import mutual_nn_matcher 26 | 27 | import sys 28 | IS_PYTHON3 = sys.version_info[0] >= 3 29 | 30 | COLMAP_VER = subprocess.check_output(['colmap', 'help']).decode("utf-8")[7:10] 31 | 32 | def array_to_blob(array): 33 | if IS_PYTHON3: 34 | return array.tostring() 35 | else: 36 | return np.getbuffer(array) 37 | 38 | def recover_database_images_and_ids(paths, args): 39 | # Connect to the database. 40 | connection = sqlite3.connect(paths.database_path) 41 | cursor = connection.cursor() 42 | 43 | # Recover database images and ids. 44 | images = {} 45 | cameras = {} 46 | cursor.execute("SELECT name, image_id, camera_id FROM images;") 47 | for row in cursor: 48 | images[row[0]] = row[1] 49 | cameras[row[0]] = row[2] 50 | 51 | # Close the connection to the database. 52 | cursor.close() 53 | connection.close() 54 | 55 | return images, cameras 56 | 57 | 58 | def preprocess_reference_model(paths, args): 59 | print('Preprocessing the reference model...') 60 | 61 | # Recover intrinsics. 62 | with open(os.path.join(paths.reference_model_path, 'database_intrinsics.txt')) as f: 63 | raw_intrinsics = f.readlines() 64 | 65 | camera_parameters = {} 66 | 67 | for intrinsics in raw_intrinsics: 68 | intrinsics = intrinsics.strip('\n').split(' ') 69 | 70 | image_name = intrinsics[0] 71 | 72 | camera_model = intrinsics[1] 73 | 74 | intrinsics = [float(param) for param in intrinsics[2 :]] 75 | 76 | camera = Camera() 77 | camera.set_intrinsics(camera_model=camera_model, intrinsics=intrinsics) 78 | 79 | camera_parameters[image_name] = camera 80 | 81 | # Recover poses. 82 | with open(os.path.join(paths.reference_model_path, 'aachen_cvpr2018_db.nvm')) as f: 83 | raw_extrinsics = f.readlines() 84 | 85 | # Skip the header. 86 | n_cameras = int(raw_extrinsics[2]) 87 | raw_extrinsics = raw_extrinsics[3 : 3 + n_cameras] 88 | 89 | for extrinsics in raw_extrinsics: 90 | extrinsics = extrinsics.strip('\n').split(' ') 91 | 92 | image_name = extrinsics[0] 93 | 94 | # Skip the focal length. Skip the distortion and terminal 0. 95 | qw, qx, qy, qz, cx, cy, cz = [float(param) for param in extrinsics[2 : -2]] 96 | 97 | qvec = np.array([qw, qx, qy, qz]) 98 | c = np.array([cx, cy, cz]) 99 | 100 | # NVM -> COLMAP. 101 | t = camera_center_to_translation(c, qvec) 102 | 103 | camera_parameters[image_name].set_pose(qvec=qvec, t=t) 104 | 105 | return camera_parameters 106 | 107 | 108 | def generate_empty_reconstruction(images, cameras, camera_parameters, paths, args): 109 | print('Generating the empty reconstruction...') 110 | 111 | if not os.path.exists(paths.empty_model_path): 112 | os.mkdir(paths.empty_model_path) 113 | 114 | with open(os.path.join(paths.empty_model_path, 'cameras.txt'), 'w') as f: 115 | for image_name in images: 116 | image_id = images[image_name] 117 | camera_id = cameras[image_name] 118 | try: 119 | camera = camera_parameters[image_name] 120 | except: 121 | continue 122 | f.write('%d %s %s\n' % ( 123 | camera_id, 124 | camera.camera_model, 125 | ' '.join(map(str, camera.intrinsics)) 126 | )) 127 | 128 | with open(os.path.join(paths.empty_model_path, 'images.txt'), 'w') as f: 129 | for image_name in images: 130 | image_id = images[image_name] 131 | camera_id = cameras[image_name] 132 | try: 133 | camera = camera_parameters[image_name] 134 | except: 135 | continue 136 | f.write('%d %s %s %d %s\n\n' % ( 137 | image_id, 138 | ' '.join(map(str, camera.qvec)), 139 | ' '.join(map(str, camera.t)), 140 | camera_id, 141 | image_name 142 | )) 143 | 144 | with open(os.path.join(paths.empty_model_path, 'points3D.txt'), 'w') as f: 145 | pass 146 | 147 | 148 | def import_features(images, paths, args): 149 | # Connect to the database. 150 | connection = sqlite3.connect(paths.database_path) 151 | cursor = connection.cursor() 152 | 153 | # Import the features. 154 | print('Importing features...') 155 | 156 | for image_name, image_id in tqdm(images.items(), total=len(images.items())): 157 | try: 158 | features_path = os.path.join(paths.image_path, '%s.%s' % (image_name, args.method_name.split('--N_')[0].split('_')[0])) 159 | 160 | keypoints = np.load(features_path)['keypoints'] 161 | n_keypoints = keypoints.shape[0] 162 | 163 | # Keep only x, y coordinates. 164 | keypoints = keypoints[:, : 2] 165 | # Add placeholder scale, orientation. 166 | keypoints = np.concatenate([keypoints, np.ones((n_keypoints, 1)), np.zeros((n_keypoints, 1))], axis=1).astype(np.float32) 167 | 168 | keypoints_str = keypoints.tostring() 169 | cursor.execute("INSERT INTO keypoints(image_id, rows, cols, data) VALUES(?, ?, ?, ?);", 170 | (image_id, keypoints.shape[0], keypoints.shape[1], keypoints_str)) 171 | connection.commit() 172 | except: 173 | print('skipping {}'.format(image_name)) 174 | pass 175 | 176 | # Close the connection to the database. 177 | cursor.close() 178 | connection.close() 179 | 180 | 181 | def image_ids_to_pair_id(image_id1, image_id2): 182 | if image_id1 > image_id2: 183 | return 2147483647 * image_id2 + image_id1 184 | else: 185 | return 2147483647 * image_id1 + image_id2 186 | 187 | 188 | def import_matches(images, paths, args): 189 | # Connect to the database. 190 | connection = sqlite3.connect(paths.database_path) 191 | cursor = connection.cursor() 192 | 193 | # Match the features and insert the matches in the database. 194 | print('Matching...') 195 | 196 | with open(paths.match_list_path, 'r') as f: 197 | raw_pairs = f.readlines() 198 | 199 | image_pair_ids = set() 200 | for raw_pair in tqdm(raw_pairs, total=len(raw_pairs)): 201 | # for raw_pair in raw_pairs: 202 | image_name1, image_name2 = raw_pair.strip('\n').split(' ') 203 | 204 | matches_fn = image_name1.replace('/','-')+'--'+image_name2.replace('/','-')+'.'+args.method_name.split('--N_')[0] 205 | try: 206 | matches = np.load(os.path.join(paths.matches_path.split('--N_')[0],matches_fn))['matches'] 207 | matches=matches[:,:2].astype(np.uint32) 208 | if args.N_matches>0: 209 | matches=matches[:args.N_matches,:] 210 | 211 | image_id1, image_id2 = images[image_name1], images[image_name2] 212 | image_pair_id = image_ids_to_pair_id(image_id1, image_id2) 213 | if image_pair_id in image_pair_ids: 214 | continue 215 | image_pair_ids.add(image_pair_id) 216 | 217 | if image_id1 > image_id2: 218 | matches = matches[:, [1, 0]] 219 | 220 | matches_str = matches.tostring() 221 | cursor.execute("INSERT INTO matches(pair_id, rows, cols, data) VALUES(?, ?, ?, ?);", 222 | (image_pair_id, matches.shape[0], matches.shape[1], matches_str)) 223 | connection.commit() 224 | # print('adding {}'.format(matches_fn)) 225 | except: 226 | print('skipping {}'.format(matches_fn)) 227 | pass 228 | 229 | # Close the connection to the database. 230 | cursor.close() 231 | connection.close() 232 | 233 | 234 | def geometric_verification(paths, args): 235 | print('Running geometric verification...') 236 | 237 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'matches_importer', 238 | '--database_path', paths.database_path, 239 | '--match_list_path', paths.match_list_path, 240 | '--match_type', 'pairs']) 241 | 242 | 243 | def reconstruct(paths, args): 244 | if not os.path.isdir(paths.database_model_path): 245 | os.mkdir(paths.database_model_path) 246 | 247 | # Reconstruct the database model. 248 | if COLMAP_VER == '3.4': 249 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'point_triangulator', 250 | '--database_path', paths.database_path, 251 | '--image_path', paths.image_path, 252 | '--import_path', paths.empty_model_path, 253 | '--export_path', paths.database_model_path, 254 | '--Mapper.ba_refine_focal_length', '0', 255 | '--Mapper.ba_refine_principal_point', '0', 256 | '--Mapper.ba_refine_extra_params', '0']) 257 | else: 258 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'point_triangulator', 259 | '--database_path', paths.database_path, 260 | '--image_path', paths.image_path, 261 | '--input_path', paths.empty_model_path, 262 | '--output_path', paths.database_model_path, 263 | '--Mapper.ba_refine_focal_length', '0', 264 | '--Mapper.ba_refine_principal_point', '0', 265 | '--Mapper.ba_refine_extra_params', '0']) 266 | 267 | 268 | def register_queries(paths, args): 269 | if not os.path.isdir(paths.final_model_path): 270 | os.mkdir(paths.final_model_path) 271 | 272 | # Register the query images. 273 | if COLMAP_VER == '3.4': 274 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'image_registrator', 275 | '--database_path', paths.database_path, 276 | '--import_path', paths.database_model_path, 277 | '--export_path', paths.final_model_path, 278 | '--Mapper.ba_refine_focal_length', '0', 279 | '--Mapper.ba_refine_principal_point', '0', 280 | '--Mapper.ba_refine_extra_params', '0']) 281 | else: 282 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'image_registrator', 283 | '--database_path', paths.database_path, 284 | '--input_path', paths.database_model_path, 285 | '--output_path', paths.final_model_path, 286 | '--Mapper.ba_refine_focal_length', '0', 287 | '--Mapper.ba_refine_principal_point', '0', 288 | '--Mapper.ba_refine_extra_params', '0']) 289 | 290 | 291 | def recover_query_poses(paths, args): 292 | print('Recovering query poses...') 293 | 294 | if not os.path.isdir(paths.final_txt_model_path): 295 | os.mkdir(paths.final_txt_model_path) 296 | 297 | # Convert the model to TXT. 298 | subprocess.call([os.path.join(args.colmap_path, 'colmap'), 'model_converter', 299 | '--input_path', paths.final_model_path, 300 | '--output_path', paths.final_txt_model_path, 301 | '--output_type', 'TXT']) 302 | 303 | # Recover query names. 304 | query_image_list_path = os.path.join(args.dataset_path, 'queries/night_time_queries_with_intrinsics.txt') 305 | 306 | with open(query_image_list_path) as f: 307 | raw_queries = f.readlines() 308 | 309 | query_names = set() 310 | for raw_query in raw_queries: 311 | raw_query = raw_query.strip('\n').split(' ') 312 | query_name = raw_query[0] 313 | query_names.add(query_name) 314 | 315 | with open(os.path.join(paths.final_txt_model_path, 'images.txt')) as f: 316 | raw_extrinsics = f.readlines() 317 | 318 | f = open(paths.prediction_path, 'w') 319 | 320 | # Skip the header. 321 | for extrinsics in raw_extrinsics[4 :: 2]: 322 | extrinsics = extrinsics.strip('\n').split(' ') 323 | 324 | image_name = extrinsics[-1] 325 | 326 | if image_name in query_names: 327 | # Skip the IMAGE_ID ([0]), CAMERA_ID ([-2]), and IMAGE_NAME ([-1]). 328 | f.write('%s %s\n' % (image_name.split('/')[-1], ' '.join(extrinsics[1 : -2]))) 329 | 330 | f.close() 331 | 332 | 333 | if __name__ == "__main__": 334 | parser = argparse.ArgumentParser() 335 | parser.add_argument('--dataset_path', required=True, help='Path to the dataset') 336 | parser.add_argument('--colmap_path', required=True, help='Path to the COLMAP executable folder') 337 | parser.add_argument('--method_name', required=True, help='Name of the method') 338 | parser.add_argument('--N_matches', required=False, type=int, default=0, help='Number of matches') 339 | parser.add_argument('--N_db', required=False, type=int, default=0, help='Number of db images per query') 340 | args = parser.parse_args() 341 | 342 | # Torch settings for the matcher. 343 | use_cuda = torch.cuda.is_available() 344 | device = torch.device("cuda:0" if use_cuda else "cpu") 345 | 346 | pairs_file = 'image_pairs_to_match.txt' 347 | 348 | # Create the extra paths. 349 | paths = types.SimpleNamespace() 350 | if args.N_db>0: 351 | pairs_file = 'image_pairs_to_match_{}_top{}.txt'.format(args.method_name,args.N_db) 352 | print('Using {} pair file'.format(pairs_file)) 353 | args.method_name=args.method_name+'--N_db{}'.format(args.N_db) 354 | if args.N_matches>0: 355 | args.method_name=args.method_name+'--N_matches{}'.format(args.N_matches) 356 | 357 | paths.dummy_database_path = os.path.join(args.dataset_path, 'database.db') 358 | paths.database_path = os.path.join(args.dataset_path, args.method_name + '.db') 359 | paths.image_path = os.path.join(args.dataset_path, 'images', 'images_upright') 360 | paths.features_path = os.path.join(args.dataset_path, args.method_name) 361 | paths.reference_model_path = os.path.join(args.dataset_path, '3D-models') 362 | paths.match_list_path = os.path.join(args.dataset_path, pairs_file ) 363 | paths.empty_model_path = os.path.join(args.dataset_path, 'sparse-%s-empty' % args.method_name) 364 | paths.database_model_path = os.path.join(args.dataset_path, 'sparse-%s-database' % args.method_name) 365 | paths.final_model_path = os.path.join(args.dataset_path, 'sparse-%s-final' % args.method_name) 366 | paths.final_txt_model_path = os.path.join(args.dataset_path, 'sparse-%s-final-txt' % args.method_name) 367 | paths.prediction_path = os.path.join(args.dataset_path, 'Aachen_eval_[%s].txt' % args.method_name) 368 | paths.matches_path = os.path.join(args.dataset_path,'matches',args.method_name) 369 | 370 | # Create a copy of the dummy database. 371 | if os.path.exists(paths.database_path): 372 | raise FileExistsError('The database file %s already exists.' % paths.database_path) 373 | shutil.copyfile(paths.dummy_database_path, paths.database_path) 374 | 375 | # Reconstruction pipeline. 376 | camera_parameters = preprocess_reference_model(paths, args) 377 | images, cameras = recover_database_images_and_ids(paths, args) 378 | generate_empty_reconstruction(images, cameras, camera_parameters, paths, args) 379 | import_features(images, paths, args) 380 | import_matches(images, paths, args) 381 | geometric_verification(paths, args) 382 | reconstruct(paths, args) 383 | register_queries(paths, args) 384 | recover_query_poses(paths, args) 385 | -------------------------------------------------------------------------------- /eval/eval_hpatches_extract.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import os 7 | from os.path import exists, join, basename 8 | from collections import OrderedDict 9 | 10 | import sys 11 | sys.path.append('..') 12 | 13 | from lib.model import ImMatchNet, MutualMatching 14 | from lib.normalization import imreadth, resize, normalize 15 | from lib.torch_util import str_to_bool 16 | from lib.point_tnf import normalize_axis,unnormalize_axis,corr_to_matches 17 | from lib.sparse import get_matches_both_dirs, torch_to_me, me_to_torch, unique 18 | from lib.relocalize import relocalize, relocalize_soft, eval_model_reloc 19 | 20 | import numpy as np 21 | import numpy.random 22 | from scipy.io import loadmat 23 | from scipy.io import savemat 24 | 25 | import argparse 26 | 27 | print('Sparse-NCNet evaluation script - HPatches Sequences dataset') 28 | 29 | use_cuda = torch.cuda.is_available() 30 | 31 | # Argument parsing 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--checkpoint', type=str, default='../trained_models/sparsencnet_k10.pth.tar') 34 | parser.add_argument('--hseq_path', type=str, default='../datasets/hpatches/hpatches-sequences-release') 35 | parser.add_argument('--k_size', type=int, default=1) 36 | parser.add_argument('--image_size', type=int, default=3200) 37 | parser.add_argument('--experiment_name', type=str, default='sparsencnet_3200_hard_soft') 38 | parser.add_argument('--symmetric_mode', type=str_to_bool, default=True) 39 | parser.add_argument('--nchunks', type=int, default=1) 40 | parser.add_argument('--chunk_idx', type=int, default=0) 41 | parser.add_argument('--skip_up_to', type=str, default='') 42 | parser.add_argument('--relocalize', type=int, default=1) 43 | parser.add_argument('--reloc_type', type=str, default='hard_soft') 44 | parser.add_argument('--reloc_hard_crop_size', type=int, default=2) 45 | parser.add_argument('--change_stride', type=int, default=1) 46 | parser.add_argument('--benchmark', type=int, default=0) 47 | parser.add_argument('--no_ncnet', type=int, default=0) 48 | parser.add_argument('--Npts', type=int, default=2000) 49 | 50 | args = parser.parse_args() 51 | 52 | print(args) 53 | 54 | chp_args = torch.load(args.checkpoint)['args'] 55 | model = ImMatchNet(use_cuda=use_cuda, 56 | checkpoint=args.checkpoint, 57 | ncons_kernel_sizes=chp_args.ncons_kernel_sizes, 58 | ncons_channels=chp_args.ncons_channels, 59 | sparse=True, 60 | symmetric_mode=bool(chp_args.symmetric_mode), 61 | feature_extraction_cnn=chp_args.feature_extraction_cnn, 62 | bn=bool(chp_args.bn), 63 | k=chp_args.k, 64 | return_fs=True, 65 | change_stride=args.change_stride 66 | ) 67 | 68 | scale_factor = 0.0625 69 | if args.relocalize==1: 70 | scale_factor = scale_factor/2 71 | if args.change_stride==1: 72 | scale_factor = scale_factor*2 73 | 74 | # Get shortlists for each query image 75 | dataset_path=args.hseq_path 76 | seq_names = sorted(os.listdir(dataset_path)) 77 | 78 | seq_names=np.array(seq_names) 79 | seq_names_split = np.array_split(seq_names,args.nchunks) 80 | seq_names_chunk = seq_names_split[args.chunk_idx] 81 | 82 | seq_names_chunk=list(seq_names_chunk) 83 | if args.skip_up_to!='': 84 | seq_names_chunk = seq_names_chunk[seq_names_chunk.index(args.skip_up_to)+1:] 85 | 86 | if args.benchmark: 87 | start = torch.cuda.Event(enable_timing=True) 88 | match = torch.cuda.Event(enable_timing=True) 89 | reloc = torch.cuda.Event(enable_timing=True) 90 | end = torch.cuda.Event(enable_timing=True) 91 | seq_names_chunk = [seq_names_chunk[0]] 92 | indices = [2 for i in range(21)] 93 | first_iter=True 94 | else: 95 | indices = range(2, 7) 96 | 97 | for seq_name in seq_names_chunk: 98 | # load query image 99 | # load database image 100 | for idx in indices: 101 | src_fn = os.path.join(args.hseq_path,seq_name,'1.ppm') 102 | src=imreadth(src_fn) 103 | hA,wA=src.shape[-2:] 104 | src=resize(normalize(src), args.image_size, scale_factor) 105 | hA_,wA_=src.shape[-2:] 106 | 107 | tgt_fn = os.path.join(args.hseq_path,seq_name,'{}.ppm'.format(idx)) 108 | tgt=imreadth(tgt_fn) 109 | hB,wB=tgt.shape[-2:] 110 | tgt=resize(normalize(tgt), args.image_size, scale_factor) 111 | hB_,wB_=tgt.shape[-2:] 112 | 113 | if args.benchmark: 114 | start.record() 115 | 116 | with torch.no_grad(): 117 | if args.benchmark: 118 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4, fe_time, cnn_time = eval_model_reloc( 119 | model, 120 | {'source_image':src, 121 | 'target_image':tgt}, 122 | args 123 | ) 124 | else: 125 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4 = eval_model_reloc( 126 | model, 127 | {'source_image':src, 128 | 'target_image':tgt}, 129 | args 130 | ) 131 | 132 | delta4d=None 133 | 134 | if args.benchmark: 135 | match.record() 136 | 137 | xA_, yA_, xB_, yB_, score_ = get_matches_both_dirs(corr4d, fs1, fs2, fs3, fs4) 138 | 139 | if args.Npts is not None: 140 | matches_idx_sorted = torch.argsort(-score_.view(-1)) 141 | # if args.relocalize: 142 | # N_matches = min(int(args.Npts*1.25), matches_idx_sorted.shape[0]) 143 | # else: 144 | # N_matches = min(args.Npts, matches_idx_sorted.shape[0]) 145 | N_matches = min(args.Npts, matches_idx_sorted.shape[0]) 146 | matches_idx_sorted = matches_idx_sorted[:N_matches] 147 | score_ = score_[:,matches_idx_sorted] 148 | xA_ = xA_[:,matches_idx_sorted] 149 | yA_ = yA_[:,matches_idx_sorted] 150 | xB_ = xB_[:,matches_idx_sorted] 151 | yB_ = yB_[:,matches_idx_sorted] 152 | 153 | if args.benchmark: 154 | reloc.record() 155 | 156 | if args.relocalize: 157 | fs1,fs2,fs3,fs4=2*fs1,2*fs2,2*fs3,2*fs4 158 | # relocalization stage 1: 159 | if args.reloc_type.startswith('hard'): 160 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_, 161 | yA_, 162 | xB_, 163 | yB_, 164 | score_, 165 | feature_A_2x, 166 | feature_B_2x, 167 | crop_size=args.reloc_hard_crop_size) 168 | if args.reloc_hard_crop_size==3: 169 | _,uidx = unique(yA_.double()*fs2*fs3*fs4+xA_.double()*fs3*fs4+yB_.double()*fs4+xB_.double(),return_index=True) 170 | xA_=xA_[:,uidx] 171 | yA_=yA_[:,uidx] 172 | xB_=xB_[:,uidx] 173 | yB_=yB_[:,uidx] 174 | score_=score_[:,uidx] 175 | elif args.reloc_type=='soft': 176 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 177 | 178 | # relocalization stage 2: 179 | if args.reloc_type=='hard_soft': 180 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=False) 181 | 182 | elif args.reloc_type=='hard_hard': 183 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=False) 184 | 185 | yA_=(yA_+0.5)/(fs1) 186 | xA_=(xA_+0.5)/(fs2) 187 | yB_=(yB_+0.5)/(fs3) 188 | xB_=(xB_+0.5)/(fs4) 189 | 190 | if args.benchmark: 191 | end.record() 192 | torch.cuda.synchronize() 193 | total_time = start.elapsed_time(end)/1000 194 | processing_time = start.elapsed_time(match)/1000 195 | match_processing_time = match.elapsed_time(reloc)/1000 196 | reloc_processing_time = reloc.elapsed_time(end)/1000 197 | max_mem = torch.cuda.max_memory_allocated()/1024/1024 198 | if first_iter: 199 | first_iter=False 200 | ttime = [] 201 | mmem = [] 202 | else: 203 | ttime.append(total_time) 204 | mmem.append(max_mem) 205 | print('fe: {:.2f}, cnn: {:.2f}, pp: {:.2f}, reloc: {:.2f}, total: {:.2f}, max mem: {:.2f}MB'.format(fe_time, cnn_time, 206 | match_processing_time, 207 | reloc_processing_time, 208 | total_time, 209 | max_mem)) 210 | 211 | 212 | xA = xA_.view(-1).data.cpu().float().numpy()*wA 213 | yA = yA_.view(-1).data.cpu().float().numpy()*hA 214 | xB = xB_.view(-1).data.cpu().float().numpy()*wB 215 | yB = yB_.view(-1).data.cpu().float().numpy()*hB 216 | score = score_.view(-1).data.cpu().float().numpy() 217 | 218 | keypoints_A=np.stack((xA,yA),axis=1) 219 | keypoints_B=np.stack((xB,yB),axis=1) 220 | 221 | matches_file = '{}/{}_{}.npz.{}'.format(seq_name,'1',idx,args.experiment_name) 222 | 223 | if not args.benchmark: 224 | with open(os.path.join(args.hseq_path,matches_file), 'wb') as output_file: 225 | np.savez( 226 | output_file, 227 | keypoints_A=keypoints_A, 228 | keypoints_B=keypoints_B, 229 | scores=score 230 | ) 231 | 232 | print(matches_file) 233 | 234 | del corr4d,delta4d,src,tgt, feature_A_2x, feature_B_2x 235 | del xA,xB,yA,yB,score 236 | del xA_,xB_,yA_,yB_,score_ 237 | torch.cuda.empty_cache() 238 | torch.cuda.reset_max_memory_allocated() 239 | 240 | if args.benchmark: 241 | print('{}x{},{:.4f},{:.4f}'.format( 242 | wA_, 243 | hA_, 244 | torch.tensor(ttime).mean(), 245 | torch.tensor(mmem).mean())) 246 | -------------------------------------------------------------------------------- /eval/eval_inloc_compute_poses.m: -------------------------------------------------------------------------------- 1 | % Evaluate Sparse-NCNet matches on top of densePE shortlist 2 | 3 | % adjust path and experiment name 4 | inloc_demo_path = '/your_path_to/InLoc_demo_old/'; 5 | experiment = 'sparsencnet_3200_hard_soft'; 6 | 7 | if exist('ncnet_path')==0 8 | ncnet_path=fullfile(pwd,'..'); 9 | end 10 | matches_path = fullfile(ncnet_path,'datasets','inloc',matches'); 11 | 12 | sorted_list_fn = 'densePE_top100_shortlist_cvpr18.mat'; 13 | sorted_list = load(fullfile(ncnet_path,'datasets','inloc','shortlists',sorted_list_fn)); 14 | 15 | addpath(fullfile(ncnet_path,'lib_matlab')); 16 | 17 | % init paths 18 | cd(inloc_demo_path) 19 | startup; 20 | [ params ] = setup_project_ht_WUSTL; 21 | % add extra parameters 22 | params.output.dir = experiment; 23 | params.output.gv_nc4d.dir = fullfile(params.output.dir, 'gv_nc4d'); % dense matching results path 24 | params.output.gv_nc4d.matformat = '.gv_nc4d.mat'; % dense matching results 25 | params.output.pnp_nc4d.matformat = '.pnp_nc4d_inlier.mat'; % PnP results 26 | % redefine gt poses path 27 | params.gt.dir = fullfile(ncnet_path,'lib_matlab') 28 | 29 | Nq=length(sorted_list.ImgList); 30 | 31 | pnp_topN=10; 32 | % set parameters 33 | % params.ncnet.thr = 0.75; 34 | params.ncnet.thr = 0; 35 | params.ncnet.pnp_thr = 0.2; 36 | params.output.pnp_nc4d_inlier.dir = fullfile(params.output.dir, ... 37 | sprintf('top_%i_PnP_thr%03d_rthr%03d',pnp_topN,params.ncnet.thr*100,params.ncnet.pnp_thr*100)); 38 | NC4D_matname = fullfile(params.output.dir, 'shortlist_densePE.mat'); 39 | 40 | % compute poses from matches 41 | ir_top100_NC4D_localization_pnponly; 42 | 43 | do_densePV=true 44 | 45 | if do_densePV 46 | params.output.synth.dir = fullfile(params.output.dir, ... 47 | sprintf('top_%i_thr%03d_rthr%03d_densePV')); 48 | 49 | nc4dPV_matname = fullfile(params.output.dir, 'shortlist_densePV.mat'); 50 | % run pose verification by rendering sythetic views 51 | ht_top10_NC4D_PV_localization 52 | end 53 | -------------------------------------------------------------------------------- /eval/eval_inloc_extract.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import os 7 | from os.path import exists, join, basename 8 | from collections import OrderedDict 9 | 10 | import sys 11 | sys.path.append('..') 12 | 13 | from lib.model import ImMatchNet, MutualMatching 14 | from lib.normalization import imreadth, resize, normalize 15 | from lib.torch_util import str_to_bool 16 | from lib.point_tnf import normalize_axis,unnormalize_axis,corr_to_matches 17 | from lib.sparse import get_matches_both_dirs, torch_to_me, me_to_torch 18 | from lib.relocalize import relocalize, relocalize_soft, eval_model_reloc 19 | 20 | import numpy as np 21 | import numpy.random 22 | from scipy.io import loadmat 23 | from scipy.io import savemat 24 | 25 | import argparse 26 | 27 | print('Sparse-NCNet evaluation script - InLoc dataset') 28 | 29 | use_cuda = torch.cuda.is_available() 30 | 31 | # Argument parsing 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--checkpoint', type=str, default='../trained_models/sparsencnet_k10.pth.tar') 34 | parser.add_argument('--inloc_shortlist', type=str, default='../datasets/inloc/shortlists/densePE_top100_shortlist_cvpr18.mat') 35 | parser.add_argument('--pano_path', type=str, default='../datasets/inloc/', help='path to InLoc panos - should contain CSE3,CSE4,CSE5,DUC1 and DUC2 folders') 36 | parser.add_argument('--query_path', type=str, default='../datasets/inloc/query/iphone7/', help='path to InLoc queries') 37 | parser.add_argument('--k_size', type=int, default=1) 38 | parser.add_argument('--image_size', type=int, default=3200) 39 | parser.add_argument('--experiment_name', type=str, default='sparsencnet_3200_hard_soft') 40 | parser.add_argument('--symmetric_mode', type=str_to_bool, default=True) 41 | parser.add_argument('--nchunks', type=int, default=1) 42 | parser.add_argument('--chunk_idx', type=int, default=0) 43 | parser.add_argument('--skip_up_to', type=str, default='') 44 | parser.add_argument('--relocalize', type=int, default=1) 45 | parser.add_argument('--reloc_type', type=str, default='hard_soft') 46 | parser.add_argument('--reloc_hard_crop_size', type=int, default=2) 47 | parser.add_argument('--change_stride', type=int, default=1) 48 | parser.add_argument('--benchmark', type=int, default=0) 49 | parser.add_argument('--no_ncnet', type=int, default=0) 50 | parser.add_argument('--Npts', type=int, default=2000) 51 | parser.add_argument('--n_queries', type=int, default=356) 52 | parser.add_argument('--n_panos', type=int, default=10) 53 | 54 | args = parser.parse_args() 55 | 56 | print(args) 57 | 58 | chp_args = torch.load(args.checkpoint)['args'] 59 | model = ImMatchNet(use_cuda=use_cuda, 60 | checkpoint=args.checkpoint, 61 | ncons_kernel_sizes=chp_args.ncons_kernel_sizes, 62 | ncons_channels=chp_args.ncons_channels, 63 | sparse=True, 64 | symmetric_mode=bool(chp_args.symmetric_mode), 65 | feature_extraction_cnn=chp_args.feature_extraction_cnn, 66 | bn=bool(chp_args.bn), 67 | k=chp_args.k, 68 | return_fs=True) 69 | 70 | # Generate output folder path 71 | output_folder = args.inloc_shortlist.split('/')[-1].split('.')[0]+'_'+args.experiment_name 72 | print('Output matches folder: '+output_folder) 73 | 74 | scale_factor = 0.0625 75 | if args.relocalize==1: 76 | scale_factor = scale_factor/2 77 | if args.change_stride==1: 78 | scale_factor = scale_factor*2 79 | elif args.change_stride==2: 80 | scale_factor = scale_factor*4 81 | 82 | if args.change_stride>=1: 83 | model.FeatureExtraction.model[-1][0].conv1.stride=(1,1) 84 | model.FeatureExtraction.model[-1][0].conv2.stride=(1,1) 85 | model.FeatureExtraction.model[-1][0].downsample[0].stride=(1,1) 86 | if args.change_stride>=2: 87 | model.FeatureExtraction.model[-2][0].conv1.stride=(1,1) 88 | model.FeatureExtraction.model[-2][0].conv2.stride=(1,1) 89 | model.FeatureExtraction.model[-2][0].downsample[0].stride=(1,1) 90 | 91 | # Get shortlists for each query image 92 | shortlist_fn = args.inloc_shortlist 93 | 94 | dbmat = loadmat(shortlist_fn) 95 | db = dbmat['ImgList'][0,:] 96 | 97 | query_fn_all=np.squeeze(np.vstack(tuple([db[q][0] for q in range(len(db))]))) 98 | pano_fn_all=np.vstack(tuple([db[q][1] for q in range(len(db))])) 99 | 100 | Nqueries=args.n_queries 101 | Npanos=args.n_panos 102 | 103 | try: 104 | os.mkdir('../datasets/inloc/matches/') 105 | except FileExistsError: 106 | pass 107 | 108 | try: 109 | os.mkdir('../datasets/inloc/matches/'+output_folder) 110 | except FileExistsError: 111 | pass 112 | 113 | queries_idx = np.arange(Nqueries) 114 | queries_idx_split = np.array_split(queries_idx,args.nchunks) 115 | queries_idx_chunk = queries_idx_split[args.chunk_idx] 116 | 117 | queries_idx_chunk=list(queries_idx_chunk) 118 | 119 | if args.skip_up_to!='': 120 | queries_idx_chunk = queries_idx_chunk[queries_idx_chunk.index(args.skip_up_to)+1:] 121 | 122 | if args.benchmark: 123 | start = torch.cuda.Event(enable_timing=True) 124 | match = torch.cuda.Event(enable_timing=True) 125 | reloc = torch.cuda.Event(enable_timing=True) 126 | end = torch.cuda.Event(enable_timing=True) 127 | queries_idx_chunk = [queries_idx_chunk[0]] 128 | indices = [0 for i in range(21)] 129 | first_iter=True 130 | else: 131 | indices = range(Npanos) 132 | 133 | for q in queries_idx_chunk: 134 | print(q) 135 | matches=numpy.zeros((1,Npanos,args.Npts,5)) 136 | # load query image 137 | src_fn = os.path.join(args.query_path,db[q][0].item()) 138 | src=imreadth(src_fn) 139 | hA,wA=src.shape[-2:] 140 | src=resize(normalize(src), args.image_size, scale_factor) 141 | hA_,wA_=src.shape[-2:] 142 | 143 | # load database image 144 | for idx in indices: 145 | tgt_fn = os.path.join(args.pano_path,db[q][1].ravel()[idx].item()) 146 | tgt=imreadth(tgt_fn) 147 | hB,wB=tgt.shape[-2:] 148 | tgt=resize(normalize(tgt), args.image_size, scale_factor) 149 | hB_,wB_=tgt.shape[-2:] 150 | 151 | if args.benchmark: 152 | start.record() 153 | 154 | with torch.no_grad(): 155 | if args.benchmark: 156 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4, fe_time, cnn_time = eval_model_reloc( 157 | model, 158 | {'source_image':src, 159 | 'target_image':tgt}, 160 | args 161 | ) 162 | else: 163 | corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4 = eval_model_reloc( 164 | model, 165 | {'source_image':src, 166 | 'target_image':tgt}, 167 | args 168 | ) 169 | 170 | delta4d=None 171 | if args.benchmark: 172 | match.record() 173 | 174 | xA_, yA_, xB_, yB_, score_ = get_matches_both_dirs(corr4d, fs1, fs2, fs3, fs4) 175 | 176 | if args.Npts is not None: 177 | matches_idx_sorted = torch.argsort(-score_.view(-1)) 178 | N_matches = min(args.Npts, matches_idx_sorted.shape[0]) 179 | matches_idx_sorted = matches_idx_sorted[:N_matches] 180 | score_ = score_[:,matches_idx_sorted] 181 | xA_ = xA_[:,matches_idx_sorted] 182 | yA_ = yA_[:,matches_idx_sorted] 183 | xB_ = xB_[:,matches_idx_sorted] 184 | yB_ = yB_[:,matches_idx_sorted] 185 | 186 | if args.benchmark: 187 | reloc.record() 188 | 189 | if args.relocalize: 190 | if args.reloc_type=='hard': 191 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 192 | elif args.reloc_type=='hard_soft': 193 | xA_, yA_, xB_, yB_, score_ = relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x) 194 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=False) 195 | elif args.reloc_type=='soft': 196 | xA_, yA_, xB_, yB_, score_ = relocalize_soft(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, upsample_positions=True) 197 | 198 | fs1,fs2,fs3,fs4=2*fs1,2*fs2,2*fs3,2*fs4 199 | 200 | yA_=(yA_+0.5)/(fs1) 201 | xA_=(xA_+0.5)/(fs2) 202 | yB_=(yB_+0.5)/(fs3) 203 | xB_=(xB_+0.5)/(fs4) 204 | 205 | if args.benchmark: 206 | end.record() 207 | torch.cuda.synchronize() 208 | total_time = start.elapsed_time(end)/1000 209 | processing_time = start.elapsed_time(match)/1000 210 | match_processing_time = match.elapsed_time(reloc)/1000 211 | reloc_processing_time = reloc.elapsed_time(end)/1000 212 | max_mem = torch.cuda.max_memory_allocated()/1024/1024 213 | if first_iter: 214 | first_iter=False 215 | ttime = [] 216 | mmem = [] 217 | else: 218 | ttime.append(total_time) 219 | mmem.append(max_mem) 220 | print('fe: {:.2f}, cnn: {:.2f}, pp: {:.2f}, reloc: {:.2f}, total: {:.2f}, max mem: {:.2f}MB'.format(fe_time, cnn_time, 221 | match_processing_time, 222 | reloc_processing_time, 223 | total_time, 224 | max_mem)) 225 | 226 | 227 | xA = xA_.view(-1).data.cpu().float().numpy() 228 | yA = yA_.view(-1).data.cpu().float().numpy() 229 | xB = xB_.view(-1).data.cpu().float().numpy() 230 | yB = yB_.view(-1).data.cpu().float().numpy() 231 | score = score_.view(-1).data.cpu().float().numpy() 232 | 233 | matches[0,idx,:,0]=xA 234 | matches[0,idx,:,1]=yA 235 | matches[0,idx,:,2]=xB 236 | matches[0,idx,:,3]=yB 237 | matches[0,idx,:,4]=score 238 | 239 | del corr4d,delta4d,tgt, feature_A_2x, feature_B_2x 240 | del xA,xB,yA,yB,score 241 | del xA_,xB_,yA_,yB_,score_ 242 | torch.cuda.empty_cache() 243 | torch.cuda.reset_max_memory_allocated() 244 | 245 | print(">>>"+str(idx)) 246 | 247 | if not args.benchmark: 248 | matches_file=os.path.join('matches/',output_folder,str(q+1)+'.mat') 249 | savemat(matches_file,{'matches':matches,'query_fn':db[q][0].item(),'pano_fn':pano_fn_all},do_compression=True) 250 | print(matches_file) 251 | del src 252 | 253 | if args.benchmark: 254 | print('{}x{},{:.4f},{:.4f}'.format( 255 | wA_, 256 | hA_, 257 | torch.tensor(ttime).mean(), 258 | torch.tensor(mmem).mean())) -------------------------------------------------------------------------------- /eval/eval_inloc_generate_plot.m: -------------------------------------------------------------------------------- 1 | % adjust path and experiment name 2 | inloc_demo_path = '/your_path_to/InLoc_demo_old/'; 3 | addpath(fullfile(pwd,'..','lib_matlab')); 4 | run(fullfile(inloc_demo_path,'startup.m')) 5 | addpath(inloc_demo_path) 6 | [ params ] = setup_project_ht_WUSTL 7 | 8 | params.gt.dir='lib_matlab' 9 | params.output.dir='' 10 | 11 | densePE = load('../datasets/inloc/shortlists/densePE_top100_shortlist_cvpr18') 12 | inloc = load('../datasets/inloc/shortlists/densePV_top10_shortlist_cvpr18') 13 | dense_ncnet=load('../datasets/inloc/shortlists/ncnet_shortlist_neurips18.mat'); 14 | sparse_ncnet_1600_hard=load('../datasets/inloc/shortlists/sparsencnet_shortlist_1600_hard.mat'); 15 | sparse_ncnet_3200_hard_soft=load('../datasets/inloc/shortlists/sparsencnet_shortlist_3200_hard_soft.mat'); 16 | 17 | % define custom palette (kind of color-blind friendly) 18 | cb1=[0,0,0]/255 19 | cb2=[0,73,73]/255 20 | cb3=[0,146,146]/255 21 | cb4=[255,109,182]/255 22 | cb5=[255,182,119]/255 23 | cb6=[73,0,146]/255 24 | cb7=[0,109,219]/255 25 | cb8=[182,109,255]/255 26 | cb9=[109,182,255]/255 27 | cb10=[182,219,255]/255 28 | cb11=[146,0,0]/255 29 | cb12=[146,73,0]/255 30 | cb13=[219,209,0]/255 31 | cb14=[36,255,36]/255 32 | cb15=[255,255,109]/255 33 | 34 | % plot 35 | method = struct(); 36 | i=1 37 | method(i).ImgList = sparse_ncnet_3200_hard_soft.ImgList; 38 | method(i).description = 'InLoc + Sparse-NCNet (H+S, 200\times150)'; 39 | method(i).marker = '-'; 40 | method(i).color = 'black'; 41 | method(i).ms = 8 42 | i=i+1 43 | method(i).ImgList = sparse_ncnet_1600_hard.ImgList; 44 | method(i).description = 'InLoc + Sparse-NCNet (H, 100\times75)'; 45 | method(i).marker = '+-.'; 46 | method(i).color = cb7; 47 | method(i).ms = 8 48 | i=i+1 49 | method(i).ImgList = dense_ncnet.ImgList; 50 | method(i).description = 'InLoc + NCNet (H)'; 51 | method(i).marker = 's-.'; 52 | method(i).color = cb10; 53 | method(i).ms = 8 54 | i=i+1 55 | method(i).ImgList = inloc.ImgList; 56 | method(i).description = 'InLoc'; 57 | method(i).marker = 'x--'; 58 | method(i).color = cb4; 59 | method(i).ms = 8 60 | i=i+1 61 | method(i).ImgList = densePE.ImgList; 62 | method(i).description = 'DensePE'; 63 | method(i).marker = 'o--'; 64 | method(i).color = cb14; 65 | method(i).ms = 8 66 | 67 | ht_plotcurve_WUSTL 68 | 69 | -------------------------------------------------------------------------------- /eval_ncnetdense/eval_hpatches_ncnetdense_extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import os 7 | from os.path import exists, join, basename 8 | from collections import OrderedDict 9 | 10 | import sys 11 | sys.path.append('..') 12 | 13 | from lib.model import ImMatchNet, MutualMatching 14 | from lib.normalization import NormalizeImageDict 15 | from lib.torch_util import str_to_bool 16 | from lib.point_tnf import normalize_axis,unnormalize_axis,corr_to_matches 17 | from lib.plot import plot_image 18 | 19 | import numpy as np 20 | import numpy.random 21 | from skimage.io import imread 22 | from scipy.io import loadmat 23 | from scipy.io import savemat 24 | 25 | import argparse 26 | 27 | print('NCNetDense evaluation script - HSequences dataset') 28 | 29 | use_cuda = torch.cuda.is_available() 30 | 31 | # Argument parsing 32 | parser = argparse.ArgumentParser() 33 | 34 | parser.add_argument('--checkpoint', type=str, default='../trained_models/ncnet_ivd.pth.tar') 35 | parser.add_argument('--hseq_path', type=str, default='../datasets/hpatches/hpatches-sequences-release') 36 | parser.add_argument('--k_size', type=int, default=2) 37 | parser.add_argument('--image_size', type=int, default=1600) 38 | parser.add_argument('--softmax', type=str_to_bool, default=False) 39 | parser.add_argument('--matching_both_directions', type=str_to_bool, default=True) 40 | parser.add_argument('--flip_matching_direction', type=str_to_bool, default=False) 41 | parser.add_argument('--experiment_name', type=str, default='ncnet_resnet101_3200k2_softmax0') 42 | parser.add_argument('--symmetric_mode', type=str_to_bool, default=True) 43 | parser.add_argument('--nchunks', type=int, default=1) 44 | parser.add_argument('--chunk_idx', type=int, default=0) 45 | parser.add_argument('--skip_up_to', type=str, default='') 46 | parser.add_argument('--feature_extraction_cnn', type=str, default='resnet101') 47 | parser.add_argument('--change_stride', type=int, default=1) 48 | parser.add_argument('--benchmark', type=int, default=0) 49 | 50 | args = parser.parse_args() 51 | 52 | image_size = args.image_size 53 | k_size = args.k_size 54 | matching_both_directions = args.matching_both_directions 55 | flip_matching_direction = args.flip_matching_direction 56 | 57 | # Load pretrained model 58 | half_precision=True # use for memory saving 59 | 60 | print(args) 61 | 62 | model = ImMatchNet(use_cuda=use_cuda, 63 | checkpoint=args.checkpoint, 64 | half_precision=half_precision, 65 | feature_extraction_cnn=args.feature_extraction_cnn, 66 | relocalization_k_size=args.k_size, 67 | symmetric_mode=args.symmetric_mode) 68 | 69 | if args.change_stride: 70 | scale_factor = 0.0625 71 | # import pdb;pdb.set_trace() 72 | model.FeatureExtraction.model[-1][0].conv1.stride=(1,1) 73 | model.FeatureExtraction.model[-1][0].conv2.stride=(1,1) 74 | model.FeatureExtraction.model[-1][0].downsample[0].stride=(1,1) 75 | else: 76 | scale_factor = 0.0625/2 77 | 78 | imreadth = lambda x: torch.Tensor(imread(x).astype(np.float32)).transpose(1,2).transpose(0,1) 79 | normalize = lambda x: NormalizeImageDict(['im'])({'im':x})['im'] 80 | 81 | # allow rectangular images. Does not modify aspect ratio. 82 | if k_size==1: 83 | resize = lambda x: nn.functional.upsample(Variable(x.unsqueeze(0).cuda(),volatile=True), 84 | size=(int(x.shape[1]/(np.max(x.shape[1:])/image_size)),int(x.shape[2]/(np.max(x.shape[1:])/image_size))),mode='bilinear') 85 | else: 86 | resize = lambda x: nn.functional.upsample(Variable(x.unsqueeze(0).cuda(),volatile=True), 87 | size=(int(np.floor(x.shape[1]/(np.max(x.shape[1:])/image_size)*scale_factor/k_size)/scale_factor*k_size), 88 | int(np.floor(x.shape[2]/(np.max(x.shape[1:])/image_size)*scale_factor/k_size)/scale_factor*k_size)),mode='bilinear') 89 | 90 | padim = lambda x,h_max: torch.cat((x,x.view(-1)[0].clone().expand(1,3,h_max-x.shape[2],x.shape[3])/1e20),dim=2) if x.shape[2]1: 144 | corr4d,delta4d=model({'source_image':src,'target_image':tgt}) 145 | else: 146 | corr4d=model({'source_image':src,'target_image':tgt}) 147 | delta4d=None 148 | if args.benchmark: 149 | mid.record() 150 | 151 | # reshape corr tensor and get matches for each point in image B 152 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() 153 | 154 | # pad image and plot 155 | if plot: 156 | h_max=int(np.max([src.shape[2],tgt.shape[2]])) 157 | im=plot_image(torch.cat((padim(src,h_max),padim(tgt,h_max)),dim=3),return_im=True) 158 | plt.imshow(im) 159 | 160 | if matching_both_directions: 161 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size) 162 | (xA2_,yA2_,xB2_,yB2_,score2_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size,invert_matching_direction=True) 163 | xA_=torch.cat((xA_,xA2_),1) 164 | yA_=torch.cat((yA_,yA2_),1) 165 | xB_=torch.cat((xB_,xB2_),1) 166 | yB_=torch.cat((yB_,yB2_),1) 167 | score_=torch.cat((score_,score2_),1) 168 | # sort in descending score (this will keep the max-score instance in the duplicate removal step) 169 | sorted_index=torch.sort(-score_)[1].squeeze() 170 | xA_=xA_.squeeze()[sorted_index].unsqueeze(0) 171 | yA_=yA_.squeeze()[sorted_index].unsqueeze(0) 172 | xB_=xB_.squeeze()[sorted_index].unsqueeze(0) 173 | yB_=yB_.squeeze()[sorted_index].unsqueeze(0) 174 | score_=score_.squeeze()[sorted_index].unsqueeze(0) 175 | # remove duplicates 176 | concat_coords=np.concatenate((xA_.cpu().data.numpy(),yA_.cpu().data.numpy(),xB_.cpu().data.numpy(),yB_.cpu().data.numpy()),0) 177 | _,unique_index=np.unique(concat_coords,axis=1,return_index=True) 178 | xA_=xA_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 179 | yA_=yA_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 180 | xB_=xB_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 181 | yB_=yB_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 182 | score_=score_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 183 | elif flip_matching_direction: 184 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size,invert_matching_direction=True) 185 | else: 186 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size) 187 | 188 | # recenter 189 | if k_size>1: 190 | yA_=yA_*(fs1*k_size-1)/(fs1*k_size)+0.5/(fs1*k_size) 191 | xA_=xA_*(fs2*k_size-1)/(fs2*k_size)+0.5/(fs2*k_size) 192 | yB_=yB_*(fs3*k_size-1)/(fs3*k_size)+0.5/(fs3*k_size) 193 | xB_=xB_*(fs4*k_size-1)/(fs4*k_size)+0.5/(fs4*k_size) 194 | else: 195 | yA_=yA_*(fs1-1)/fs1+0.5/fs1 196 | xA_=xA_*(fs2-1)/fs2+0.5/fs2 197 | yB_=yB_*(fs3-1)/fs3+0.5/fs3 198 | xB_=xB_*(fs4-1)/fs4+0.5/fs4 199 | 200 | if args.benchmark: 201 | end.record() 202 | torch.cuda.synchronize() 203 | total_time = start.elapsed_time(end)/1000 204 | processing_time = start.elapsed_time(mid)/1000 205 | post_processing_time = mid.elapsed_time(end)/1000 206 | max_mem = torch.cuda.max_memory_allocated()/1024/1024 207 | if first_iter: 208 | first_iter=False 209 | ttime = [] 210 | mmem = [] 211 | else: 212 | ttime.append(total_time) 213 | mmem.append(max_mem) 214 | print('cnn: {:.2f}, pp: {:.2f}, total: {:.2f}, max mem: {:.2f}MB'.format(processing_time, 215 | post_processing_time, 216 | total_time, 217 | max_mem)) 218 | 219 | xA = xA_.view(-1).data.cpu().float().numpy()*wA 220 | yA = yA_.view(-1).data.cpu().float().numpy()*hA 221 | xB = xB_.view(-1).data.cpu().float().numpy()*wB 222 | yB = yB_.view(-1).data.cpu().float().numpy()*hB 223 | score = score_.view(-1).data.cpu().float().numpy() 224 | 225 | keypoints_A=np.stack((xA,yA),axis=1) 226 | keypoints_B=np.stack((xB,yB),axis=1) 227 | 228 | Npts=len(xA) 229 | if Npts>0: 230 | # plot top N matches 231 | if plot: 232 | c=numpy.random.rand(Npts,3) 233 | for i in range(Npts): 234 | if score[i]>0.75: 235 | ax = plt.gca() 236 | ax.add_artist(plt.Circle((float(xA[i])*src.shape[3],float(yA[i])*src.shape[2]), radius=3, color=c[i,:])) 237 | ax.add_artist(plt.Circle((float(xB[i])*tgt.shape[3]+src.shape[3] ,float(yB[i])*tgt.shape[2]), radius=3, color=c[i,:])) 238 | 239 | 240 | matches_file = '{}/{}_{}.npz.{}'.format(seq_name,'1',idx,args.experiment_name) 241 | 242 | if not args.benchmark: 243 | with open(os.path.join(args.hseq_path,matches_file), 'wb') as output_file: 244 | np.savez( 245 | output_file, 246 | keypoints_A=keypoints_A, 247 | keypoints_B=keypoints_B, 248 | scores=score 249 | ) 250 | 251 | print(matches_file) 252 | 253 | del corr4d,delta4d,src,tgt 254 | del xA,xB,yA,yB,score 255 | del xA_,xB_,yA_,yB_,score_ 256 | torch.cuda.empty_cache() 257 | torch.cuda.reset_max_memory_allocated() 258 | 259 | if args.benchmark: 260 | print('{}x{},{:.4f},{:.4f}'.format( 261 | wA_, 262 | hA_, 263 | torch.tensor(ttime).mean(), 264 | torch.tensor(mmem).mean())) -------------------------------------------------------------------------------- /eval_ncnetdense/eval_inloc_ncnetdense_extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | import os 6 | from os.path import exists, join, basename 7 | from collections import OrderedDict 8 | import numpy as np 9 | import numpy.random 10 | import scipy as sc 11 | import scipy.misc 12 | from scipy.io import loadmat 13 | from scipy.io import savemat 14 | import matplotlib.pyplot as plt 15 | 16 | import sys 17 | sys.path.append('..') 18 | 19 | from lib.model import ImMatchNet,MutualMatching 20 | from lib.normalization import NormalizeImageDict 21 | from lib.torch_util import str_to_bool 22 | from lib.point_tnf import normalize_axis,unnormalize_axis,corr_to_matches 23 | from lib.plot import plot_image 24 | 25 | from skimage.io import imread 26 | 27 | import argparse 28 | 29 | print('NCNetDense evaluation script - InLoc dataset') 30 | 31 | use_cuda = torch.cuda.is_available() 32 | 33 | # Argument parsing 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--checkpoint', type=str, default='../trained_models/ncnet_ivd.pth.tar') 36 | parser.add_argument('--inloc_shortlist', type=str, default='../datasets/inloc/shortlist/densePE_top100_shortlist_cvpr18.mat') 37 | parser.add_argument('--k_size', type=int, default=2) 38 | parser.add_argument('--image_size', type=int, default=3200) 39 | parser.add_argument('--n_queries', type=int, default=356) 40 | parser.add_argument('--n_panos', type=int, default=10) 41 | parser.add_argument('--softmax', type=str_to_bool, default=True) 42 | parser.add_argument('--matching_both_directions', type=str_to_bool, default=True) 43 | parser.add_argument('--flip_matching_direction', type=str_to_bool, default=False) 44 | parser.add_argument('--pano_path', type=str, default='../datasets/inloc/', help='path to InLoc panos - should contain CSE3,CSE4,CSE5,DUC1 and DUC2 folders') 45 | parser.add_argument('--query_path', type=str, default='../datasets/inloc/query/iphone7/', help='path to InLoc queries') 46 | 47 | args = parser.parse_args() 48 | 49 | image_size = args.image_size 50 | k_size = args.k_size 51 | matching_both_directions = args.matching_both_directions 52 | flip_matching_direction = args.flip_matching_direction 53 | 54 | # Load pretrained model 55 | half_precision=True # use for memory saving 56 | 57 | print(args) 58 | 59 | model = ImMatchNet(use_cuda=use_cuda, 60 | checkpoint=args.checkpoint, 61 | half_precision=half_precision, 62 | relocalization_k_size=args.k_size) 63 | 64 | # Generate output folder path 65 | output_folder = args.inloc_shortlist.split('/')[-1].split('.')[0]+'_SZ_NEW_'+str(image_size)+'_K_'+str(k_size) 66 | if matching_both_directions: 67 | output_folder += '_BOTHDIRS' 68 | elif flip_matching_direction: 69 | output_folder += '_AtoB' 70 | else: 71 | output_folder += '_BtoA' 72 | if args.softmax==True: 73 | output_folder += '_SOFTMAX' 74 | if args.checkpoint!='': 75 | checkpoint_name=args.checkpoint.split('/')[-1].split('.')[0] 76 | output_folder += '_CHECKPOINT_'+checkpoint_name 77 | print('Output matches folder: '+output_folder) 78 | 79 | # Data preprocessing 80 | 81 | # Manually change image resolution for this test. On training, image_size=400 was used, with squared images 82 | scale_factor = 0.0625 83 | 84 | imreadth = lambda x: torch.Tensor(imread(x).astype(np.float32)).transpose(1,2).transpose(0,1) 85 | normalize = lambda x: NormalizeImageDict(['im'])({'im':x})['im'] 86 | 87 | # allow rectangular images. Does not modify aspect ratio. 88 | if k_size==1: 89 | resize = lambda x: nn.functional.upsample(Variable(x.unsqueeze(0).cuda(),volatile=True), 90 | size=(int(x.shape[1]/(np.max(x.shape[1:])/image_size)),int(x.shape[2]/(np.max(x.shape[1:])/image_size))),mode='bilinear') 91 | else: 92 | resize = lambda x: nn.functional.upsample(Variable(x.unsqueeze(0).cuda(),volatile=True), 93 | size=(int(np.floor(x.shape[1]/(np.max(x.shape[1:])/image_size)*scale_factor/k_size)/scale_factor*k_size), 94 | int(np.floor(x.shape[2]/(np.max(x.shape[1:])/image_size)*scale_factor/k_size)/scale_factor*k_size)),mode='bilinear') 95 | 96 | padim = lambda x,h_max: torch.cat((x,x.view(-1)[0].clone().expand(1,3,h_max-x.shape[2],x.shape[3])/1e20),dim=2) if x.shape[2]1: 142 | corr4d,delta4d=model({'source_image':src,'target_image':tgt}) 143 | else: 144 | corr4d=model({'source_image':src,'target_image':tgt}) 145 | delta4d=None 146 | 147 | # reshape corr tensor and get matches for each point in image B 148 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() 149 | 150 | # pad image and plot 151 | if plot: 152 | h_max=int(np.max([src.shape[2],tgt.shape[2]])) 153 | im=plot_image(torch.cat((padim(src,h_max),padim(tgt,h_max)),dim=3),return_im=True) 154 | plt.imshow(im) 155 | 156 | if matching_both_directions: 157 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size) 158 | (xA2_,yA2_,xB2_,yB2_,score2_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size,invert_matching_direction=True) 159 | xA_=torch.cat((xA_,xA2_),1) 160 | yA_=torch.cat((yA_,yA2_),1) 161 | xB_=torch.cat((xB_,xB2_),1) 162 | yB_=torch.cat((yB_,yB2_),1) 163 | score_=torch.cat((score_,score2_),1) 164 | # sort in descending score (this will keep the max-score instance in the duplicate removal step) 165 | sorted_index=torch.sort(-score_)[1].squeeze() 166 | xA_=xA_.squeeze()[sorted_index].unsqueeze(0) 167 | yA_=yA_.squeeze()[sorted_index].unsqueeze(0) 168 | xB_=xB_.squeeze()[sorted_index].unsqueeze(0) 169 | yB_=yB_.squeeze()[sorted_index].unsqueeze(0) 170 | score_=score_.squeeze()[sorted_index].unsqueeze(0) 171 | # remove duplicates 172 | concat_coords=np.concatenate((xA_.cpu().data.numpy(),yA_.cpu().data.numpy(),xB_.cpu().data.numpy(),yB_.cpu().data.numpy()),0) 173 | _,unique_index=np.unique(concat_coords,axis=1,return_index=True) 174 | xA_=xA_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 175 | yA_=yA_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 176 | xB_=xB_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 177 | yB_=yB_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 178 | score_=score_.squeeze()[torch.cuda.LongTensor(unique_index)].unsqueeze(0) 179 | elif flip_matching_direction: 180 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size,invert_matching_direction=True) 181 | else: 182 | (xA_,yA_,xB_,yB_,score_)=corr_to_matches(corr4d,scale='positive',do_softmax=do_softmax,delta4d=delta4d,k_size=k_size) 183 | 184 | # recenter 185 | if k_size>1: 186 | yA_=yA_*(fs1*k_size-1)/(fs1*k_size)+0.5/(fs1*k_size) 187 | xA_=xA_*(fs2*k_size-1)/(fs2*k_size)+0.5/(fs2*k_size) 188 | yB_=yB_*(fs3*k_size-1)/(fs3*k_size)+0.5/(fs3*k_size) 189 | xB_=xB_*(fs4*k_size-1)/(fs4*k_size)+0.5/(fs4*k_size) 190 | else: 191 | yA_=yA_*(fs1-1)/fs1+0.5/fs1 192 | xA_=xA_*(fs2-1)/fs2+0.5/fs2 193 | yB_=yB_*(fs3-1)/fs3+0.5/fs3 194 | xB_=xB_*(fs4-1)/fs4+0.5/fs4 195 | 196 | xA = xA_.view(-1).data.cpu().float().numpy() 197 | yA = yA_.view(-1).data.cpu().float().numpy() 198 | xB = xB_.view(-1).data.cpu().float().numpy() 199 | yB = yB_.view(-1).data.cpu().float().numpy() 200 | score = score_.view(-1).data.cpu().float().numpy() 201 | 202 | Npts=len(xA) 203 | if Npts>0: 204 | matches[0,idx,:Npts,0]=xA 205 | matches[0,idx,:Npts,1]=yA 206 | matches[0,idx,:Npts,2]=xB 207 | matches[0,idx,:Npts,3]=yB 208 | matches[0,idx,:Npts,4]=score 209 | import pdb;pdb.set_trace() 210 | # plot top N matches 211 | if plot: 212 | c=numpy.random.rand(Npts,3) 213 | for i in range(Npts): 214 | if score[i]>0.75: 215 | ax = plt.gca() 216 | ax.add_artist(plt.Circle((float(xA[i])*src.shape[3],float(yA[i])*src.shape[2]), radius=3, color=c[i,:])) 217 | ax.add_artist(plt.Circle((float(xB[i])*tgt.shape[3]+src.shape[3] ,float(yB[i])*tgt.shape[2]), radius=3, color=c[i,:])) 218 | 219 | corr4d=None 220 | delta4d=None 221 | 222 | if idx%10==0: 223 | print(">>>"+str(idx)) 224 | 225 | savemat(os.path.join('matches/',output_folder,str(q+1)+'.mat'),{'matches':matches,'query_fn':db[q][0].item(),'pano_fn':pano_fn_all},do_compression=True) 226 | 227 | if plot: 228 | plt.gcf().set_dpi(200) 229 | plt.show() 230 | 231 | -------------------------------------------------------------------------------- /lib/conv4d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | import torch.nn.functional as F 5 | from torch.nn import Module 6 | from torch.nn.modules.conv import _ConvNd 7 | from torch.nn.modules.utils import _quadruple 8 | from torch.autograd import Variable 9 | from torch.nn import Conv2d 10 | import torch.nn.functional as F 11 | 12 | def conv4d(data,filters,bias=None,permute_filters=True,use_half=False,split_conv=False): 13 | b,c,h,w,d,t=data.size() 14 | 15 | data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 16 | 17 | # Same permutation is done with filters, unless already provided with permutation 18 | if permute_filters: 19 | filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 20 | 21 | c_out=filters.size(1) 22 | padding=filters.size(0)//2 23 | if split_conv: 24 | h=h-2*padding 25 | 26 | if use_half: 27 | output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 28 | else: 29 | output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 30 | 31 | if use_half: 32 | Z=Variable(torch.zeros(padding,b,c,w,d,t).half()) 33 | else: 34 | Z=Variable(torch.zeros(padding,b,c,w,d,t)) 35 | 36 | if data.is_cuda: 37 | Z=Z.cuda(data.get_device()) 38 | output=output.cuda(data.get_device()) 39 | 40 | if padding>0 and split_conv==False: 41 | data_padded = torch.cat((Z,data,Z),0) 42 | else: 43 | data_padded = data 44 | 45 | for i in range(output.size(0)): # loop on first feature dimension 46 | # convolve with center channel of filter (at position=padding) 47 | output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:], 48 | filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding) 49 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) 50 | for p in range(1,padding+1): 51 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:], 52 | filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding) 53 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:], 54 | filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding) 55 | 56 | output=output.permute(1,2,0,3,4,5).contiguous() 57 | return output 58 | 59 | class Conv4d(_ConvNd): 60 | """Applies a 4D convolution over an input signal composed of several input 61 | planes. 62 | """ 63 | 64 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): 65 | # stride, dilation and groups !=1 functionality not tested 66 | stride=1 67 | dilation=1 68 | groups=1 69 | # zero padding is added automatically in conv4d function to preserve tensor size 70 | padding = 0 71 | kernel_size = _quadruple(kernel_size) 72 | stride = _quadruple(stride) 73 | padding = _quadruple(padding) 74 | dilation = _quadruple(dilation) 75 | super(Conv4d, self).__init__( 76 | in_channels, out_channels, kernel_size, stride, padding, dilation, 77 | False, _quadruple(0), groups, bias,'zeros') 78 | # weights will be sliced along one dimension during convolution loop 79 | # make the looping dimension to be the first one in the tensor, 80 | # so that we don't need to call contiguous() inside the loop 81 | self.pre_permuted_filters=pre_permuted_filters 82 | if self.pre_permuted_filters: 83 | self.permute_filters() 84 | self.use_half=False 85 | 86 | def forward(self, input): 87 | return conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor 88 | # return conv4d(input, self.weight) 89 | 90 | def permute_filters(self): 91 | self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous() 92 | 93 | def unpermute_filters(self): 94 | self.weight.data=self.weight.data.permute(1,2,0,3,4,5).contiguous() 95 | 96 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as multiprocessing 3 | from torch.utils.data.sampler import SequentialSampler, RandomSampler, BatchSampler 4 | import collections 5 | import sys 6 | import traceback 7 | import threading 8 | import numpy as np 9 | import numpy.random 10 | 11 | #from torch._six import string_classes 12 | PY2 = sys.version_info[0] == 2 13 | PY3 = sys.version_info[0] == 3 14 | 15 | if PY2: 16 | string_classes = basestring 17 | else: 18 | string_classes = (str, bytes) 19 | 20 | 21 | if sys.version_info[0] == 2: 22 | import Queue as queue 23 | else: 24 | import queue 25 | 26 | 27 | _use_shared_memory = False 28 | """Whether to use shared memory in default_collate""" 29 | 30 | 31 | class ExceptionWrapper(object): 32 | "Wraps an exception plus traceback to communicate across threads" 33 | 34 | def __init__(self, exc_info): 35 | self.exc_type = exc_info[0] 36 | self.exc_msg = "".join(traceback.format_exception(*exc_info)) 37 | 38 | 39 | def _worker_loop(dataset, index_queue, data_queue, collate_fn, rng_seed): 40 | global _use_shared_memory 41 | _use_shared_memory = True 42 | 43 | np.random.seed(rng_seed) 44 | torch.set_num_threads(1) 45 | while True: 46 | r = index_queue.get() 47 | if r is None: 48 | data_queue.put(None) 49 | break 50 | idx, batch_indices = r 51 | try: 52 | samples = collate_fn([dataset[i] for i in batch_indices]) 53 | except Exception: 54 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 55 | else: 56 | data_queue.put((idx, samples)) 57 | 58 | 59 | def _pin_memory_loop(in_queue, out_queue, done_event): 60 | while True: 61 | try: 62 | r = in_queue.get() 63 | except: 64 | if done_event.is_set(): 65 | return 66 | raise 67 | if r is None: 68 | break 69 | if isinstance(r[1], ExceptionWrapper): 70 | out_queue.put(r) 71 | continue 72 | idx, batch = r 73 | try: 74 | batch = pin_memory_batch(batch) 75 | except Exception: 76 | out_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 77 | else: 78 | out_queue.put((idx, batch)) 79 | 80 | 81 | numpy_type_map = { 82 | 'float64': torch.DoubleTensor, 83 | 'float32': torch.FloatTensor, 84 | 'float16': torch.HalfTensor, 85 | 'int64': torch.LongTensor, 86 | 'int32': torch.IntTensor, 87 | 'int16': torch.ShortTensor, 88 | 'int8': torch.CharTensor, 89 | 'uint8': torch.ByteTensor, 90 | } 91 | 92 | 93 | def default_collate(batch): 94 | "Puts each data field into a tensor with outer dimension batch size" 95 | if torch.is_tensor(batch[0]): 96 | out = None 97 | if _use_shared_memory: 98 | # If we're in a background process, concatenate directly into a 99 | # shared memory tensor to avoid an extra copy 100 | numel = sum([x.numel() for x in batch]) 101 | storage = batch[0].storage()._new_shared(numel) 102 | out = batch[0].new(storage) 103 | return torch.stack(batch, 0, out=out) 104 | elif type(batch[0]).__module__ == 'numpy': 105 | elem = batch[0] 106 | if type(elem).__name__ == 'ndarray': 107 | return torch.stack([torch.from_numpy(b) for b in batch], 0) 108 | if elem.shape == (): # scalars 109 | py_type = float if elem.dtype.name.startswith('float') else int 110 | return numpy_type_map[elem.dtype.name](list(map(py_type, batch))) 111 | elif isinstance(batch[0], int): 112 | return torch.LongTensor(batch) 113 | elif isinstance(batch[0], float): 114 | return torch.DoubleTensor(batch) 115 | elif isinstance(batch[0], string_classes): 116 | return batch 117 | elif isinstance(batch[0], collections.Mapping): 118 | return {key: default_collate([d[key] for d in batch]) for key in batch[0]} 119 | elif isinstance(batch[0], collections.Sequence): 120 | transposed = zip(*batch) 121 | return [default_collate(samples) for samples in transposed] 122 | 123 | raise TypeError(("batch must contain tensors, numbers, dicts or lists; found {}" 124 | .format(type(batch[0])))) 125 | 126 | 127 | def pin_memory_batch(batch): 128 | if torch.is_tensor(batch): 129 | return batch.pin_memory() 130 | elif isinstance(batch, string_classes): 131 | return batch 132 | elif isinstance(batch, collections.Mapping): 133 | return {k: pin_memory_batch(sample) for k, sample in batch.items()} 134 | elif isinstance(batch, collections.Sequence): 135 | return [pin_memory_batch(sample) for sample in batch] 136 | else: 137 | return batch 138 | 139 | 140 | class DataLoaderIter(object): 141 | "Iterates once over the DataLoader's dataset, as specified by the sampler" 142 | 143 | def __init__(self, loader): 144 | self.dataset = loader.dataset 145 | self.collate_fn = loader.collate_fn 146 | self.batch_sampler = loader.batch_sampler 147 | self.num_workers = loader.num_workers 148 | self.pin_memory = loader.pin_memory 149 | self.done_event = threading.Event() 150 | 151 | self.sample_iter = iter(self.batch_sampler) 152 | 153 | if self.num_workers > 0: 154 | self.index_queue = multiprocessing.SimpleQueue() 155 | self.data_queue = multiprocessing.SimpleQueue() 156 | self.batches_outstanding = 0 157 | self.shutdown = False 158 | self.send_idx = 0 159 | self.rcvd_idx = 0 160 | self.reorder_dict = {} 161 | 162 | self.workers = [ 163 | multiprocessing.Process( 164 | target=_worker_loop, 165 | args=(self.dataset, self.index_queue, self.data_queue, self.collate_fn, np.random.randint(0, 4294967296, dtype='uint32'))) 166 | for _ in range(self.num_workers)] 167 | 168 | for w in self.workers: 169 | w.daemon = True # ensure that the worker exits on process exit 170 | w.start() 171 | 172 | if self.pin_memory: 173 | in_data = self.data_queue 174 | self.data_queue = queue.Queue() 175 | self.pin_thread = threading.Thread( 176 | target=_pin_memory_loop, 177 | args=(in_data, self.data_queue, self.done_event)) 178 | self.pin_thread.daemon = True 179 | self.pin_thread.start() 180 | 181 | # prime the prefetch loop 182 | for _ in range(2 * self.num_workers): 183 | self._put_indices() 184 | 185 | def __len__(self): 186 | return len(self.batch_sampler) 187 | 188 | def __next__(self): 189 | if self.num_workers == 0: # same-process loading 190 | indices = next(self.sample_iter) # may raise StopIteration 191 | batch = self.collate_fn([self.dataset[i] for i in indices]) 192 | if self.pin_memory: 193 | batch = pin_memory_batch(batch) 194 | return batch 195 | 196 | # check if the next sample has already been generated 197 | if self.rcvd_idx in self.reorder_dict: 198 | batch = self.reorder_dict.pop(self.rcvd_idx) 199 | return self._process_next_batch(batch) 200 | 201 | if self.batches_outstanding == 0: 202 | self._shutdown_workers() 203 | raise StopIteration 204 | 205 | while True: 206 | assert (not self.shutdown and self.batches_outstanding > 0) 207 | idx, batch = self.data_queue.get() 208 | self.batches_outstanding -= 1 209 | if idx != self.rcvd_idx: 210 | # store out-of-order samples 211 | self.reorder_dict[idx] = batch 212 | continue 213 | return self._process_next_batch(batch) 214 | 215 | next = __next__ # Python 2 compatibility 216 | 217 | def __iter__(self): 218 | return self 219 | 220 | def _put_indices(self): 221 | assert self.batches_outstanding < 2 * self.num_workers 222 | indices = next(self.sample_iter, None) 223 | if indices is None: 224 | return 225 | self.index_queue.put((self.send_idx, indices)) 226 | self.batches_outstanding += 1 227 | self.send_idx += 1 228 | 229 | def _process_next_batch(self, batch): 230 | self.rcvd_idx += 1 231 | self._put_indices() 232 | if isinstance(batch, ExceptionWrapper): 233 | raise batch.exc_type(batch.exc_msg) 234 | return batch 235 | 236 | def __getstate__(self): 237 | # TODO: add limited pickling support for sharing an iterator 238 | # across multiple threads for HOGWILD. 239 | # Probably the best way to do this is by moving the sample pushing 240 | # to a separate thread and then just sharing the data queue 241 | # but signalling the end is tricky without a non-blocking API 242 | raise NotImplementedError("DataLoaderIterator cannot be pickled") 243 | 244 | def _shutdown_workers(self): 245 | if not self.shutdown: 246 | self.shutdown = True 247 | self.done_event.set() 248 | for _ in self.workers: 249 | self.index_queue.put(None) 250 | 251 | def __del__(self): 252 | if self.num_workers > 0: 253 | self._shutdown_workers() 254 | 255 | 256 | class DataLoader(object): 257 | """ 258 | Data loader. Combines a dataset and a sampler, and provides 259 | single- or multi-process iterators over the dataset. 260 | 261 | Arguments: 262 | dataset (Dataset): dataset from which to load the data. 263 | batch_size (int, optional): how many samples per batch to load 264 | (default: 1). 265 | shuffle (bool, optional): set to ``True`` to have the data reshuffled 266 | at every epoch (default: False). 267 | sampler (Sampler, optional): defines the strategy to draw samples from 268 | the dataset. If specified, ``shuffle`` must be False. 269 | batch_sampler (Sampler, optional): like sampler, but returns a batch of 270 | indices at a time. Mutually exclusive with batch_size, shuffle, 271 | sampler, and drop_last. 272 | num_workers (int, optional): how many subprocesses to use for data 273 | loading. 0 means that the data will be loaded in the main process 274 | (default: 0) 275 | collate_fn (callable, optional): merges a list of samples to form a mini-batch. 276 | pin_memory (bool, optional): If ``True``, the data loader will copy tensors 277 | into CUDA pinned memory before returning them. 278 | drop_last (bool, optional): set to ``True`` to drop the last incomplete batch, 279 | if the dataset size is not divisible by the batch size. If False and 280 | the size of dataset is not divisible by the batch size, then the last batch 281 | will be smaller. (default: False) 282 | """ 283 | 284 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, 285 | num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False): 286 | self.dataset = dataset 287 | self.batch_size = batch_size 288 | self.num_workers = num_workers 289 | self.collate_fn = collate_fn 290 | self.pin_memory = pin_memory 291 | self.drop_last = drop_last 292 | 293 | if batch_sampler is not None: 294 | if batch_size > 1 or shuffle or sampler is not None or drop_last: 295 | raise ValueError('batch_sampler is mutually exclusive with ' 296 | 'batch_size, shuffle, sampler, and drop_last') 297 | 298 | if sampler is not None and shuffle: 299 | raise ValueError('sampler is mutually exclusive with shuffle') 300 | 301 | if batch_sampler is None: 302 | if sampler is None: 303 | if shuffle: 304 | sampler = RandomSampler(dataset) 305 | else: 306 | sampler = SequentialSampler(dataset) 307 | batch_sampler = BatchSampler(sampler, batch_size, drop_last) 308 | 309 | self.sampler = sampler 310 | self.batch_sampler = batch_sampler 311 | 312 | def __iter__(self): 313 | return DataLoaderIter(self) 314 | 315 | def __len__(self): 316 | return len(self.batch_sampler) -------------------------------------------------------------------------------- /lib/eval_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import numpy as np 4 | import os 5 | from skimage import draw 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | from lib.pf_dataset import PFPascalDataset 9 | from lib.point_tnf import PointsToUnitCoords, PointsToPixelCoords, bilinearInterpPointTnf, nearestNeighPointTnf 10 | 11 | 12 | def pck(source_points,warped_points,L_pck,alpha=0.1): 13 | # compute precentage of correct keypoints 14 | batch_size=source_points.size(0) 15 | pck=torch.zeros((batch_size)) 16 | for i in range(batch_size): 17 | p_src = source_points[i,:] 18 | p_wrp = warped_points[i,:] 19 | N_pts = torch.sum(torch.ne(p_src[0,:],-1)*torch.ne(p_src[1,:],-1)) 20 | point_distance = torch.pow(torch.sum(torch.pow(p_src[:,:N_pts]-p_wrp[:,:N_pts],2),0),0.5) 21 | L_pck_mat = L_pck[i].expand_as(point_distance) 22 | correct_points = torch.le(point_distance,L_pck_mat*alpha) 23 | pck[i]=torch.mean(correct_points.float()) 24 | return pck 25 | 26 | 27 | def pck_metric(batch,batch_start_idx,matches,stats,args,use_cuda=True, interp='bilinear'): 28 | 29 | source_im_size = batch['source_im_size'] 30 | target_im_size = batch['target_im_size'] 31 | 32 | source_points = batch['source_points'] 33 | target_points = batch['target_points'] 34 | 35 | # warp points with estimated transformations 36 | target_points_norm = PointsToUnitCoords(target_points,target_im_size) 37 | 38 | # compute points stage 1 only 39 | if interp=='bilinear': 40 | warped_points_norm = bilinearInterpPointTnf(matches,target_points_norm) 41 | elif interp=='nearest': 42 | warped_points_norm = nearestNeighPointTnf(matches,target_points_norm) 43 | else: 44 | raise ValueError('interpolation method {} invalid'.format(interp)) 45 | 46 | warped_points = PointsToPixelCoords(warped_points_norm,source_im_size) 47 | 48 | L_pck = batch['L_pck'].data 49 | 50 | current_batch_size=batch['source_im_size'].size(0) 51 | indices = range(batch_start_idx,batch_start_idx+current_batch_size) 52 | 53 | # compute PCK 54 | pck_batch = pck(source_points.data, warped_points.data, L_pck) 55 | stats['point_tnf']['pck'][indices] = pck_batch.unsqueeze(1).cpu().numpy() 56 | 57 | return stats -------------------------------------------------------------------------------- /lib/im_pair_dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.utils.data import Dataset 6 | from skimage import io 7 | import pandas as pd 8 | import numpy as np 9 | from lib.transformation import AffineTnf 10 | import torchvision 11 | import PIL 12 | from PIL import Image 13 | 14 | class ImagePairDataset(Dataset): 15 | 16 | """ 17 | 18 | Image pair dataset used for weak supervision 19 | 20 | 21 | Args: 22 | csv_file (string): Path to the csv file with image names and transformations. 23 | training_image_path (string): Directory with the images. 24 | output_size (2-tuple): Desired output size 25 | transform (callable): Transformation for post-processing the training pair (eg. image normalization) 26 | 27 | """ 28 | 29 | def __init__(self, dataset_csv_path, dataset_csv_file, dataset_image_path, dataset_size=0,output_size=(240,240),transform=None,random_crop=False,random_affine=False): 30 | self.random_crop=random_crop 31 | if random_affine: 32 | self.random_affine=torchvision.transforms.RandomAffine(40, translate=None, scale=(1,1.5), shear=40, resample=PIL.Image.BILINEAR) 33 | else: 34 | self.random_affine=None 35 | self.out_h, self.out_w = output_size 36 | self.train_data = pd.read_csv(os.path.join(dataset_csv_path,dataset_csv_file)) 37 | if dataset_size is not None and dataset_size!=0: 38 | dataset_size = min((dataset_size,len(self.train_data))) 39 | self.train_data = self.train_data.iloc[0:dataset_size,:] 40 | self.img_A_names = self.train_data.iloc[:,0] 41 | self.img_B_names = self.train_data.iloc[:,1] 42 | self.set = self.train_data.iloc[:,2].to_numpy() 43 | self.flip = self.train_data.iloc[:, 3].to_numpy().astype('int') 44 | self.dataset_image_path = dataset_image_path 45 | self.transform = transform 46 | # no cuda as dataset is called from CPU threads in dataloader and produces confilct 47 | self.affineTnf = AffineTnf(out_h=self.out_h, out_w=self.out_w, use_cuda = False) 48 | 49 | def __len__(self): 50 | return len(self.img_A_names) 51 | 52 | def __getitem__(self, idx): 53 | # get pre-processed images 54 | try: 55 | image_A,im_size_A = self.get_image(self.img_A_names,idx,self.flip[idx]) 56 | image_B,im_size_B = self.get_image(self.img_B_names,idx,self.flip[idx],affine=self.random_affine is not None) 57 | except: 58 | return self.__getitem__(np.random.randint(self.__len__())) 59 | 60 | image_set = self.set[idx] 61 | 62 | sample = {'source_image': image_A, 'target_image': image_B, 'source_im_size': im_size_A, 'target_im_size': im_size_B, 'set':image_set} 63 | 64 | if self.transform: 65 | sample = self.transform(sample) 66 | 67 | return sample 68 | 69 | def get_image(self,img_name_list,idx,flip,affine=False): 70 | img_name = os.path.join(self.dataset_image_path, img_name_list.iloc[idx]) 71 | #image = io.imread(img_name) 72 | image = Image.open(img_name) 73 | 74 | if affine: 75 | image = self.random_affine(image) 76 | 77 | # convert from PIL to numpy 78 | image = np.array(image) 79 | 80 | # if grayscale convert to 3-channel image 81 | if image.ndim==2: 82 | image=np.repeat(np.expand_dims(image,2),axis=2,repeats=3) 83 | 84 | # do random crop 85 | if self.random_crop: 86 | h,w,c=image.shape 87 | top=np.random.randint(h/4) 88 | bottom=int(3*h/4+np.random.randint(h/4)) 89 | left=np.random.randint(w/4) 90 | right=int(3*w/4+np.random.randint(w/4)) 91 | image = image[top:bottom,left:right,:] 92 | 93 | # flip horizontally if needed 94 | if flip: 95 | image=np.flip(image,1) 96 | 97 | # get image size 98 | im_size = np.asarray(image.shape) 99 | 100 | # convert to torch Variable 101 | image = np.expand_dims(image.transpose((2,0,1)),0) 102 | image = torch.Tensor(image.astype(np.float32)) 103 | image_var = Variable(image,requires_grad=False) 104 | 105 | # Resize image using bilinear sampling with identity affine tnf 106 | image = self.affineTnf(image_var).data.squeeze(0) 107 | 108 | im_size = torch.Tensor(im_size.astype(np.float32)) 109 | 110 | return (image, im_size) -------------------------------------------------------------------------------- /lib/knn.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import torch 3 | 4 | res = faiss.StandardGpuResources() # use a single GPU 5 | 6 | def swig_ptr_from_FloatTensor(x): 7 | assert x.is_contiguous() 8 | assert x.dtype == torch.float32 9 | return faiss.cast_integer_to_float_ptr( 10 | x.storage().data_ptr() + x.storage_offset() * 4) 11 | 12 | def swig_ptr_from_LongTensor(x): 13 | assert x.is_contiguous() 14 | assert x.dtype == torch.int64, 'dtype=%s' % x.dtype 15 | return faiss.cast_integer_to_long_ptr( 16 | x.storage().data_ptr() + x.storage_offset() * 8) 17 | 18 | def search_index_pytorch(index, x, k, D=None, I=None): 19 | """call the search function of an index with pytorch tensor I/O (CPU 20 | and GPU supported)""" 21 | assert x.is_contiguous() 22 | n, d = x.size() 23 | assert d == index.d 24 | 25 | if D is None: 26 | D = torch.empty((n, k), dtype=torch.float32, device=x.device) 27 | else: 28 | assert D.size() == (n, k) 29 | 30 | if I is None: 31 | I = torch.empty((n, k), dtype=torch.int64, device=x.device) 32 | else: 33 | assert I.size() == (n, k) 34 | torch.cuda.synchronize() 35 | xptr = swig_ptr_from_FloatTensor(x) 36 | Iptr = swig_ptr_from_LongTensor(I) 37 | Dptr = swig_ptr_from_FloatTensor(D) 38 | index.search_c(n, xptr, 39 | k, Dptr, Iptr) 40 | torch.cuda.synchronize() 41 | return D, I 42 | 43 | def knn_faiss(feature_B, feature_A, k): 44 | b,ch,nA = feature_A.shape 45 | if b==1: 46 | feature_A = feature_A.view(ch,-1).t().contiguous() 47 | feature_B = feature_B.view(ch,-1).t().contiguous() 48 | index_cpu = faiss.IndexFlatL2(ch) 49 | index = faiss.index_cpu_to_gpu(res, 0, index_cpu) 50 | torch.cuda.synchronize() 51 | feature_B_ptr = swig_ptr_from_FloatTensor(feature_B) 52 | index.add_c(feature_B.shape[0], feature_B_ptr) 53 | dist, indx = search_index_pytorch(index, feature_A, k=k) 54 | dist = dist.t().unsqueeze(0).contiguous() 55 | indx = indx.t().unsqueeze(0).contiguous() 56 | else: 57 | feature_A = feature_A.view(b,ch,-1).permute(0,2,1).contiguous() 58 | feature_B = feature_B.view(b,ch,-1).permute(0,2,1).contiguous() 59 | dist = [] 60 | indx = [] 61 | for i in range(b): 62 | index_cpu = faiss.IndexFlatL2(ch) 63 | index = faiss.index_cpu_to_gpu(res, 0, index_cpu) 64 | torch.cuda.synchronize() 65 | feature_B_ptr = swig_ptr_from_FloatTensor(feature_B[i]) 66 | index.add_c(feature_B[i].shape[0], feature_B_ptr) 67 | dist_i, indx_i = search_index_pytorch(index, feature_A[i], k=k) 68 | dist_i = dist_i.t().unsqueeze(0).contiguous() 69 | indx_i = indx_i.t().unsqueeze(0).contiguous() 70 | dist.append(dist_i) 71 | indx.append(indx_i) 72 | dist = torch.cat(dist,dim=0) 73 | indx = torch.cat(indx,dim=0) 74 | return dist, indx 75 | 76 | # The knn_faiss_ivf function was working slower for small image resolutions 77 | # I leave it here for future reference or for other people who might find it useful 78 | 79 | def knn_faiss_ivf(feature_B, feature_A, k): 80 | b,ch,nA = feature_A.shape 81 | if b==1: 82 | feature_A = feature_A.view(ch,-1).t().contiguous() 83 | feature_B = feature_B.view(ch,-1).t().contiguous() 84 | 85 | quantizer = faiss.IndexFlatL2(ch) # the other index 86 | index_cpu = faiss.IndexIVFFlat(quantizer, ch, 100, faiss.METRIC_L2) 87 | index = faiss.index_cpu_to_gpu(res, 0, index_cpu) 88 | torch.cuda.synchronize() 89 | feature_B_ptr = swig_ptr_from_FloatTensor(feature_B) 90 | index.train(feature_B.cpu().numpy()) 91 | index.add_c(feature_B.shape[0], feature_B_ptr) 92 | 93 | dist, indx = search_index_pytorch(index, feature_A, k=k) 94 | dist = dist.t().unsqueeze(0).contiguous() 95 | indx = indx.t().unsqueeze(0).contiguous() 96 | else: 97 | feature_A = feature_A.view(b,ch,-1).permute(0,2,1).contiguous() 98 | feature_B = feature_B.view(b,ch,-1).permute(0,2,1).contiguous() 99 | dist = [] 100 | indx = [] 101 | for i in range(b): 102 | quantizer = faiss.IndexFlatL2(ch) # the other index 103 | index_cpu = faiss.IndexIVFFlat(quantizer, ch, 100, faiss.METRIC_L2) 104 | index = faiss.index_cpu_to_gpu(res, 0, index_cpu) 105 | torch.cuda.synchronize() 106 | feature_B_ptr = swig_ptr_from_FloatTensor(feature_B[i]) 107 | index.train(feature_B[i].cpu().numpy()) 108 | index.add_c(feature_B[i].shape[0], feature_B_ptr) 109 | 110 | dist_i, indx_i = search_index_pytorch(index, feature_A[i], k=k) 111 | dist_i = dist_i.t().unsqueeze(0).contiguous() 112 | indx_i = indx_i.t().unsqueeze(0).contiguous() 113 | dist.append(dist_i) 114 | indx.append(indx_i) 115 | dist = torch.cat(dist,dim=0) 116 | indx = torch.cat(indx,dim=0) 117 | return dist, indx -------------------------------------------------------------------------------- /lib/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Variable 6 | import torchvision.models as models 7 | import numpy as np 8 | import numpy.matlib 9 | import pickle 10 | import MinkowskiEngine as ME 11 | 12 | from .conv4d import Conv4d 13 | from .sparse import transpose_torch, transpose_me, torch_to_me, me_to_torch, corr_and_add 14 | 15 | def featureL2Norm(feature): 16 | epsilon = 1e-6 17 | norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature) 18 | return torch.div(feature,norm) 19 | 20 | class FeatureExtraction(torch.nn.Module): 21 | def __init__(self, train_fe=False, feature_extraction_cnn='resnet101', feature_extraction_model_file='', normalization=True, last_layer='', use_cuda=True): 22 | super(FeatureExtraction, self).__init__() 23 | self.normalization = normalization 24 | self.feature_extraction_cnn=feature_extraction_cnn 25 | # for resnet below 26 | resnet_feature_layers = ['conv1','bn1','relu','maxpool','layer1','layer2','layer3','layer4'] 27 | if feature_extraction_cnn.startswith('resnet'): 28 | if feature_extraction_cnn=='resnet101': 29 | self.model = models.resnet101(pretrained=True) 30 | elif feature_extraction_cnn=='resnet18': 31 | self.model = models.resnet18(pretrained=True) 32 | if last_layer=='': 33 | last_layer = 'layer3' 34 | resnet_module_list = [getattr(self.model,l) for l in resnet_feature_layers] 35 | last_layer_idx = resnet_feature_layers.index(last_layer) 36 | self.model = nn.Sequential(*resnet_module_list[:last_layer_idx+1]) 37 | if train_fe==False: 38 | # freeze parameters 39 | for param in self.model.parameters(): 40 | param.requires_grad = False 41 | # move to GPU 42 | if use_cuda: 43 | self.model = self.model.cuda() 44 | 45 | def forward(self, image_batch): 46 | features = self.model(image_batch) 47 | if self.normalization and not self.feature_extraction_cnn=='resnet101fpn': 48 | features = featureL2Norm(features) 49 | return features 50 | 51 | def change_stride(self): 52 | print('Changing FeatureExtraction stride') 53 | self.model[-1][0].conv1.stride=(1,1) 54 | self.model[-1][0].conv2.stride=(1,1) 55 | self.model[-1][0].downsample[0].stride=(1,1) 56 | 57 | def corr_dense(feature_A, feature_B): 58 | b,c,hA,wA = feature_A.size() 59 | b,c,hB,wB = feature_B.size() 60 | # reshape features for matrix multiplication 61 | feature_A = feature_A.view(b,c,hA*wA).transpose(1,2) # size [b,c,h*w] 62 | feature_B = feature_B.view(b,c,hB*wB) # size [b,c,h*w] 63 | # perform matrix mult. 64 | feature_mul = torch.bmm(feature_A,feature_B) 65 | # indexed [batch,row_A,col_A,row_B,col_B] 66 | correlation_tensor = feature_mul.view(b,hA,wA,hB,wB).unsqueeze(1) 67 | 68 | return correlation_tensor 69 | 70 | class SparseNeighConsensus(torch.nn.Module): 71 | def __init__(self, use_cuda=True, kernel_sizes=[3,3,3], channels=[10,10,1], symmetric_mode=True, bn=False): 72 | super(SparseNeighConsensus, self).__init__() 73 | self.symmetric_mode = symmetric_mode 74 | self.kernel_sizes = kernel_sizes 75 | self.channels = channels 76 | num_layers = len(kernel_sizes) 77 | nn_modules = list() 78 | for i in range(num_layers): 79 | if i==0: 80 | nn_modules.append(ME.MinkowskiReLU(inplace=True)) 81 | ch_in = 1 82 | else: 83 | ch_in = channels[i-1] 84 | ch_out = channels[i] 85 | k_size = kernel_sizes[i] 86 | if ch_out==1 or bn==False: 87 | nn_modules.append(ME.MinkowskiConvolution(ch_in,ch_out,kernel_size=k_size,has_bias=True,dimension=4)) 88 | elif bn==True: 89 | nn_modules.append(torch.nn.Sequential( 90 | ME.MinkowskiConvolution(ch_in,ch_out,kernel_size=k_size,has_bias=True,dimension=4), 91 | ME.MinkowskiBatchNorm(ch_out))) 92 | nn_modules.append(ME.MinkowskiReLU(inplace=True)) 93 | self.conv = nn.Sequential(*nn_modules) 94 | # self.add = ME.MinkowskiUnion() 95 | 96 | if use_cuda: 97 | self.conv.cuda() 98 | 99 | def forward(self, x): 100 | if self.symmetric_mode: 101 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 102 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 103 | x = me_to_torch(self.conv(x)) + transpose_torch(me_to_torch(self.conv(transpose_me(x)))) 104 | x = x.coalesce() 105 | else: 106 | x = me_to_torch(self.conv(x)) 107 | return x 108 | 109 | class DenseNeighConsensus(torch.nn.Module): 110 | def __init__(self, use_cuda=True, kernel_sizes=[3,3,3], channels=[10,10,1], symmetric_mode=True): 111 | super(DenseNeighConsensus, self).__init__() 112 | self.symmetric_mode = symmetric_mode 113 | self.kernel_sizes = kernel_sizes 114 | self.channels = channels 115 | num_layers = len(kernel_sizes) 116 | nn_modules = list() 117 | for i in range(num_layers): 118 | if i==0: 119 | ch_in = 1 120 | else: 121 | ch_in = channels[i-1] 122 | ch_out = channels[i] 123 | k_size = kernel_sizes[i] 124 | nn_modules.append(Conv4d(in_channels=ch_in,out_channels=ch_out,kernel_size=k_size,bias=True)) 125 | nn_modules.append(nn.ReLU(inplace=True)) 126 | self.conv = nn.Sequential(*nn_modules) 127 | if use_cuda: 128 | self.conv.cuda() 129 | 130 | def forward(self, x): 131 | if self.symmetric_mode: 132 | # apply network on the input and its "transpose" (swapping A-B to B-A ordering of the correlation tensor), 133 | # this second result is "transposed back" to the A-B ordering to match the first result and be able to add together 134 | x = self.conv(x)+self.conv(x.permute(0,1,4,5,2,3)).permute(0,1,4,5,2,3) 135 | else: 136 | x = me_to_torch(self.conv(x)) 137 | return x 138 | 139 | def MutualMatching(corr4d): 140 | # mutual matching 141 | batch_size,ch,fs1,fs2,fs3,fs4 = corr4d.size() 142 | 143 | corr4d_B=corr4d.view(batch_size,fs1*fs2,fs3,fs4) # [batch_idx,k_A,i_B,j_B] 144 | corr4d_A=corr4d.view(batch_size,fs1,fs2,fs3*fs4) 145 | 146 | # get max 147 | corr4d_B_max,_=torch.max(corr4d_B,dim=1,keepdim=True) 148 | corr4d_A_max,_=torch.max(corr4d_A,dim=3,keepdim=True) 149 | 150 | eps = 1e-5 151 | corr4d_B=corr4d_B/(corr4d_B_max+eps) 152 | corr4d_A=corr4d_A/(corr4d_A_max+eps) 153 | 154 | corr4d_B=corr4d_B.view(batch_size,1,fs1,fs2,fs3,fs4) 155 | corr4d_A=corr4d_A.view(batch_size,1,fs1,fs2,fs3,fs4) 156 | 157 | corr4d=corr4d*(corr4d_A*corr4d_B) # parenthesis are important for symmetric output 158 | 159 | return corr4d 160 | 161 | def maxpool4d(corr4d_hres,k_size=4): 162 | slices=[] 163 | for i in range(k_size): 164 | for j in range(k_size): 165 | for k in range(k_size): 166 | for l in range(k_size): 167 | slices.append(corr4d_hres[:,0,i::k_size,j::k_size,k::k_size,l::k_size].unsqueeze(0)) 168 | slices=torch.cat(tuple(slices),dim=1) 169 | corr4d,max_idx=torch.max(slices,dim=1,keepdim=True) 170 | max_l=torch.fmod(max_idx,k_size) 171 | max_k=torch.fmod(max_idx.sub(max_l).div(k_size),k_size) 172 | max_j=torch.fmod(max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size),k_size) 173 | max_i=max_idx.sub(max_l).div(k_size).sub(max_k).div(k_size).sub(max_j).div(k_size) 174 | # i,j,k,l represent the *relative* coords of the max point in the box of size k_size*k_size*k_size*k_size 175 | return (corr4d,max_i,max_j,max_k,max_l) 176 | 177 | class ImMatchNet(nn.Module): 178 | def __init__(self, 179 | feature_extraction_cnn='resnet101', 180 | feature_extraction_last_layer='', 181 | feature_extraction_model_file='', 182 | return_correlation=False, 183 | ncons_kernel_sizes=[3,3,3], 184 | ncons_channels=[10,10,1], 185 | normalize_features=True, 186 | train_fe=False, 187 | use_cuda=True, 188 | relocalization_k_size=0, 189 | half_precision=False, 190 | checkpoint=None, 191 | sparse=False, 192 | symmetric_mode=True, 193 | k = 10, 194 | bn=False, 195 | return_fs=False, 196 | change_stride=False 197 | ): 198 | 199 | super(ImMatchNet, self).__init__() 200 | # Load checkpoint 201 | if checkpoint is not None and checkpoint is not '': 202 | ncons_channels, ncons_kernel_sizes, checkpoint = self.get_checkpoint_parameters(checkpoint) 203 | 204 | self.use_cuda = use_cuda 205 | self.normalize_features = normalize_features 206 | self.return_correlation = return_correlation 207 | self.relocalization_k_size = relocalization_k_size 208 | self.half_precision = half_precision 209 | self.sparse = sparse 210 | self.k = k 211 | self.d2 = feature_extraction_cnn=='d2' 212 | self.return_fs = return_fs 213 | self.Npts = None 214 | 215 | self.FeatureExtraction = FeatureExtraction(train_fe=train_fe, 216 | feature_extraction_cnn=feature_extraction_cnn, 217 | feature_extraction_model_file=feature_extraction_model_file, 218 | last_layer=feature_extraction_last_layer, 219 | normalization=normalize_features, 220 | use_cuda=self.use_cuda) 221 | self.FeatureExtraction.eval() 222 | 223 | if sparse: 224 | self.NeighConsensus = SparseNeighConsensus(use_cuda=self.use_cuda, 225 | kernel_sizes=ncons_kernel_sizes, 226 | channels=ncons_channels, 227 | symmetric_mode=symmetric_mode, 228 | bn = bn) 229 | else: 230 | self.NeighConsensus = DenseNeighConsensus(use_cuda=self.use_cuda, 231 | kernel_sizes=ncons_kernel_sizes, 232 | channels=ncons_channels, 233 | symmetric_mode=symmetric_mode) 234 | 235 | if checkpoint is not None and checkpoint is not '': self.load_weights(checkpoint) 236 | if self.half_precision: self.set_half_precision() 237 | if change_stride: self.FeatureExtraction.change_stride() 238 | 239 | def get_checkpoint_parameters(self, checkpoint): 240 | print('Loading checkpoint...') 241 | checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage) 242 | checkpoint['state_dict'] = OrderedDict([(k.replace('vgg', 'model'), v) for k, v in checkpoint['state_dict'].items()]) 243 | # override relevant parameters 244 | print('Using checkpoint parameters: ') 245 | ncons_channels=checkpoint['args'].ncons_channels 246 | print(' ncons_channels: '+str(ncons_channels)) 247 | ncons_kernel_sizes=checkpoint['args'].ncons_kernel_sizes 248 | print(' ncons_kernel_sizes: '+str(ncons_kernel_sizes)) 249 | return ncons_channels, ncons_kernel_sizes, checkpoint 250 | 251 | def load_weights(self, checkpoint): 252 | # Load weights 253 | print('Copying weights...') 254 | for name, param in self.FeatureExtraction.state_dict().items(): 255 | if 'num_batches_tracked' not in name: 256 | self.FeatureExtraction.state_dict()[name].copy_(checkpoint['state_dict']['FeatureExtraction.' + name]) 257 | for name, param in self.NeighConsensus.state_dict().items(): 258 | self.NeighConsensus.state_dict()[name].copy_(checkpoint['state_dict']['NeighConsensus.' + name]) 259 | print('Done!') 260 | 261 | def set_half_precision(self): 262 | for p in self.NeighConsensus.parameters(): 263 | p.data=p.data.half() 264 | for l in self.NeighConsensus.conv: 265 | if isinstance(l,Conv4d): 266 | l.use_half=True 267 | 268 | def forward(self, tnf_batch): 269 | # feature extraction 270 | feature_A = self.FeatureExtraction(tnf_batch['source_image']) 271 | feature_B = self.FeatureExtraction(tnf_batch['target_image']) 272 | 273 | if self.sparse: 274 | return self.process_sparse(feature_A, feature_B) 275 | else: 276 | return self.process_dense(feature_A, feature_B) 277 | 278 | def process_sparse(self, feature_A, feature_B): 279 | corr4d = corr_and_add(feature_A, feature_B, k = self.k, Npts = self.Npts) 280 | corr4d = self.NeighConsensus(corr4d) 281 | if self.return_fs: 282 | fs1, fs2 = feature_A.shape[-2:] 283 | fs3, fs4 = feature_B.shape[-2:] 284 | return corr4d, fs1, fs2, fs3, fs4 285 | else: 286 | return corr4d 287 | 288 | def process_dense(self, feature_A, feature_B): 289 | if self.half_precision: 290 | feature_A=feature_A.half() 291 | feature_B=feature_B.half() 292 | corr4d = corr_dense(feature_A,feature_B) 293 | if self.relocalization_k_size>1: 294 | corr4d,max_i,max_j,max_k,max_l=maxpool4d(corr4d,k_size=self.relocalization_k_size) 295 | corr4d = MutualMatching(corr4d) 296 | corr4d = self.NeighConsensus(corr4d) 297 | corr4d = MutualMatching(corr4d) 298 | if self.relocalization_k_size>1: 299 | delta4d = (max_i,max_j,max_k,max_l) 300 | return (corr4d, delta4d) 301 | else: 302 | return corr4d 303 | 304 | 305 | 306 | 307 | -------------------------------------------------------------------------------- /lib/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from skimage.io import imread 6 | import numpy as np 7 | 8 | def normalize_caffe(image): 9 | if image.ndim==3: return normalize_caffe(image.unsqueeze(0)).squeeze(0) 10 | # RGB -> BGR 11 | image = image[:, [2,1,0], :, :] 12 | # Zero-center by mean pixel 13 | mean = torch.tensor([103.939, 116.779, 123.68], device = image.device) 14 | image = image - mean.view(1,3,1,1) 15 | return image 16 | 17 | normalize_image_dict_caffe = lambda x: {k:normalize_caffe(v) if k in ['source_image', 18 | 'target_image'] else v for k,v in x.items()} 19 | 20 | class NormalizeImageDict(object): 21 | """ 22 | 23 | Normalizes Tensor images in dictionary 24 | 25 | Args: 26 | image_keys (list): dict. keys of the images to be normalized 27 | normalizeRange (bool): if True the image is divided by 255.0s 28 | 29 | """ 30 | 31 | def __init__(self,image_keys,normalizeRange=True): 32 | self.image_keys = image_keys 33 | self.normalizeRange=normalizeRange 34 | self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 35 | std=[0.229, 0.224, 0.225]) 36 | 37 | def __call__(self, sample): 38 | for key in self.image_keys: 39 | if self.normalizeRange: 40 | sample[key] /= 255.0 41 | sample[key] = self.normalize(sample[key]) 42 | return sample 43 | 44 | def normalize_image(image, forward=True, mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]): 45 | im_size = image.size() 46 | mean=torch.FloatTensor(mean).unsqueeze(1).unsqueeze(2) 47 | std=torch.FloatTensor(std).unsqueeze(1).unsqueeze(2) 48 | if image.is_cuda: 49 | mean = mean.cuda() 50 | std = std.cuda() 51 | if isinstance(image,torch.autograd.variable.Variable): 52 | mean = Variable(mean,requires_grad=False) 53 | std = Variable(std,requires_grad=False) 54 | if forward: 55 | if len(im_size)==3: 56 | result = image.sub(mean.expand(im_size)).div(std.expand(im_size)) 57 | elif len(im_size)==4: 58 | result = image.sub(mean.unsqueeze(0).expand(im_size)).div(std.unsqueeze(0).expand(im_size)) 59 | else: 60 | if len(im_size)==3: 61 | result = image.mul(std.expand(im_size)).add(mean.expand(im_size)) 62 | elif len(im_size)==4: 63 | result = image.mul(std.unsqueeze(0).expand(im_size)).add(mean.unsqueeze(0).expand(im_size)) 64 | 65 | return result 66 | 67 | imreadth = lambda x: torch.Tensor(imread(x).astype(np.float32)).transpose(1,2).transpose(0,1) 68 | normalize = lambda x: NormalizeImageDict(['im'])({'im':x})['im'] 69 | 70 | # allow rectangular images. Does not modify aspect ratio. 71 | resize = lambda x, image_size, scale_factor: F.interpolate(x.unsqueeze(0).cuda(), 72 | size=(int(np.floor(x.shape[1]/(np.max(x.shape[1:])/image_size)*scale_factor)/scale_factor), 73 | int(np.floor(x.shape[2]/(np.max(x.shape[1:])/image_size)*scale_factor)/scale_factor)),mode='bilinear', align_corners=False) 74 | 75 | padim = lambda x, h_max: torch.cat((x,x.view(-1)[0].clone().expand(1,3,h_max-x.shape[2],x.shape[3])/1e20),dim=2) if x.shape[2] BGR 28 | image = image[:: -1, :, :] 29 | elif preprocessing == 'torch': 30 | mean = np.array([0.485, 0.456, 0.406]) 31 | std = np.array([0.229, 0.224, 0.225]) 32 | image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) 33 | image *= 255.0 34 | else: 35 | raise ValueError('Unknown preprocessing parameter.') 36 | image = np.transpose(image, [1, 2, 0]) 37 | image = np.round(image).astype(np.uint8) 38 | return image 39 | 40 | def save_plot(filename): 41 | plt.gca().set_axis_off() 42 | plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 43 | hspace = 0, wspace = 0) 44 | plt.margins(0,0) 45 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 46 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 47 | plt.savefig(filename, bbox_inches = 'tight', 48 | pad_inches = 0) 49 | 50 | padim = lambda x,h_max: torch.cat((x,x.view(-1)[0].clone().expand(1,3,h_max-x.shape[2],x.shape[3])/1e20),dim=2) if x.shape[2]0).long(),dim=1,keepdim=True)-1 114 | x_minus[x_minus<0]=0 # fix edge case 115 | x_plus = x_minus+1 116 | 117 | y_minus = torch.sum(((target_points_norm[:,1,:]-grid)>0).long(),dim=1,keepdim=True)-1 118 | y_minus[y_minus<0]=0 # fix edge case 119 | y_plus = y_minus+1 120 | 121 | toidx = lambda x,y,L: y*L+x 122 | 123 | m_m_idx = toidx(x_minus,y_minus,feature_size) 124 | p_p_idx = toidx(x_plus,y_plus,feature_size) 125 | p_m_idx = toidx(x_plus,y_minus,feature_size) 126 | m_p_idx = toidx(x_minus,y_plus,feature_size) 127 | 128 | topoint = lambda idx, X, Y: torch.cat((X[idx.view(-1)].view(b,1,N).contiguous(), 129 | Y[idx.view(-1)].view(b,1,N).contiguous()),dim=1) 130 | 131 | P_m_m = topoint(m_m_idx,X_,Y_) 132 | P_p_p = topoint(p_p_idx,X_,Y_) 133 | P_p_m = topoint(p_m_idx,X_,Y_) 134 | P_m_p = topoint(m_p_idx,X_,Y_) 135 | 136 | multrows = lambda x: x[:,0,:]*x[:,1,:] 137 | 138 | f_p_p=multrows(torch.abs(target_points_norm-P_m_m)) 139 | f_m_m=multrows(torch.abs(target_points_norm-P_p_p)) 140 | f_m_p=multrows(torch.abs(target_points_norm-P_p_m)) 141 | f_p_m=multrows(torch.abs(target_points_norm-P_m_p)) 142 | 143 | Q_m_m = topoint(m_m_idx,xA.view(-1),yA.view(-1)) 144 | Q_p_p = topoint(p_p_idx,xA.view(-1),yA.view(-1)) 145 | Q_p_m = topoint(p_m_idx,xA.view(-1),yA.view(-1)) 146 | Q_m_p = topoint(m_p_idx,xA.view(-1),yA.view(-1)) 147 | 148 | warped_points_norm = (Q_m_m*f_m_m+Q_p_p*f_p_p+Q_m_p*f_m_p+Q_p_m*f_p_m)/(f_p_p+f_m_m+f_m_p+f_p_m) 149 | return warped_points_norm 150 | 151 | 152 | def PointsToUnitCoords(P,im_size): 153 | h,w = im_size[:,0],im_size[:,1] 154 | P_norm = P.clone() 155 | # normalize Y 156 | P_norm[:,0,:] = normalize_axis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 157 | # normalize X 158 | P_norm[:,1,:] = normalize_axis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 159 | return P_norm 160 | 161 | def PointsToPixelCoords(P,im_size): 162 | h,w = im_size[:,0],im_size[:,1] 163 | P_norm = P.clone() 164 | # normalize Y 165 | P_norm[:,0,:] = unnormalize_axis(P[:,0,:],w.unsqueeze(1).expand_as(P[:,0,:])) 166 | # normalize X 167 | P_norm[:,1,:] = unnormalize_axis(P[:,1,:],h.unsqueeze(1).expand_as(P[:,1,:])) 168 | return P_norm -------------------------------------------------------------------------------- /lib/py_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | 5 | def imshow_image(image, preprocessing=None): 6 | if preprocessing is None: 7 | pass 8 | elif preprocessing == 'caffe': 9 | mean = np.array([103.939, 116.779, 123.68]) 10 | image = image + mean.reshape([3, 1, 1]) 11 | # RGB -> BGR 12 | image = image[:: -1, :, :] 13 | elif preprocessing == 'torch': 14 | mean = np.array([0.485, 0.456, 0.406]) 15 | std = np.array([0.229, 0.224, 0.225]) 16 | image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1]) 17 | image *= 255.0 18 | else: 19 | raise ValueError('Unknown preprocessing parameter.') 20 | image = np.transpose(image, [1, 2, 0]) 21 | image = np.round(image).astype(np.uint8) 22 | return image 23 | 24 | def create_file_path(filename): 25 | if not os.path.exists(os.path.dirname(filename)): 26 | try: 27 | os.makedirs(os.path.dirname(filename)) 28 | except OSError as exc: # Guard against race condition 29 | if exc.errno != errno.EEXIST: 30 | raise -------------------------------------------------------------------------------- /lib/relocalize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.ops as O 3 | import torch.nn.functional as F 4 | from lib.model import featureL2Norm, corr_dense 5 | from lib.sparse import corr_and_add 6 | 7 | def relocalize(xA_,yA_,xB_,yB_,score_,feature_A_2x, feature_B_2x, N_matches=None, upsample_positions=True, crop_size = 2): 8 | assert crop_size==3 or crop_size==2 9 | 10 | if N_matches is None: 11 | N_matches = xA_.shape[1] 12 | else: 13 | idx = torch.argsort(-score_.view(-1)) 14 | N_matches = min(N_matches, idx.shape[0]) 15 | idx = idx[:N_matches] 16 | score_ = score_[:,idx] 17 | xA_ = xA_[:,idx] 18 | yA_ = yA_[:,idx] 19 | xB_ = xB_[:,idx] 20 | yB_ = yB_[:,idx] 21 | 22 | if upsample_positions: 23 | xA_ = xA_*2 24 | yA_ = yA_*2 25 | xB_ = xB_*2 26 | yB_ = yB_*2 27 | 28 | coords_A = torch.cat( 29 | (torch.zeros(1,N_matches).to(xA_.device), 30 | xA_-(crop_size%2), 31 | yA_-(crop_size%2), 32 | xA_+1, 33 | yA_+1), 34 | dim = 0 35 | ).t() 36 | 37 | coords_B = torch.cat( 38 | (torch.zeros(1,N_matches).to(xB_.device), 39 | xB_-(crop_size%2), 40 | yB_-(crop_size%2), 41 | xB_+1, 42 | yB_+1), 43 | dim = 0 44 | ).t() 45 | 46 | ch = feature_A_2x.shape[1] 47 | feature_A_local = O.roi_pool(feature_A_2x, 48 | coords_A, 49 | output_size=(crop_size,crop_size)).view(N_matches,ch,-1,1) 50 | feature_B_local = O.roi_pool(feature_B_2x, 51 | coords_B, 52 | output_size=(crop_size,crop_size)).view(N_matches,ch,1,-1) 53 | 54 | deltaY, deltaX = torch.meshgrid(torch.linspace(-(crop_size%2),1,crop_size), 55 | torch.linspace(-(crop_size%2),1,crop_size)) 56 | 57 | deltaX = deltaX.contiguous().view(-1).to(xA_.device) 58 | deltaY = deltaY.contiguous().view(-1).to(xA_.device) 59 | 60 | corr_local = (feature_A_local * feature_B_local).sum(dim=1) 61 | 62 | delta_A_idx = torch.argmax(corr_local.max(dim=2,keepdim=True)[0],dim=1) 63 | delta_B_idx = torch.argmax(corr_local.max(dim=1,keepdim=True)[0],dim=2) 64 | 65 | xA_ = xA_ + deltaX[delta_A_idx].t() 66 | yA_ = yA_ + deltaY[delta_A_idx].t() 67 | xB_ = xB_ + deltaX[delta_B_idx].t() 68 | yB_ = yB_ + deltaY[delta_B_idx].t() 69 | 70 | return xA_, yA_, xB_, yB_, score_ 71 | 72 | def relocalize_soft(xA_,yA_,xB_,yB_,score_, feature_A_2x, feature_B_2x, N_matches=None, sigma=10, upsample_positions=True): 73 | if N_matches is None: 74 | N_matches = xA_.shape[1] 75 | else: 76 | idx = torch.argsort(-score_.view(-1)) 77 | N_matches = min(N_matches, idx.shape[0]) 78 | idx = idx[:N_matches] 79 | score_ = score_[:,idx] 80 | xA_ = xA_[:,idx] 81 | yA_ = yA_[:,idx] 82 | xB_ = xB_[:,idx] 83 | yB_ = yB_[:,idx] 84 | 85 | if upsample_positions: 86 | xA_ = xA_*2 87 | yA_ = yA_*2 88 | xB_ = xB_*2 89 | yB_ = yB_*2 90 | 91 | coords_A = torch.cat( 92 | (torch.zeros(1,N_matches).to(xA_.device), 93 | xA_-1, 94 | yA_-1, 95 | xA_+1, 96 | yA_+1), 97 | dim = 0 98 | ).t() 99 | 100 | coords_B = torch.cat( 101 | (torch.zeros(1,N_matches).to(xB_.device), 102 | xB_-1, 103 | yB_-1, 104 | xB_+1, 105 | yB_+1), 106 | dim = 0 107 | ).t() 108 | 109 | ch = feature_A_2x.shape[1] 110 | feature_A_local = O.roi_pool(feature_A_2x,coords_A,output_size=(3,3)) 111 | feature_B_local = O.roi_pool(feature_B_2x,coords_B,output_size=(3,3)) 112 | 113 | deltaY, deltaX = torch.meshgrid(torch.linspace(-1,1,3),torch.linspace(-1,1,3)) 114 | 115 | deltaX = deltaX.contiguous().to(xA_.device).unsqueeze(0) 116 | deltaY = deltaY.contiguous().to(xA_.device).unsqueeze(0) 117 | 118 | corrA_B = (feature_A_local[:,:,1:2,1:2] * feature_B_local).sum(dim=1).mul(sigma).view(N_matches,-1).softmax(dim=1).view(N_matches,3,3) 119 | corrB_A = (feature_B_local[:,:,1:2,1:2] * feature_A_local).sum(dim=1).mul(sigma).view(N_matches,-1).softmax(dim=1).view(N_matches,3,3) 120 | 121 | deltaX_B = (corrA_B * deltaX).view(N_matches,-1).sum(dim=1).unsqueeze(0) 122 | deltaY_B = (corrA_B * deltaY).view(N_matches,-1).sum(dim=1).unsqueeze(0) 123 | 124 | deltaX_A = (corrB_A * deltaX).view(N_matches,-1).sum(dim=1).unsqueeze(0) 125 | deltaY_A = (corrB_A * deltaY).view(N_matches,-1).sum(dim=1).unsqueeze(0) 126 | 127 | xA_ = xA_ + deltaX_A 128 | yA_ = yA_ + deltaY_A 129 | xB_ = xB_ + deltaX_B 130 | yB_ = yB_ + deltaY_B 131 | 132 | return xA_, yA_, xB_, yB_, score_ 133 | 134 | # redefine forward function for evaluation with relocalization 135 | def eval_model_reloc(model, batch, args=None): 136 | 137 | benchmark = False if args is None else args.benchmark 138 | relocalize = True if args is None else args.relocalize 139 | reloc_hard_crop_size = 2 if args is None else args.reloc_hard_crop_size 140 | no_ncnet = False if args is None else args.no_ncnet 141 | 142 | if benchmark: 143 | start = torch.cuda.Event(enable_timing=True) 144 | mid = torch.cuda.Event(enable_timing=True) 145 | end = torch.cuda.Event(enable_timing=True) 146 | start.record() 147 | # feature extraction 148 | if relocalize: 149 | feature_A_2x = model.FeatureExtraction.model(batch['source_image']) 150 | feature_B_2x = model.FeatureExtraction.model(batch['target_image']) 151 | 152 | if reloc_hard_crop_size==3: 153 | feature_A = F.max_pool2d(feature_A_2x, kernel_size=3, stride=2, padding=1) 154 | feature_B = F.max_pool2d(feature_B_2x, kernel_size=3, stride=2, padding=1) 155 | elif reloc_hard_crop_size==2: 156 | feature_A = F.max_pool2d(feature_A_2x, kernel_size=2, stride=2, padding=0) 157 | feature_B = F.max_pool2d(feature_B_2x, kernel_size=2, stride=2, padding=0) 158 | 159 | feature_A_2x = featureL2Norm(feature_A_2x) 160 | feature_B_2x = featureL2Norm(feature_B_2x) 161 | else: 162 | feature_A = model.FeatureExtraction.model(batch['source_image']) 163 | feature_B = model.FeatureExtraction.model(batch['target_image']) 164 | feature_A_2x, feature_B_2x = None, None 165 | 166 | feature_A = featureL2Norm(feature_A) 167 | feature_B = featureL2Norm(feature_B) 168 | 169 | fs1, fs2 = feature_A.shape[-2:] 170 | fs3, fs4 = feature_B.shape[-2:] 171 | 172 | if benchmark: 173 | mid.record() 174 | 175 | 176 | if no_ncnet: 177 | corr4d = None 178 | else: 179 | corr4d = corr_and_add(feature_A, feature_B, k = model.k, Npts=None) 180 | corr4d = model.NeighConsensus(corr4d) 181 | 182 | if benchmark: 183 | end.record() 184 | torch.cuda.synchronize() 185 | fe_time = start.elapsed_time(mid)/1000 186 | cnn_time = mid.elapsed_time(end)/1000 187 | return corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4, fe_time, cnn_time 188 | 189 | return corr4d, feature_A_2x, feature_B_2x, fs1, fs2, fs3, fs4 190 | -------------------------------------------------------------------------------- /lib/sparse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | from collections import defaultdict 4 | from .point_tnf import normalize_axis 5 | import math 6 | from .knn import knn_faiss 7 | import numpy as np 8 | 9 | def sparse_corr(feature_A, 10 | feature_B, 11 | k=10, 12 | coords_A=None, 13 | coords_B=None, 14 | reverse=False, 15 | ratio=False, 16 | sparse_type='torch', 17 | return_indx=False, 18 | fsize = None, 19 | bidx = None): 20 | 21 | b,ch=feature_B.shape[:2] 22 | 23 | if fsize is None: 24 | hA, wA = feature_A.shape[2:] 25 | hB, wB = feature_B.shape[2:] 26 | else: 27 | hA, wA = fsize 28 | hB, wB = fsize 29 | 30 | feature_A = feature_A.view(b,ch,-1) 31 | feature_B = feature_B.view(b,ch,-1) 32 | 33 | nA = feature_A.shape[2] 34 | nB = feature_B.shape[2] 35 | 36 | with torch.no_grad(): 37 | dist_squared, indx = knn_faiss(feature_B, feature_A, k) 38 | 39 | if bidx is None: bidx = torch.arange(b).view(b,1,1) 40 | bidx = bidx.expand_as(indx).contiguous() 41 | 42 | if feature_A.requires_grad: 43 | corr = (feature_A.permute(1,0,2).unsqueeze(2) * \ 44 | feature_B.permute(1,0,2)[:,bidx.view(-1),indx.view(-1)].view(ch,b,k,nA)).sum(dim=0).contiguous() 45 | else: 46 | corr = 1-dist_squared/2 # [b,k,nA] 47 | 48 | 49 | if ratio: 50 | corr_ratio=corr/corr[:,:1,:] 51 | 52 | if coords_A is None: 53 | YA,XA=torch.meshgrid(torch.arange(hA),torch.arange(wA)) 54 | YA=YA.contiguous() 55 | XA=XA.contiguous() 56 | yA=YA.view(-1).unsqueeze(0).unsqueeze(0).expand(b,k,nA).contiguous().view(-1,1) 57 | xA=XA.view(-1).unsqueeze(0).unsqueeze(0).expand(b,k,nA).contiguous().view(-1,1) 58 | else: 59 | yA,xA = coords_A 60 | yA=yA.view(-1).unsqueeze(0).unsqueeze(0).expand(b,k,nA).contiguous().view(-1,1) 61 | xA=xA.view(-1).unsqueeze(0).unsqueeze(0).expand(b,k,nA).contiguous().view(-1,1) 62 | 63 | if coords_B is None: 64 | YB,XB=torch.meshgrid(torch.arange(hB),torch.arange(wB)) 65 | YB=YB.contiguous() 66 | XB=XB.contiguous() 67 | yB=YB.view(-1)[indx.view(-1).cpu()].view(-1,1) 68 | xB=XB.view(-1)[indx.view(-1).cpu()].view(-1,1) 69 | else: 70 | yB,xB = coords_B 71 | yB=yB.view(-1)[indx.view(-1).cpu()].view(-1,1) 72 | xB=xB.view(-1)[indx.view(-1).cpu()].view(-1,1) 73 | 74 | bidx = bidx.view(-1,1) 75 | corr=corr.view(-1,1) 76 | if ratio: corr_ratio = corr_ratio.view(-1,1) 77 | 78 | if reverse: 79 | yA,xA,yB,xB=yB,xB,yA,xA 80 | hA,wA,hB,wB=hB,wB,hA,wA 81 | 82 | if sparse_type == 'me': 83 | coords = torch.cat((bidx, yA, xA, yB, xB),dim=1).int() 84 | scorr = ME.SparseTensor(corr, coords) 85 | 86 | if ratio: scorr_ratio = ME.SparseTensor(corr_ratio,coords) 87 | 88 | elif sparse_type == 'torch': 89 | coords = torch.cat((bidx, yA, xA, yB, xB),dim=1).long().to(corr.device).t() 90 | scorr = torch.sparse.FloatTensor(coords,corr,torch.Size([b,hA,wA,hB,wB,1])) 91 | 92 | if ratio: scorr_ratio = torch.sparse.FloatTensor(coords,corr_ratio,torch.Size([b,hA,wA,hB,wB,1])) 93 | 94 | elif sparse_type == 'raw': 95 | coords = torch.cat((bidx, yA, xA, yB, xB),dim=1).int() 96 | scorr = (corr, coords) 97 | 98 | if ratio: scorr_ratio = (corr_ratio,coords) 99 | 100 | else: 101 | raise ValueError('sparse type {} not recognized'.format(sparse_type)) 102 | 103 | if ratio: return scorr, scorr_ratio 104 | if return_indx: return scorr, indx 105 | return scorr 106 | 107 | def torch_to_me(sten): 108 | sten = sten.coalesce() 109 | indices = sten.indices().t().int().cpu() 110 | 111 | return ME.SparseTensor(sten.values(), indices) 112 | 113 | def me_to_torch(sten): 114 | values = sten.feats 115 | indices = sten.coords.t().long().to(values.device) 116 | 117 | sten = torch.sparse.FloatTensor(indices,values).coalesce() 118 | 119 | return sten 120 | 121 | def corr_and_add( 122 | feature_A, 123 | feature_B, 124 | k=10, 125 | coords_A=None, 126 | coords_B=None, 127 | Npts=None): 128 | 129 | # compute sparse correlation from A to B 130 | scorr = sparse_corr( 131 | feature_A, 132 | feature_B, 133 | k=k, 134 | ratio=False, sparse_type='raw') 135 | 136 | # compute sparse correlation from B to A 137 | scorr2 = sparse_corr( 138 | feature_B, 139 | feature_A, 140 | k=k, 141 | ratio=False, 142 | reverse=True, sparse_type='raw') 143 | 144 | scorr = ME.SparseTensor(scorr[0],scorr[1]) 145 | scorr2 = ME.SparseTensor(scorr2[0],scorr2[1],coords_manager=scorr.coords_man,force_creation=True) 146 | 147 | scorr = ME.MinkowskiUnion()(scorr,scorr2) 148 | 149 | return scorr 150 | 151 | def transpose_me(sten): 152 | return ME.SparseTensor(sten.feats.clone(), 153 | sten.coords[:,[0,3,4,1,2]].clone()) 154 | 155 | def transpose_torch(sten): 156 | sten = sten.coalesce() 157 | indices = sten.indices()[[0,3,4,1,2],:] 158 | values = sten.values() 159 | 160 | return torch.sparse.FloatTensor(indices,values).coalesce() 161 | 162 | def get_scores(corr, reverse=False, k=10): 163 | if reverse: 164 | c=[3,4] 165 | else: 166 | c=[1,2] 167 | 168 | coords = corr.indices() 169 | values = corr.values().squeeze().clone() 170 | 171 | #knn = KNN(k=k, transpose_mode=False) 172 | batch_size = coords[:1,:].max()+1 173 | feature_size = coords[1:,:].max()+1 174 | 175 | loss = [] 176 | for b in range(batch_size): 177 | batch_indices = torch.nonzero(coords[0,:].view(-1)==b).view(-1) 178 | ref = coords[c,:][:,batch_indices] 179 | uniq_coords = torch.unique(ref,dim=1) 180 | dist, idx = knn_faiss(ref.unsqueeze(0).float(), uniq_coords.unsqueeze(0).float(), k) 181 | #dist, idx = knn(ref.unsqueeze(0), uniq_coords.unsqueeze(0)) 182 | mask = (dist == 0) 183 | #import pdb; pdb.set_trace() 184 | curr_vals = values[batch_indices[idx]]*mask 185 | zeros = torch.zeros((1,feature_size**2-curr_vals.shape[1],curr_vals.shape[2]),device=curr_vals.device) 186 | curr_vals_extended = torch.cat((curr_vals,zeros),dim=1) 187 | max_vals = torch.softmax(curr_vals_extended,dim=1).max(dim=1)[0].mean().unsqueeze(0) 188 | loss.append(max_vals) 189 | 190 | scores = torch.cat(loss,dim=0).mean() 191 | 192 | return scores 193 | 194 | 195 | def unique(ar, return_index=False, return_inverse=False, 196 | return_counts=False): 197 | 198 | ar = ar.view(-1) 199 | 200 | perm = ar.argsort() 201 | aux = ar[perm] 202 | 203 | mask = torch.zeros(aux.shape, dtype=torch.bool) 204 | mask[0] = True 205 | mask[1:] = aux[1:] != aux[:-1] 206 | 207 | ret = aux[mask] 208 | if not return_index and not return_inverse and not return_counts: 209 | return ret 210 | 211 | ret = ret, 212 | if return_index: 213 | ret += perm[mask], 214 | if return_inverse: 215 | imask = torch.cumsum(mask) - 1 216 | inv_idx = torch.zeros(mask.shape, dtype=torch.int64) 217 | inv_idx[perm] = imask 218 | ret += inv_idx, 219 | if return_counts: 220 | nonzero = torch.nonzero(mask)[0] 221 | idx = torch.zeros(nonzero.shape[0] + 1, dtype=nonzero.dtype) 222 | idx[:-1] = nonzero 223 | idx[-1] = mask.size 224 | ret += idx[1:] - idx[:-1], 225 | return ret 226 | 227 | 228 | def get_matches(out, reverse=True, fsize=40, scale='centered'): 229 | if isinstance(fsize,tuple): 230 | fs1, fs2, fs3, fs4 = fsize 231 | else: 232 | fs1, fs2, fs3, fs4 = fsize, fsize, fsize, fsize 233 | 234 | if reverse: 235 | c=[3,4] 236 | fh, fw = fs3, fs4 237 | else: 238 | c=[1,2] 239 | fh, fw = fs1, fs2 240 | 241 | coords = out.coords[:,c].cuda() 242 | feats = out.feats 243 | sorted_idx = torch.argsort(-feats,dim=0).view(-1) 244 | coords = coords[sorted_idx] 245 | 246 | coords_idx = coords[:,0]*fw+coords[:,1] 247 | _, matches_idx = unique(coords_idx, return_index=True) 248 | matches_idx = sorted_idx[matches_idx] 249 | 250 | matches_scores = feats[matches_idx].t() 251 | matches = out.coords.to(out.device)[matches_idx,1:] 252 | 253 | if scale=='centered': 254 | yA = normalize_axis(matches[:,0]+1,fs1).unsqueeze(0).to(out.device) 255 | xA = normalize_axis(matches[:,1]+1,fs2).unsqueeze(0).to(out.device) 256 | yB = normalize_axis(matches[:,2]+1,fs3).unsqueeze(0).to(out.device) 257 | xB = normalize_axis(matches[:,3]+1,fs4).unsqueeze(0).to(out.device) 258 | elif scale=='positive': 259 | yA = (matches[:,0].float()/(fs1-1)).unsqueeze(0).to(out.device) 260 | xA = (matches[:,1].float()/(fs2-1)).unsqueeze(0).to(out.device) 261 | yB = (matches[:,2].float()/(fs3-1)).unsqueeze(0).to(out.device) 262 | xB = (matches[:,3].float()/(fs4-1)).unsqueeze(0).to(out.device) 263 | elif scale=='none': 264 | yA = (matches[:,0].float()).unsqueeze(0).to(out.device) 265 | xA = (matches[:,1].float()).unsqueeze(0).to(out.device) 266 | yB = (matches[:,2].float()).unsqueeze(0).to(out.device) 267 | xB = (matches[:,3].float()).unsqueeze(0).to(out.device) 268 | 269 | return xA,yA,xB,yB,matches_scores 270 | 271 | 272 | def get_matches_both_dirs(corr4d, fs1, fs2, fs3, fs4): 273 | corr4d = torch_to_me(corr4d) 274 | (xA_,yA_,xB_,yB_,score_)=get_matches(corr4d, fsize=(fs1, fs2, fs3, fs4), reverse=False, scale='none') 275 | 276 | (xA2_,yA2_,xB2_,yB2_,score2_)=get_matches(corr4d, fsize=(fs1, fs2, fs3, fs4), reverse=True, scale='none') 277 | # fuse matches 278 | xA_=torch.cat((xA_,xA2_),1) 279 | yA_=torch.cat((yA_,yA2_),1) 280 | xB_=torch.cat((xB_,xB2_),1) 281 | yB_=torch.cat((yB_,yB2_),1) 282 | score_=torch.cat((score_,score2_),1) 283 | # remove duplicates 284 | all_matches = torch.cat((xA_,yA_,xB_,yB_),dim=0) 285 | _, matches_idx = np.unique(all_matches.cpu().numpy(),axis=1,return_index=True) 286 | score_ = score_[:,matches_idx] 287 | xA_,yA_,xB_,yB_ = all_matches[:1,matches_idx], all_matches[1:2,matches_idx], all_matches[2:3,matches_idx], all_matches[3:,matches_idx] 288 | return xA_, yA_, xB_, yB_, score_ 289 | -------------------------------------------------------------------------------- /lib/torch_util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | from torch.autograd import Variable 4 | from os import makedirs, remove 5 | from os.path import exists, join, basename, dirname 6 | import collections 7 | from .dataloader import default_collate 8 | 9 | def collate_custom(batch): 10 | """ Custom collate function for the Dataset class 11 | * It doesn't convert numpy arrays to stacked-tensors, but rather combines them in a list 12 | * This is useful for processing annotations of different sizes 13 | """ 14 | # this case will occur in first pass, and will convert a 15 | # list of dictionaries (returned by the threads by sampling dataset[idx]) 16 | # to a unified dictionary of collated values 17 | if isinstance(batch[0], collections.Mapping): 18 | return {key: collate_custom([d[key] for d in batch]) for key in batch[0]} 19 | # these cases will occur in recursion 20 | elif torch.is_tensor(batch[0]): # for tensors, use standrard collating function 21 | return default_collate(batch) 22 | else: # for other types (i.e. lists), return as is 23 | return batch 24 | 25 | class BatchTensorToVars(object): 26 | """Convert tensors in dict batch to vars 27 | """ 28 | def __init__(self, use_cuda=True): 29 | self.use_cuda=use_cuda 30 | 31 | def __call__(self, batch): 32 | batch_var = {} 33 | for key,value in batch.items(): 34 | if isinstance(value,torch.Tensor) and not self.use_cuda: 35 | batch_var[key] = Variable(value,requires_grad=False) 36 | elif isinstance(value,torch.Tensor) and self.use_cuda: 37 | batch_var[key] = Variable(value,requires_grad=False).cuda() 38 | else: 39 | batch_var[key] = value 40 | return batch_var 41 | 42 | def Softmax1D(x,dim): 43 | x_k = torch.max(x,dim)[0].unsqueeze(dim) 44 | x -= x_k.expand_as(x) 45 | exp_x = torch.exp(x) 46 | return torch.div(exp_x,torch.sum(exp_x,dim).unsqueeze(dim).expand_as(x)) 47 | 48 | def save_checkpoint(state, is_best, file, save_all_epochs=False): 49 | model_dir = dirname(file) 50 | model_fn = basename(file) 51 | # make dir if needed (should be non-empty) 52 | if model_dir!='' and not exists(model_dir): 53 | makedirs(model_dir) 54 | if save_all_epochs: 55 | torch.save(state, join(model_dir,str(state['epoch'])+'_' + model_fn)) 56 | if is_best: 57 | shutil.copyfile(join(model_dir,str(state['epoch'])+'_' + model_fn), join(model_dir,'best_' + model_fn)) 58 | else: 59 | torch.save(state, file) 60 | if is_best: 61 | shutil.copyfile(file, join(model_dir,'best_' + model_fn)) 62 | 63 | 64 | def str_to_bool(v): 65 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 66 | return True 67 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 68 | return False 69 | else: 70 | raise argparse.ArgumentTypeError('Boolean value expected.') 71 | 72 | def expand_dim(tensor,dim,desired_dim_len): 73 | sz = list(tensor.size()) 74 | sz[dim]=desired_dim_len 75 | return tensor.expand(tuple(sz)) 76 | -------------------------------------------------------------------------------- /lib/transformation.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | from skimage import io 5 | import pandas as pd 6 | import numpy as np 7 | import torch 8 | from torch.nn.modules.module import Module 9 | from torch.utils.data import Dataset 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | 13 | from lib.torch_util import expand_dim 14 | 15 | class AffineTnf(object): 16 | def __init__(self, out_h=240, out_w=240, use_cuda=True): 17 | self.out_h = out_h 18 | self.out_w = out_w 19 | self.use_cuda = use_cuda 20 | self.gridGen = AffineGridGen(out_h=out_h, out_w=out_w, use_cuda=use_cuda) 21 | self.theta_identity = torch.Tensor(np.expand_dims(np.array([[1,0,0],[0,1,0]]),0).astype(np.float32)) 22 | if use_cuda: 23 | self.theta_identity = self.theta_identity.cuda() 24 | 25 | def __call__(self, image_batch, theta_batch=None, out_h=None, out_w=None): 26 | if image_batch is None: 27 | b=1 28 | else: 29 | b=image_batch.size(0) 30 | if theta_batch is None: 31 | theta_batch = self.theta_identity 32 | theta_batch = theta_batch.expand(b,2,3).contiguous() 33 | theta_batch = Variable(theta_batch,requires_grad=False) 34 | 35 | # check if output dimensions have been specified at call time and have changed 36 | if (out_h is not None and out_w is not None) and (out_h!=self.out_h or out_w!=self.out_w): 37 | gridGen = AffineGridGen(out_h, out_w) 38 | else: 39 | gridGen = self.gridGen 40 | 41 | sampling_grid = gridGen(theta_batch) 42 | 43 | # sample transformed image 44 | warped_image_batch = F.grid_sample(image_batch, sampling_grid) 45 | 46 | return warped_image_batch 47 | 48 | 49 | class AffineGridGen(Module): 50 | def __init__(self, out_h=240, out_w=240, out_ch = 3, use_cuda=True): 51 | super(AffineGridGen, self).__init__() 52 | self.out_h = out_h 53 | self.out_w = out_w 54 | self.out_ch = out_ch 55 | 56 | def forward(self, theta): 57 | b=theta.size()[0] 58 | if not theta.size()==(b,2,3): 59 | theta = theta.view(-1,2,3) 60 | theta = theta.contiguous() 61 | batch_size = theta.size()[0] 62 | out_size = torch.Size((batch_size,self.out_ch,self.out_h,self.out_w)) 63 | return F.affine_grid(theta, out_size) 64 | -------------------------------------------------------------------------------- /lib_matlab/at_imageresize_nc4d.m: -------------------------------------------------------------------------------- 1 | function im1 = at_imageresize_nc4d(im1,imax) 2 | if nargin < 2, imax = 1920; end 3 | 4 | isz = size(im1(:,:,1)); 5 | if max(isz) > imax 6 | if isz(1) > isz(2) 7 | im1 = imresize(im1,[imax NaN]); 8 | else 9 | im1 = imresize(im1,[NaN imax]); 10 | end 11 | end 12 | 13 | 14 | % isz = size(im1(:,:,1)); 15 | % if (1920*1440) < prod(isz) 16 | % if isz(1) > isz(2) 17 | % im1 = imresize(im1,[1920 NaN]); 18 | % else 19 | % im1 = imresize(im1,[NaN 1920]); 20 | % end 21 | % end 22 | 23 | % isz = size(im1(:,:,1)); 24 | % if (1600*1200) < prod(isz) 25 | % scale = 1600/max(isz); 26 | % im1 = imresize(im1, scale); 27 | % % if isz(1) > isz(2) 28 | % % im1 = imresize(im1,[640 NaN]); 29 | % % else 30 | % % im1 = imresize(im1,[NaN 640]); 31 | % % end 32 | % end 33 | 34 | % isz = size(im1(:,:,1)); 35 | % if (1920*1440) < prod(isz) 36 | % im1 = imresize(im1,sqrt((1920*1440)/prod(isz))); 37 | % end 38 | -------------------------------------------------------------------------------- /lib_matlab/at_pv_wrapper.m: -------------------------------------------------------------------------------- 1 | function at_pv_wrapper(ii,dbscanlist_uniq,dbscantranslist_uniq,qlist_uniq,dblist_uniq,Plist_uniq,params) 2 | 3 | this_dbscan = dbscanlist_uniq{ii}; 4 | this_dbscantrans = dbscantranslist_uniq{ii}; 5 | this_qlist = qlist_uniq{ii}; 6 | this_dblist = dblist_uniq{ii}; 7 | this_Plist = Plist_uniq{ii}; 8 | %load scan 9 | load(fullfile(params.data.dir, params.data.db.scan.dir, this_dbscan), 'A'); 10 | 11 | [ ~, P_after ] = load_WUSTL_transformation(fullfile(params.data.dir, params.data.db.trans.dir, this_dbscantrans)); 12 | RGB = [A{5}, A{6}, A{7}]'; 13 | XYZ = [A{1}, A{2}, A{3}]'; 14 | XYZ = P_after * [XYZ; ones(1, length(XYZ))]; 15 | XYZ = bsxfun(@rdivide, XYZ(1:3, :), XYZ(4, :)); 16 | 17 | %compute synthesized images and similarity scores 18 | for jj = 1:1:length(this_qlist) 19 | parfor_nc4d_PV( this_qlist{jj}, this_dblist{jj}, this_Plist{jj}, RGB, XYZ, params ); 20 | % fprintf('densePV: %d / %d done. \n', jj, length(this_qlist)); 21 | end 22 | fprintf('ncnetPV: scan %s (%d / %d) done. \n', this_dbscan, ii, length(dbscanlist_uniq)); 23 | -------------------------------------------------------------------------------- /lib_matlab/ht_plotcurve_WUSTL.m: -------------------------------------------------------------------------------- 1 | %Note: It plots localization rates with varying position thresholds. 2 | set(0,'defaultAxesFontName','sans-serif'); 3 | set(0,'defaultTextFontName','sans-serif'); 4 | 5 | %% groundtruth 6 | refposes_matname = fullfile(params.gt.dir, params.gt.matname); 7 | load(refposes_matname, 'DUC1_RefList', 'DUC2_RefList'); 8 | 9 | %% evaluation 10 | vefig = figure();hold on; 11 | linehandles = zeros(1, length(method)); 12 | for ii = 1:1:length(method) 13 | ImgList = method(ii).ImgList; 14 | 15 | fp = fopen(sprintf('error_%s.txt',method(ii).description),'w'); 16 | 17 | %1. DUC1 18 | poserr_DUC1 = zeros(1, length(DUC1_RefList)); 19 | orierr_DUC1 = zeros(1, length(DUC1_RefList)); 20 | for jj = 1:1:length(DUC1_RefList) 21 | this_locid = strcmp(DUC1_RefList(jj).queryname, {ImgList.queryname}); 22 | if sum(this_locid) == 0 23 | poserr_DUC1(jj) = inf; 24 | orierr_DUC1(jj) = inf; 25 | keyboard; 26 | else 27 | top1_floor = strsplit(ImgList(this_locid).topNname{1}, '/');top1_floor = top1_floor{1}; 28 | isfloorcorrect = strcmp('DUC1', top1_floor); 29 | if isfloorcorrect && ~isnan(ImgList(this_locid).P{1}(1)) 30 | [poserr_DUC1(jj), orierr_DUC1(jj)] = p2dist(DUC1_RefList(jj).P, ImgList(this_locid).P{1}); 31 | else 32 | poserr_DUC1(jj) = inf; 33 | orierr_DUC1(jj) = inf; 34 | end 35 | 36 | fprintf(fp,'%s %f %f\n',DUC1_RefList(jj).queryname,poserr_DUC1(jj),orierr_DUC1(jj)); 37 | 38 | end 39 | end 40 | 41 | %2. DUC2 42 | poserr_DUC2 = zeros(1, length(DUC2_RefList)); 43 | orierr_DUC2 = zeros(1, length(DUC2_RefList)); 44 | for jj = 1:1:length(DUC2_RefList) 45 | this_locid = strcmp(DUC2_RefList(jj).queryname, {ImgList.queryname}); 46 | if sum(this_locid) == 0 47 | poserr_DUC2(jj) = inf; 48 | orierr_DUC2(jj) = inf; 49 | keyboard; 50 | else 51 | top1_floor = strsplit(ImgList(this_locid).topNname{1}, '/');top1_floor = top1_floor{1}; 52 | isfloorcorrect = strcmp('DUC2', top1_floor); 53 | if isfloorcorrect && ~isnan(ImgList(this_locid).P{1}(1)) 54 | [poserr_DUC2(jj), orierr_DUC2(jj)] = p2dist(DUC2_RefList(jj).P, ImgList(this_locid).P{1}); 55 | else 56 | poserr_DUC2(jj) = inf; 57 | orierr_DUC2(jj) = inf; 58 | end 59 | end 60 | 61 | % if poserr_DUC2(jj) > 1.0 62 | % fprintf(fp,'%s\n',DUC2_RefList(jj).queryname); 63 | % end 64 | 65 | fprintf(fp,'%s %f %f\n',DUC2_RefList(jj).queryname,poserr_DUC2(jj),orierr_DUC2(jj)); 66 | 67 | end 68 | fclose(fp); 69 | 70 | %3. all 71 | poserr_all = [poserr_DUC1, poserr_DUC2]; 72 | orierr_all = [orierr_DUC1, orierr_DUC2]; 73 | 74 | 75 | %plot curve 76 | eval_poserr = poserr_all; 77 | eval_orierr = orierr_all*180/pi; 78 | max_orierr = 10;%deg 79 | eval_poserr(eval_orierr>max_orierr) = inf; 80 | eval_err = sort(eval_poserr, 'ascend'); 81 | errthr_dot = [0:0.0625:1, 1.125:0.125:2]; 82 | localized_rate_dot = sum(repmat(eval_err', 1, size(errthr_dot, 2)) < repmat(errthr_dot, size(eval_err, 2), 1), 1) / size(eval_err, 2); 83 | 84 | figure(vefig); 85 | if mod(ii,2)==0 86 | start_ind=1 87 | else 88 | start_ind=2 89 | end 90 | linehandles(ii) = plot(errthr_dot, localized_rate_dot*100, method(ii).marker, 'color', method(ii).color, 'MarkerSize', method(ii).ms, 'LineWidth', 2.0, 'MarkerIndices',start_ind:2:length(errthr_dot)); 91 | set(linehandles(ii), 'MarkerFaceColor', get(linehandles(ii), 'Color')); 92 | drawnow; 93 | 94 | 95 | end 96 | 97 | figure(vefig); 98 | xlim([0 2]);ylim([0 90]);grid on; 99 | % xlim([0 0.5]);ylim([0 80]);grid on; 100 | lgnd = legend(linehandles, {method.description}, 'Location', 'southeast', 'FontSize', 10) %, 'Interpreter', 'latex'); 101 | %lgnd.FontName = 'Times New Roman'; 102 | xlabel('Distance threshold [meters]');ylabel('Correctly localized queries [%]'); 103 | set(gca, 'FontSize', 18); 104 | % set(gca, 'XTick', 0:0.5:3); 105 | set(gca, 'XTick', 0:0.25:2); 106 | % set(gca, 'XTick', 0:0.1:0.5); 107 | set(get(gcf,'CurrentAxes'),'Position',[0.15 0.13 0.8 0.8]); 108 | set(gcf,'PaperUnits','Inches','PaperPosition',[0 0 5 5]); 109 | drawnow; 110 | 111 | %save fig 112 | figname = fullfile(params.output.dir, sprintf('athr%.4f_%d.fig', max_orierr, length(poserr_all))); 113 | epsname = fullfile(params.output.dir, sprintf('athr%.4f_%d.eps', max_orierr, length(poserr_all))); 114 | svgname = fullfile(params.output.dir, sprintf('athr%.4f_%d.pdf', max_orierr, length(poserr_all))); 115 | 116 | saveas(gcf,svgname) 117 | print(vefig, '-depsc',epsname,'-r160'); 118 | savefig(vefig, figname); 119 | 120 | -------------------------------------------------------------------------------- /lib_matlab/ht_top10_NC4D_PV_localization.m: -------------------------------------------------------------------------------- 1 | %Note: It first synthesize query views according to top10 pose candedates 2 | %and compute similarity between original query and synthesized views. Pose 3 | %candidates are then re-scored by the similarity. 4 | 5 | %% densePV (top10 pose candidate -> pose verification) 6 | PV_topN = 10; 7 | % nc4dPV_matname = fullfile(params.output.dir, 'densePV_top10_shortlist.mat'); 8 | if exist(nc4dPV_matname, 'file') ~= 2 9 | 10 | %synthesis list 11 | qlist = cell(1, PV_topN*length(ImgList_NC4D)); 12 | dblist = cell(1, PV_topN*length(ImgList_NC4D)); 13 | Plist = cell(1, PV_topN*length(ImgList_NC4D)); 14 | for ii = 1:1:length(ImgList_NC4D) 15 | for jj = 1:1:PV_topN 16 | qlist{PV_topN*(ii-1)+jj} = ImgList_NC4D(ii).queryname; 17 | dblist{PV_topN*(ii-1)+jj} = ImgList_NC4D(ii).topNname{jj}; 18 | Plist{PV_topN*(ii-1)+jj} = ImgList_NC4D(ii).P{jj}; 19 | end 20 | end 21 | %find unique scans 22 | dbscanlist = cell(size(dblist)); 23 | dbscantranslist = cell(size(dblist)); 24 | for ii = 1:1:length(dblist) 25 | this_floorid = strsplit(dblist{ii}, '/');this_floorid = this_floorid{1}; 26 | info = parse_WUSTL_cutoutname( dblist{ii} ); 27 | dbscanlist{ii} = fullfile(this_floorid, [info.scene_id, '_scan_', info.scan_id, params.data.db.scan.matformat]); 28 | dbscantranslist{ii} = fullfile(this_floorid, 'transformations', [info.scene_id, '_trans_', info.scan_id, '.txt']); 29 | end 30 | [dbscanlist_uniq, sort_idx, uniq_idx] = unique(dbscanlist); 31 | dbscantranslist_uniq = dbscantranslist(sort_idx); 32 | qlist_uniq = cell(size(dbscanlist_uniq)); 33 | dblist_uniq = cell(size(dbscanlist_uniq)); 34 | Plist_uniq = cell(size(dbscanlist_uniq)); 35 | for ii = 1:1:length(dbscanlist_uniq) 36 | idx = uniq_idx == ii; 37 | qlist_uniq{ii} = qlist(idx); 38 | dblist_uniq{ii} = dblist(idx); 39 | Plist_uniq{ii} = Plist(idx); 40 | end 41 | 42 | %compute synthesized views and similarity 43 | parfor ii = 1:1:length(dbscanlist_uniq) 44 | at_pv_wrapper(ii,dbscanlist_uniq,dbscantranslist_uniq,qlist_uniq,dblist_uniq,Plist_uniq,params) 45 | end 46 | 47 | %load similarity score and reranking 48 | ImgList = struct('queryname', {}, 'topNname', {}, 'topNscore', {}, 'P', {}); 49 | for ii = 1:1:length(ImgList_NC4D) 50 | ImgList(ii).queryname = ImgList_NC4D(ii).queryname; 51 | ImgList(ii).topNname = ImgList_NC4D(ii).topNname(1:PV_topN); 52 | ImgList(ii).topNscore = zeros(1, PV_topN); 53 | ImgList(ii).P = ImgList_NC4D(ii).P(1:PV_topN); 54 | for jj = 1:1:PV_topN 55 | [~, dbbasename, ~] = fileparts(ImgList(ii).topNname{jj}); 56 | load(fullfile(params.output.synth.dir, ImgList(ii).queryname, [dbbasename, params.output.synth.matformat]), 'score'); 57 | ImgList(ii).topNscore(jj) = score; 58 | end 59 | 60 | %reranking 61 | [sorted_score, idx] = sort(ImgList(ii).topNscore, 'descend'); 62 | ImgList(ii).topNname = ImgList(ii).topNname(idx); 63 | ImgList(ii).topNscore = ImgList(ii).topNscore(idx); 64 | ImgList(ii).P = ImgList(ii).P(idx); 65 | end 66 | 67 | if exist(params.output.dir, 'dir') ~= 7 68 | mkdir(params.output.dir); 69 | end 70 | save('-v6', nc4dPV_matname, 'ImgList'); 71 | 72 | else 73 | load(nc4dPV_matname, 'ImgList'); 74 | end 75 | ImgList_NC4DPV = ImgList; 76 | -------------------------------------------------------------------------------- /lib_matlab/ir_top100_NC4D_localization_pnponly.m: -------------------------------------------------------------------------------- 1 | 2 | if exist(NC4D_matname, 'file') ~= 2 3 | %copy results from NC4D 4 | ImgList = struct('queryname', {}, 'topNname', {}, 'topNscore', {}, 'P', {}); 5 | for ii = 1:Nq 6 | 7 | ImgList(ii).queryname = sorted_list.ImgList(ii).queryname; 8 | ImgList(ii).topNname = sorted_list.ImgList(ii).topNname(1:pnp_topN); 9 | 10 | end 11 | 12 | %build list for pnp 13 | qlist = cell(1, Nq*pnp_topN); 14 | dblist = cell(1, Nq*pnp_topN); 15 | for ii = 1:Nq 16 | for jj = 1:pnp_topN 17 | qlist{pnp_topN*(ii-1)+jj} = ImgList(ii).queryname; 18 | dblist{pnp_topN*(ii-1)+jj} = ImgList(ii).topNname{jj}; 19 | end 20 | end 21 | 22 | %run dense pnp in parpool 23 | dbfname = fullfile(params.data.dir, params.data.db.cutout.dir, ImgList(ii).topNname{1}); 24 | imsize_db = size(at_imageresize_nc4d(imread(dbfname))); 25 | parfor ii = 1:Nq 26 | % load matches for this query 27 | nc4d_matches = load(fullfile(matches_path,experiment,[num2str(ii) '.mat']),'matches'); 28 | 29 | for jj = 1:pnp_topN 30 | %preload query feature 31 | qfname = fullfile(params.data.dir, params.data.q.dir, ImgList(ii).queryname); 32 | imsize_q = size(at_imageresize_nc4d(imread(qfname))); 33 | 34 | kk = pnp_topN*(ii-1)+jj; 35 | parfor_NC4D_PE_pnponly( qlist{kk}, dblist{kk}, params, ... 36 | nc4d_matches.matches(1,jj,:,:), ... 37 | imsize_q, imsize_db); 38 | fprintf('nc4dPE: %s vs %s DONE. \n', qlist{kk}, dblist{kk}); 39 | end 40 | end 41 | 42 | %load top pnp 43 | for ii = 1:Nq 44 | ImgList(ii).P = cell(1, pnp_topN); 45 | for jj = 1:pnp_topN 46 | [~, dbbasename, ~] = fileparts(ImgList(ii).topNname{jj}); 47 | this_nc4dpe_matname = fullfile(params.output.pnp_nc4d_inlier.dir, ImgList(ii).queryname, [dbbasename, params.output.pnp_nc4d.matformat]); 48 | load(this_nc4dpe_matname, 'P', 'inls'); 49 | ImgList(ii).P{jj} = P; 50 | ImgList(ii).inls{jj} = inls; 51 | end 52 | end 53 | 54 | if exist(params.output.dir, 'dir') ~= 7 55 | mkdir(params.output.dir); 56 | end 57 | save('-v6', NC4D_matname, 'ImgList'); 58 | else 59 | load(NC4D_matname, 'ImgList'); 60 | end 61 | ImgList_NC4D = ImgList; 62 | -------------------------------------------------------------------------------- /lib_matlab/p2c.m: -------------------------------------------------------------------------------- 1 | function C = p2c(P) 2 | % P = [R t], no K involved 3 | 4 | C = -P(1:3,1:3)'*P(1:3,4); -------------------------------------------------------------------------------- /lib_matlab/p2dist.m: -------------------------------------------------------------------------------- 1 | function [dpos,dori] = p2dist(P1,P2) 2 | 3 | 4 | c1 = p2c(P1); 5 | c2 = p2c(P2); 6 | 7 | dpos = sqrt(sum((c1 - c2).^2)); 8 | 9 | R = P1(1:3,1:3)\P2(1:3,1:3); 10 | dori = acos((trace(R)-1)/2); 11 | 12 | 13 | % r1 = vl_irodr(P1(1:3,1:3)); 14 | % r2 = vl_irodr(P2(1:3,1:3)); 15 | % 16 | % % dori = acos((trace(P1(1:3,1:3) * P2(1:3,1:3)') - 1)/2); 17 | % 18 | % dori = acos((r1./vnorm(r1))'*(r2./vnorm(r2))); 19 | % 20 | % % v1 = P1(1:3,1:3)*[0;0;1]; 21 | % % v2 = P2(1:3,1:3)*[0;0;1]; 22 | % % dori = acos(v1'*v2); 23 | % 24 | % keyboard; -------------------------------------------------------------------------------- /lib_matlab/parfor_NC4D_PE_pnponly.m: -------------------------------------------------------------------------------- 1 | function parfor_NC4D_PE_pnponly( qname, dbname, params, matches, imsize_q, imsize_db) 2 | 3 | [~, dbbasename, ~] = fileparts(dbname); 4 | this_nc4dpe_matname = fullfile(params.output.pnp_nc4d_inlier.dir, qname, [dbbasename, params.output.pnp_nc4d.matformat]); 5 | 6 | if exist(this_nc4dpe_matname, 'file') ~= 2 7 | 8 | %geometric verification results 9 | 10 | M = squeeze(matches); 11 | 12 | f1 = M(:,1:2)'; 13 | f2 = M(:,3:4)'; 14 | scr = M(:,5); 15 | 16 | thr = params.ncnet.thr; 17 | if thr>0 18 | f1 = f1(:,scr>thr); 19 | f2 = f2(:,scr>thr); 20 | end 21 | 22 | if isfield(params.ncnet,'N_subsample') 23 | subsample=randperm(size(f1,2)); 24 | subsample=subsample(1:min(size(f1,2),params.ncnet.N_subsample)); 25 | f1=f1(:,subsample); 26 | f2=f2(:,subsample); 27 | end 28 | 29 | tent_xq2d = f1; 30 | tent_xdb2d = f2; 31 | 32 | %depth information 33 | this_db_matname = fullfile(params.data.dir, params.data.db.cutout.dir, [dbname, params.data.db.cutout.matformat]); 34 | load(this_db_matname, 'XYZcut'); 35 | %load transformation matrix (local to global) 36 | this_floorid = strsplit(dbname, '/');this_floorid = this_floorid{1}; 37 | info = parse_WUSTL_cutoutname( dbname ); 38 | transformation_txtname = fullfile(params.data.dir, params.data.db.trans.dir, this_floorid, 'transformations', ... 39 | sprintf('%s_trans_%s.txt', info.scene_id, info.scan_id)); 40 | [ ~, P_after ] = load_WUSTL_transformation(transformation_txtname); 41 | 42 | %Feature upsampling 43 | Iqsize = size(imread(fullfile(params.data.dir, params.data.q.dir, qname))); 44 | Idbsize = size(XYZcut); 45 | 46 | tent_xq2d(1,:) = Iqsize(2)*tent_xq2d(1,:); 47 | tent_xq2d(2,:) = Iqsize(1)*tent_xq2d(2,:); 48 | tent_xdb2d(1,:) = floor(Idbsize(2)*tent_xdb2d(1,:)); 49 | tent_xdb2d(2,:) = floor(Idbsize(1)*tent_xdb2d(2,:)); 50 | tent_xdb2d(1,tent_xdb2d(1,:) == 0) = 1; % fix zeros 51 | tent_xdb2d(2,tent_xdb2d(2,:) == 0) = 1; 52 | 53 | 54 | %query ray 55 | Kq = [params.data.q.fl, 0, Iqsize(2)/2.0; ... 56 | 0, params.data.q.fl, Iqsize(1)/2.0; ... 57 | 0, 0, 1]; 58 | tent_ray2d = Kq^-1 * [tent_xq2d; ones(1, size(tent_xq2d, 2))]; 59 | %DB 3d points 60 | indx = sub2ind(size(XYZcut(:,:,1)),tent_xdb2d(2,:),tent_xdb2d(1,:)); 61 | X = XYZcut(:,:,1);Y = XYZcut(:,:,2);Z = XYZcut(:,:,3); 62 | tent_xdb3d = [X(indx); Y(indx); Z(indx)]; 63 | tent_xdb3d = bsxfun(@plus, P_after(1:3, 1:3)*tent_xdb3d, P_after(1:3, 4)); 64 | %Select keypoint correspond to 3D 65 | idx_3d = all(~isnan(tent_xdb3d), 1); 66 | tent_xq2d = tent_xq2d(:, idx_3d); 67 | tent_xdb2d = tent_xdb2d(:, idx_3d); 68 | tent_ray2d = tent_ray2d(:, idx_3d); 69 | tent_xdb3d = tent_xdb3d(:, idx_3d); 70 | 71 | tentatives_2d = [tent_xq2d; tent_xdb2d]; 72 | tentatives_3d = [tent_ray2d; tent_xdb3d]; 73 | 74 | %solver 75 | if size(tentatives_2d, 2) < 3 76 | P = nan(3, 4); 77 | inls = false(1, size(tentatives_2d, 2)); 78 | else 79 | [ P, inls ] = ht_lo_ransac_p3p( tent_ray2d, tent_xdb3d, params.ncnet.pnp_thr*pi/180, 10000); 80 | if isempty(P) 81 | P = nan(3, 4); 82 | end 83 | end 84 | 85 | if exist(fullfile(params.output.pnp_nc4d_inlier.dir, qname), 'dir') ~= 7 86 | mkdir(fullfile(params.output.pnp_nc4d_inlier.dir, qname)); 87 | end 88 | save('-v7.3', this_nc4dpe_matname, 'P', 'inls', 'tentatives_2d', 'tentatives_3d', 'idx_3d'); 89 | 90 | if 0 91 | % %debug 92 | close all; 93 | 94 | Iq = imread(fullfile(params.data.dir, params.data.q.dir, qname)); 95 | Idb = imread(fullfile(params.data.dir, params.data.db.cutout.dir, dbname)); 96 | points.x2 = tentatives_2d(3, :); 97 | points.y2 = tentatives_2d(4, :); 98 | points.x1 = tentatives_2d(1, :); 99 | points.y1 = tentatives_2d(2, :); 100 | points.color = 'g'; 101 | points.facecolor = 'g'; 102 | points.markersize = 10; 103 | points.linestyle = '-'; 104 | points.linewidth = 0.5; 105 | show_matches2_horizontal( Iq, Idb, points, ... 106 | [1:size(tentatives_2d,2); 1:size(tentatives_2d,2)], inls ); 107 | 108 | if ~exist('eg','dir'), mkdir('eg'); end 109 | print('-dpng',['eg/' qname '.png'],'-r60'); 110 | % keyboard; 111 | end 112 | end 113 | 114 | 115 | end 116 | 117 | -------------------------------------------------------------------------------- /lib_matlab/parfor_nc4d_PV.m: -------------------------------------------------------------------------------- 1 | function parfor_nc4d_PV( qname, dbname, P, RGB, XYZ, params ) 2 | dslevel = 8^-1; 3 | 4 | [~, dbbasename, ~] = fileparts(dbname); 5 | this_nc4dPV_matname = fullfile(params.output.synth.dir, qname, [dbbasename, params.output.synth.matformat]); 6 | 7 | if exist(this_nc4dPV_matname, 'file') ~= 2 8 | if all(~isnan(P(:))) 9 | 10 | %load downsampled images 11 | Iq = imresize(imread(fullfile(params.data.dir, params.data.q.dir, qname)), dslevel); 12 | fl = params.data.q.fl * dslevel; 13 | K = [fl, 0, size(Iq, 2)/2.0; 0, fl, size(Iq, 1)/2.0; 0, 0, 1]; 14 | [ RGBpersp, XYZpersp ] = ht_Points2Persp( RGB, XYZ, K*P, size(Iq, 1), size(Iq, 2) ); 15 | RGB_flag = all(~isnan(XYZpersp), 3); 16 | 17 | %compute DSIFT error 18 | if any(RGB_flag(:)) 19 | %normalization 20 | Iq_norm = image_normalization( double(rgb2gray(Iq)), RGB_flag ); 21 | I_synth = double(rgb2gray(RGBpersp)); 22 | I_synth(~RGB_flag) = nan; 23 | I_synth = image_normalization( inpaint_nans(I_synth), RGB_flag ); 24 | 25 | %compute DSIFT 26 | [fq, dq] = vl_phow(im2single(Iq_norm),'sizes',8,'step',4); 27 | [fsynth, dsynth] = vl_phow(im2single(I_synth),'sizes',8,'step',4); 28 | f_linind = sub2ind(size(I_synth), fsynth(2, :), fsynth(1, :)); 29 | iseval = RGB_flag(f_linind); 30 | dq = relja_rootsift(single(dq)); dsynth = relja_rootsift(single(dsynth)); 31 | 32 | %error 33 | err = sqrt(sum((dq(:, iseval) - dsynth(:, iseval)).^2, 1)); 34 | score = quantile(err, 0.5)^-1; 35 | errmap = nan(size(I_synth));errmap(f_linind(iseval)) = err; 36 | xuni = sort(unique(fsynth(1, :)), 'ascend');yuni = sort(unique(fsynth(2, :)), 'ascend'); 37 | errmap = errmap(yuni, xuni); 38 | 39 | % %debug 40 | % figure();set(gcf, 'Position', [0 0 1000, 300]); 41 | % ultimateSubplot( 3, 1, 1, 1, 0.01, 0.05 ); 42 | % imshow(Iq); 43 | % ultimateSubplot( 3, 1, 2, 1, 0.01, 0.05 ); 44 | % imshow(RGBpersp); 45 | % ultimateSubplot( 3, 1, 3, 1, 0.01, 0.05 ); 46 | % imagesc(errmap);colormap('jet');axis image off; 47 | % keyboard; 48 | 49 | else 50 | score = 0; 51 | errmap = []; 52 | end 53 | else 54 | Iq = []; 55 | RGBpersp = []; 56 | RGB_flag = []; 57 | score = 0; 58 | errmap = 0; 59 | end 60 | 61 | if exist(fullfile(params.output.synth.dir, qname), 'dir') ~= 7 62 | mkdir(fullfile(params.output.synth.dir, qname)); 63 | end 64 | save(this_nc4dPV_matname, 'Iq', 'RGBpersp', 'RGB_flag', 'score', 'errmap'); 65 | 66 | 67 | I1 = imresize(imread(fullfile(params.data.dir, params.data.q.dir, qname)), dslevel); 68 | if isempty(RGBpersp) 69 | I2 = uint8(zeros(size(I1))); 70 | else 71 | I2 = RGBpersp; 72 | end 73 | 74 | [I1ylen, I1xlen] = size(I1); 75 | [I2ylen, I2xlen] = size(I2); 76 | 77 | %cat image 78 | if I1ylen <= I2ylen 79 | scale1 = 1; 80 | scale2 = I1ylen/I2ylen; 81 | I2 = imresize(I2,scale2); 82 | else 83 | scale1 = I2ylen/I1ylen; 84 | scale2 = 1; 85 | I1 = imresize(I1,scale1); 86 | end 87 | catI = cat(2, I1, I2); 88 | 89 | imwrite(catI,[this_nc4dPV_matname '.jpg'],'Quality',90) 90 | 91 | end 92 | 93 | 94 | end 95 | 96 | -------------------------------------------------------------------------------- /lib_matlab/show_matches2_horizontal.m: -------------------------------------------------------------------------------- 1 | function [ h ] = show_matches2_horizontal( I1, I2, showkeys, match12, inls12 ) 2 | % 3 | 4 | if size(I1, 3) == 3 5 | I1 = rgb2gray(I1); 6 | end 7 | if size(I2, 3) == 3 8 | I2 = rgb2gray(I2); 9 | end 10 | 11 | [I1ylen, I1xlen] = size(I1); 12 | [I2ylen, I2xlen] = size(I2); 13 | 14 | %cat image 15 | if I1ylen <= I2ylen 16 | scale1 = 1; 17 | scale2 = I1ylen/I2ylen; 18 | I2 = imresize(I2,scale2); 19 | else 20 | scale1 = I2ylen/I1ylen; 21 | scale2 = 1; 22 | I1 = imresize(I1,scale1); 23 | end 24 | catI = cat(2, I1, I2); 25 | 26 | h = figure('Visible','off'); 27 | imagesc(catI); 28 | colormap(gray); 29 | set(gca, 'Position', [0 0 1 1]); 30 | set(gcf, 'Position', [0 0 size(catI,2) size(catI,1)]); 31 | grid off; 32 | axis equal tight; 33 | axis off; 34 | hold on; 35 | 36 | %plot matches 37 | style = length(showkeys); 38 | for s = 1:1:style 39 | x1 = scale1*showkeys(s).x1(1,match12(1,:)); 40 | y1 = scale1*showkeys(s).y1(1,match12(1,:)); 41 | x2 = scale2*showkeys(s).x2(1,match12(2,:)) + size(I1,2) + 10; 42 | y2 = scale2*showkeys(s).y2(1,match12(2,:)); 43 | 44 | mh = scatter([x1'; x2'], [y1'; y2']); 45 | set(mh, 'MarkerEdgeColor', 'b', 'MarkerFaceColor', 'b', 'SizeData', 10); 46 | 47 | x1 = scale1*showkeys(s).x1(1,inls12(1,:)); 48 | y1 = scale1*showkeys(s).y1(1,inls12(1,:)); 49 | x2 = scale2*showkeys(s).x2(1,inls12(1,:)) + size(I1,2) + 10; 50 | y2 = scale2*showkeys(s).y2(1,inls12(1,:)); 51 | 52 | mh = scatter([x1'; x2'], [y1'; y2']); 53 | set(mh, 'MarkerEdgeColor', 'g', 'MarkerFaceColor', 'g', 'SizeData', 10); 54 | 55 | lh = line([x1; x2], [y1; y2]); 56 | set(lh, 'Color', showkeys(s).color, 'LineStyle', showkeys(s).linestyle, 'LineWidth', showkeys(s).linewidth); 57 | 58 | end 59 | 60 | 61 | end 62 | 63 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import os 3 | from os.path import exists, join, basename 4 | from collections import OrderedDict 5 | import numpy as np 6 | import numpy.random 7 | import datetime 8 | import torch 9 | import torch.nn as nn 10 | import torch.optim as optim 11 | from torch.nn.functional import relu 12 | from torch.utils.data import Dataset 13 | 14 | from lib.dataloader import DataLoader # modified dataloader 15 | from lib.model import ImMatchNet 16 | from lib.im_pair_dataset import ImagePairDataset 17 | from lib.normalization import NormalizeImageDict, normalize_image_dict_caffe 18 | from lib.torch_util import save_checkpoint, str_to_bool 19 | from lib.torch_util import BatchTensorToVars, str_to_bool 20 | from lib.sparse import get_scores 21 | 22 | from lib.sparse import corr_and_add 23 | import torch.nn.functional as F 24 | from lib.model import featureL2Norm 25 | 26 | import argparse 27 | 28 | 29 | # Seed and CUDA 30 | use_cuda = torch.cuda.is_available() 31 | torch.manual_seed(10) 32 | if use_cuda: 33 | torch.cuda.manual_seed(10) 34 | np.random.seed(10) 35 | 36 | print('Sparse-NCNet training script') 37 | 38 | # Argument parsing 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--checkpoint', type=str, default='') 41 | parser.add_argument('--image_size', type=int, default=400) 42 | parser.add_argument('--dataset_image_path', type=str, default='datasets/ivd', help='path to IVD dataset') 43 | parser.add_argument('--dataset_csv_path', type=str, default='datasets/ivd/image_pairs/', help='path to IVD training csv') 44 | parser.add_argument('--num_epochs', type=int, default=5, help='number of training epochs') 45 | parser.add_argument('--batch_size', type=int, default=16, help='training batch size') 46 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 47 | parser.add_argument('--ncons_kernel_sizes', nargs='+', type=int, default=[3,3], help='kernels sizes in neigh. cons.') 48 | parser.add_argument('--ncons_channels', nargs='+', type=int, default=[16,1], help='channels in neigh. cons') 49 | parser.add_argument('--result_model_fn', type=str, default='sparsencnet', help='trained model filename') 50 | parser.add_argument('--result-model-dir', type=str, default='trained_models', help='path to trained models folder') 51 | parser.add_argument('--fe_finetune_params', type=int, default=0, help='number of layers to finetune') 52 | parser.add_argument('--k', type=int, default=10, help='number of nearest neighs') 53 | parser.add_argument('--symmetric_mode', type=int, default=1, help='use symmetric mode') 54 | parser.add_argument('--bn', type=int, default=0, help='use batch norm') 55 | parser.add_argument('--random_affine', type=int, default=0, help='use affine data augmentation') 56 | parser.add_argument('--feature_extraction_cnn', type=str, default='resnet101', help='type of feature extractor') 57 | parser.add_argument('--relocalize', type=int, default=0) 58 | parser.add_argument('--change_stride', type=int, default=0) 59 | 60 | 61 | args = parser.parse_args() 62 | print(args) 63 | 64 | # Create model 65 | print('Creating CNN model...') 66 | model = ImMatchNet(use_cuda=use_cuda, 67 | checkpoint=args.checkpoint, 68 | ncons_kernel_sizes=args.ncons_kernel_sizes, 69 | ncons_channels=args.ncons_channels, 70 | sparse=True, 71 | symmetric_mode=bool(args.symmetric_mode), 72 | feature_extraction_cnn=args.feature_extraction_cnn, 73 | bn=bool(args.bn), 74 | k=args.k) 75 | 76 | if args.change_stride: 77 | model.FeatureExtraction.model[-1][0].conv1.stride=(1,1) 78 | model.FeatureExtraction.model[-1][0].conv2.stride=(1,1) 79 | model.FeatureExtraction.model[-1][0].downsample[0].stride=(1,1) 80 | 81 | def eval_model_fn(batch): 82 | # feature extraction 83 | if args.relocalize: 84 | feature_A_2x = model.FeatureExtraction(batch['source_image']) 85 | feature_B_2x = model.FeatureExtraction(batch['target_image']) 86 | 87 | feature_A = F.max_pool2d(feature_A_2x, kernel_size=3, stride=2, padding=1) 88 | feature_B = F.max_pool2d(feature_B_2x, kernel_size=3, stride=2, padding=1) 89 | else: 90 | feature_A = model.FeatureExtraction(batch['source_image']) 91 | feature_B = model.FeatureExtraction(batch['target_image']) 92 | 93 | feature_A = featureL2Norm(feature_A) 94 | feature_B = featureL2Norm(feature_B) 95 | 96 | fs1, fs2 = feature_A.shape[-2:] 97 | fs3, fs4 = feature_B.shape[-2:] 98 | 99 | corr4d = corr_and_add(feature_A, feature_B, k = model.k) 100 | corr4d = model.NeighConsensus(corr4d) 101 | 102 | return corr4d 103 | 104 | # Set which parts of the model to train 105 | if args.fe_finetune_params>0: 106 | for i in range(args.fe_finetune_params): 107 | for p in model.FeatureExtraction.model[-1][-(i+1)].parameters(): 108 | p.requires_grad=True 109 | 110 | print('Trainable parameters:') 111 | for i,p in enumerate(filter(lambda p: p.requires_grad, model.parameters())): 112 | print(str(i+1)+": "+str(p.shape)) 113 | 114 | # Optimizer 115 | print('using Adam optimizer') 116 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr) 117 | 118 | cnn_image_size=(args.image_size,args.image_size) 119 | 120 | Dataset = ImagePairDataset 121 | train_csv = 'train_pairs.csv' 122 | test_csv = 'val_pairs.csv' 123 | if args.feature_extraction_cnn == 'd2': #.startswith('d2'): 124 | normalization_tnf = normalize_image_dict_caffe 125 | else: 126 | normalization_tnf = NormalizeImageDict(['source_image','target_image']) 127 | 128 | batch_preprocessing_fn = BatchTensorToVars(use_cuda=use_cuda) 129 | 130 | # Dataset and dataloader 131 | dataset = Dataset(transform=normalization_tnf, 132 | dataset_image_path=args.dataset_image_path, 133 | dataset_csv_path=args.dataset_csv_path, 134 | dataset_csv_file = train_csv, 135 | output_size=cnn_image_size, 136 | random_affine = bool(args.random_affine)) 137 | 138 | dataloader = DataLoader(dataset, batch_size=args.batch_size, 139 | shuffle=True, 140 | num_workers=0) 141 | 142 | dataset_test = Dataset(transform=normalization_tnf, 143 | dataset_image_path=args.dataset_image_path, 144 | dataset_csv_path=args.dataset_csv_path, 145 | dataset_csv_file=test_csv, 146 | output_size=cnn_image_size) 147 | 148 | dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size, 149 | shuffle=True, num_workers=4) 150 | 151 | # Define checkpoint name 152 | checkpoint_name = os.path.join(args.result_model_dir, 153 | datetime.datetime.now().strftime("%Y-%m-%d_%H:%M")+'_'+args.result_model_fn + '.pth.tar') 154 | 155 | print('Checkpoint name: '+checkpoint_name) 156 | 157 | # Train 158 | best_test_loss = float("inf") 159 | 160 | def weak_loss(model,batch,normalization='softmax',alpha=30): 161 | if normalization is None: 162 | normalize = lambda x: x 163 | elif normalization=='softmax': 164 | normalize = lambda x: torch.nn.functional.softmax(x,1) 165 | elif normalization=='l1': 166 | normalize = lambda x: x/(torch.sum(x,dim=1,keepdim=True)+0.0001) 167 | 168 | b = batch['source_image'].size(0) 169 | start = torch.cuda.Event(enable_timing=True) 170 | mid = torch.cuda.Event(enable_timing=True) 171 | end = torch.cuda.Event(enable_timing=True) 172 | 173 | # positive 174 | start.record() 175 | corr4d = model(batch) 176 | mid.record() 177 | 178 | # compute matching scores 179 | scores_A = get_scores(corr4d, k = args.k) 180 | scores_B = get_scores(corr4d, reverse=True, k = args.k) 181 | score_pos = (scores_A + scores_B)/2 182 | 183 | end.record() 184 | torch.cuda.synchronize() 185 | model_time = start.elapsed_time(mid)/1000 186 | loss_time = mid.elapsed_time(end)/1000 187 | #print('model: {:.2f}: loss: {:.2f}'.format(model_time,loss_time)) 188 | 189 | # negative 190 | batch['source_image']=batch['source_image'][np.roll(np.arange(b),-1),:] # roll 191 | corr4d = model(batch) 192 | 193 | # compute matching scores 194 | scores_A = get_scores(corr4d, k = args.k) 195 | scores_B = get_scores(corr4d, reverse=True, k = args.k) 196 | score_neg = (scores_A + scores_B)/2 197 | 198 | # loss 199 | loss = score_neg - score_pos 200 | return loss 201 | 202 | loss_fn = lambda model,batch: weak_loss(model,batch,normalization='softmax') 203 | 204 | 205 | # define epoch function 206 | def process_epoch(mode,epoch,model,loss_fn,optimizer,dataloader,batch_preprocessing_fn,use_cuda=True,log_interval=50): 207 | epoch_loss = 0 208 | for batch_idx, batch in enumerate(dataloader): 209 | if mode=='train': 210 | optimizer.zero_grad() 211 | tnf_batch = batch_preprocessing_fn(batch) 212 | 213 | loss = loss_fn(model,tnf_batch) 214 | 215 | loss_np = loss.item() 216 | 217 | epoch_loss += loss_np 218 | if mode=='train': 219 | loss.backward() 220 | optimizer.step() 221 | else: 222 | loss=None 223 | if batch_idx % log_interval == 0: 224 | print(mode.capitalize()+' Epoch: {} [{}/{} ({:.0f}%)]\t\tLoss: {:.6f}'.format( 225 | epoch, batch_idx , len(dataloader), 226 | 100. * batch_idx / len(dataloader), loss_np)) 227 | epoch_loss /= len(dataloader) 228 | print(mode.capitalize()+' set: Average loss: {:.4f}'.format(epoch_loss)) 229 | return epoch_loss 230 | 231 | train_loss = np.zeros(args.num_epochs) 232 | test_loss = np.zeros(args.num_epochs) 233 | 234 | print('Starting training...') 235 | 236 | model.FeatureExtraction.eval() 237 | 238 | for epoch in range(1, args.num_epochs+1): 239 | train_loss[epoch-1] = process_epoch('train',epoch,eval_model_fn,loss_fn,optimizer,dataloader,batch_preprocessing_fn,log_interval=1) 240 | test_loss[epoch-1] = process_epoch('test',epoch,eval_model_fn,loss_fn,optimizer,dataloader_test,batch_preprocessing_fn,log_interval=1) 241 | 242 | # remember best loss 243 | is_best = test_loss[epoch-1] < best_test_loss 244 | best_test_loss = min(test_loss[epoch-1], best_test_loss) 245 | save_checkpoint({ 246 | 'epoch': epoch, 247 | 'args': args, 248 | 'state_dict': model.state_dict(), 249 | 'best_test_loss': best_test_loss, 250 | 'optimizer' : optimizer.state_dict(), 251 | 'train_loss': train_loss, 252 | 'test_loss': test_loss, 253 | }, is_best,checkpoint_name) 254 | 255 | print('Done!') 256 | -------------------------------------------------------------------------------- /trained_models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !download.sh 6 | -------------------------------------------------------------------------------- /trained_models/download.sh: -------------------------------------------------------------------------------- 1 | wget https://www.di.ens.fr/willow/research/sparse-ncnet/models/sparsencnet_k10.pth.tar 2 | --------------------------------------------------------------------------------