├── .gitignore ├── .gitmodules ├── README.md ├── assets └── new_arch_final.png ├── clean_logs.py ├── commons.py ├── configs.py ├── download.py ├── eval_pose_file.py ├── gloc ├── __init__.py ├── datasets │ ├── __init__.py │ ├── dataset.py │ ├── dataset_nolabels.py │ ├── get_dataset.py │ └── imlist_dataset.py ├── extraction │ ├── __init__.py │ ├── extract_feats.py │ └── utils.py ├── initialization.py ├── models │ ├── __init__.py │ ├── features.py │ ├── get_model.py │ ├── layers.py │ └── refinement_model.py ├── rendering │ ├── __init__.py │ ├── base_renderer.py │ ├── mesh_renderer.py │ ├── nerf_renderer.py │ ├── rend_conf.py │ ├── rend_utils.py │ └── splatting_renderer.py ├── resamplers │ ├── __init__.py │ ├── get_protocol.py │ ├── samplers.py │ ├── sampling_utils.py │ ├── scalers.py │ ├── scalers_conf.py │ └── strategies.py └── utils │ ├── __init__.py │ ├── camera_utils.py │ ├── utils.py │ └── visualization.py ├── parse_args.py ├── path_configs.py ├── refine_pose.py ├── refine_pose_aachen.py ├── render_dataset_from_script.py ├── requirements.txt └── submit_poses.py /.gitignore: -------------------------------------------------------------------------------- 1 | .spyproject 2 | .idea 3 | __pycache__ 4 | *__pycache__* 5 | logs 6 | cache 7 | .ipynb_checkpoints 8 | *.ipynb 9 | *.torch 10 | renderings 11 | pretrained 12 | descr_cache 13 | rendered_views.txt 14 | fake* 15 | old* 16 | render_check 17 | masks* 18 | sphere_render* 19 | notebook_code.py 20 | extrinsic2pyramid 21 | *.pth 22 | triplet_renders 23 | gifs 24 | feat_steps 25 | extrinsic2pyramid 26 | plots 27 | .vscode 28 | nerf_renders 29 | render_straight 30 | renders_rotate 31 | measure_nerf* 32 | jonas* 33 | *test_keypoints* 34 | viz* 35 | vreph_stuff 36 | vreph_coms 37 | ign_* 38 | pose_priors 39 | to_refine* 40 | conv_analy* 41 | renderings_dense_post 42 | query_mask* 43 | pt2_1* 44 | paper_plots 45 | rebuttal_plots 46 | chess_7* 47 | examples 48 | *.zip 49 | data -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "gloc/models/third_party/RoMA"] 2 | path = gloc/models/third_party/RoMA 3 | url = https://github.com/ga1i13o/RoMa 4 | [submodule "third_party/gaussian-splatting"] 5 | path = third_party/gaussian-splatting 6 | url = https://github.com/ga1i13o/gaussian-splatting 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement 3 | 4 | This is the official pyTorch implementation of the CVPR24 paper "The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement". 5 | In this work, we present a simple approach for Pose Refinement that combines pre-trained features with a particle filter and a renderable representation of the scene. 6 | 7 | 8 | [[CVPR 2024 Open Access](https://openaccess.thecvf.com/content/CVPR2024/html/Trivigno_The_Unreasonable_Effectiveness_of_Pre-Trained_Features_for_Camera_Pose_Refinement_CVPR_2024_paper.html)] [[ArXiv](https://arxiv.org/abs/2404.10438)] 9 | 10 |

11 | 12 |
Our proposed Pose Refinement algorithm. 13 |

14 | 15 | ## Download data 16 | 17 | The following command 18 | 19 | `$ python download.py` 20 | 21 | Will download colmap models, and pre-trained Splatting models for the scene representation, as well as the Cambridge Landmark dataset. 22 | 23 | ## Environment 24 | 25 | Follow the instructions to install the `gaussian_splatting` environment from the [official repo](https://github.com/graphdeco-inria/gaussian-splatting/tree/main?tab=readme-ov-file#setup) 26 | 27 | Then, activate the environment and execute: 28 | 29 | `$ pin install -r requirements.txt` 30 | 31 | ## Reproduce our results 32 | 33 | ### Cambridge Landmarks 34 | 35 | To reproduce results on Cambridge Landmarks, e.g. for KingsCollege: 36 | 37 | `$ python refine_pose.py KingsCollege --exp_name kings_college_refine --renderer g_splatting --clean_logs` 38 | 39 | The script will load the config for the number of steps and hyperparameters of the MonteCarlo optimization from the `configs.py` file. It will utilize a Gaussian Splatting model to render candidate poses. Other options such as a colored mesh, or a NeRF model will be uploaded soon 40 | 41 | ### 7scenes 42 | 43 | Coming soon 44 | 45 | ### Aachen Day-Night 46 | 47 | Coming soon 48 | 49 | ## Cite 50 | Here is the bibtex to cite our paper 51 | ```@inproceedings{trivigno2024unreasonable, 52 | title={The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement}, 53 | author={Trivigno, Gabriele and Masone, Carlo and Caputo, Barbara and Sattler, Torsten}, 54 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 55 | pages={12786--12798}, 56 | year={2024} 57 | } 58 | ``` 59 | 60 | 61 | ## Acknowledgements 62 | Parts of this repo are inspired by the following repositories: 63 | - [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting) 64 | - [Nerfstudio](https://github.com/nerfstudio-project/nerfstudio) 65 | -------------------------------------------------------------------------------- /assets/new_arch_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ga1i13o/mcloc_poseref/6f0610c3572486a4a0a4a8268e7844efd66c94e9/assets/new_arch_final.png -------------------------------------------------------------------------------- /clean_logs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from os.path import join 5 | from glob import glob 6 | from tqdm import tqdm 7 | 8 | 9 | def main(folder, only_step=-1): 10 | if not os.path.isdir(folder): 11 | raise ValueError(f'{folder} is not a directory') 12 | if not 'renderings' in os.listdir(folder): 13 | raise ValueError(f'{folder} should be a log directory') 14 | 15 | steps_folders = glob(join(folder, 'renderings', '*')) 16 | print(f'found {len(steps_folders)} folders') 17 | dir_to_step = lambda x: int(x.split('/')[-1].split('_')[2].split('s')[-1]) 18 | 19 | for step_f in tqdm(steps_folders, ncols=100): 20 | s_num = dir_to_step(step_f) 21 | if (only_step != -1) and (s_num != only_step): 22 | continue 23 | 24 | query_renders = glob(join(step_f, '**', '*.png'), recursive=True) 25 | print(f'\nStep {s_num}: found {len(query_renders)} renderings, deleting them...') 26 | for q_r in tqdm(query_renders, ncols=100): 27 | os.remove(q_r) 28 | 29 | vreph_dir = join(step_f, 'vreph_conf') 30 | if os.path.isdir(vreph_dir): 31 | n_vreph = len(os.listdir(vreph_dir)) 32 | print(f'Found {n_vreph} files in vreph conf, deleting them...') 33 | shutil.rmtree(vreph_dir) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser(description='Argument parser') 38 | parser.add_argument('folder', type=str, help='log folder') 39 | parser.add_argument('-s', '--step', type=int, default=-1, help='delete only step n.') 40 | 41 | args = parser.parse_args() 42 | main(args.folder, args.step) 43 | 44 | 45 | """ 46 | Example 47 | python clean_logs.py logs/pt2_1_kings_cpl3_320_colmaporig/2023-08-24_23-40-16 48 | python clean_logs.py logs/pt2_1_V2_T6_N40_M2_kings_nerf_cpl3_nozstd_320_1/2023-09-19_14-05-18 49 | """ -------------------------------------------------------------------------------- /commons.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains some functions and classes which can be useful in very diverse projects. 3 | """ 4 | import os 5 | import sys 6 | import torch 7 | import logging 8 | import traceback 9 | from os.path import join 10 | import random 11 | import numpy as np 12 | import sys 13 | import requests 14 | from bs4 import BeautifulSoup 15 | 16 | 17 | def submit_poses(method, path, dataset='aachenv11'): 18 | session = open('ign_secret.txt', 'r').readline().strip() 19 | resp = requests.get("https://www.visuallocalization.net/submission/", 20 | headers={"Cookie": f"sessionid={session}"}) 21 | 22 | bs = BeautifulSoup(resp.content, features="html.parser") 23 | csrf = bs.select('form > input[name="csrfmiddlewaretoken"]')[0].attrs['value'] 24 | 25 | url = "https://www.visuallocalization.net/submission/" 26 | 27 | headers = { 28 | "Cookie": f"csrftoken={resp.cookies['csrftoken']}; sessionid={session}", 29 | "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.93 Safari/537.36", 30 | } 31 | 32 | data = { 33 | "csrfmiddlewaretoken": csrf, 34 | "method_name": method, #"test", 35 | "publication_url": "", 36 | "code_url": "", 37 | "info_field": "", 38 | "dataset": dataset, 39 | "workshop_submission": "default", 40 | } 41 | 42 | files = { 43 | "result_file": open(path, "r"), 44 | } 45 | resp = requests.post(url, files=files, data=data, headers=headers) 46 | print(resp.status_code) 47 | 48 | 49 | def setup_logging(output_folder, console="debug", 50 | info_filename="info.log", debug_filename="debug.log"): 51 | """Set up logging files and console output. 52 | Creates one file for INFO logs and one for DEBUG logs. 53 | Args: 54 | output_folder (str): creates the folder where to save the files. 55 | debug (str): 56 | if == "debug" prints on console debug messages and higher 57 | if == "info" prints on console info messages and higher 58 | if == None does not use console (useful when a logger has already been set) 59 | info_filename (str): the name of the info file. if None, don't create info file 60 | debug_filename (str): the name of the debug file. if None, don't create debug file 61 | """ 62 | if os.path.exists(output_folder): 63 | raise FileExistsError(f"{output_folder} already exists!") 64 | os.makedirs(output_folder, exist_ok=True) 65 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use 66 | logging.getLogger('PIL').setLevel(logging.INFO) # turn off logging tag for some images 67 | logging.getLogger('matplotlib.font_manager').disabled = True 68 | 69 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S") 70 | logger = logging.getLogger('') 71 | logger.setLevel(logging.DEBUG) 72 | 73 | if info_filename != None: 74 | info_file_handler = logging.FileHandler(join(output_folder, info_filename)) 75 | info_file_handler.setLevel(logging.INFO) 76 | info_file_handler.setFormatter(base_formatter) 77 | logger.addHandler(info_file_handler) 78 | 79 | if debug_filename != None: 80 | debug_file_handler = logging.FileHandler(join(output_folder, debug_filename)) 81 | debug_file_handler.setLevel(logging.DEBUG) 82 | debug_file_handler.setFormatter(base_formatter) 83 | logger.addHandler(debug_file_handler) 84 | 85 | if console != None: 86 | console_handler = logging.StreamHandler() 87 | if console == "debug": console_handler.setLevel(logging.DEBUG) 88 | if console == "info": console_handler.setLevel(logging.INFO) 89 | console_handler.setFormatter(base_formatter) 90 | logger.addHandler(console_handler) 91 | 92 | def exception_handler(type_, value, tb): 93 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb))) 94 | sys.excepthook = exception_handler 95 | 96 | 97 | def make_deterministic(seed=0): 98 | """Make results deterministic. If seed == -1, do not make deterministic. 99 | Running your script in a deterministic way might slow it down. 100 | Note that for some packages (eg: sklearn's PCA) this function is not enough. 101 | """ 102 | seed = int(seed) 103 | if seed == -1: 104 | return 105 | print(f'setting seed {seed}') 106 | random.seed(seed) 107 | np.random.seed(seed) 108 | torch.manual_seed(seed) 109 | torch.cuda.manual_seed_all(seed) 110 | torch.backends.cudnn.deterministic = True 111 | torch.backends.cudnn.benchmark = False 112 | 113 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | conf = { 2 | 'cambridge': [ 3 | { 4 | # Set N.1 5 | 'beams': 2, 6 | 'steps': 40, 7 | 'N': 52, 8 | 'M': 2, 9 | 'feat_model': 'cosplace_r18_l3', 10 | 'protocol': '2_1', 11 | 'center_std': [1.2, 1.2, 0.1], 12 | 'teta': [10], 13 | 'gamma': 0.3, 14 | 'res': 320, 15 | 'colmap_res': 320, 16 | }, 17 | { 18 | # Set N.2 19 | 'beams': 2, 20 | 'steps': 30, 21 | 'N': 40, 22 | 'M': 2, 23 | 'feat_model': 'cosplace_r18_l2', 24 | 'protocol': '2_1', 25 | 'center_std': [.4, .4, .04], 26 | 'teta': [4], 27 | 'gamma': 0.3, 28 | 'res': 320, 29 | 'colmap_res': 320, 30 | }, 31 | { 32 | # Set N.3 33 | 'beams': 2, 34 | 'steps': 60, 35 | 'N': 32, 36 | 'M': 1, 37 | 'feat_model': 'cosplace_r18_l2', 38 | 'protocol': '2_0', 39 | 'center_std': [.15, .15, 0.02], 40 | 'teta': [1.5], 41 | 'gamma': 0.3, 42 | 'res': 480, 43 | 'colmap_res': 480, 44 | }, 45 | ] 46 | } 47 | 48 | 49 | def get_config(ds_name): 50 | cambridge_scenes = [ 51 | 'StMarysChurch', 'OldHospital', 'KingsCollege', 'ShopFacade' 52 | ] 53 | 54 | if ds_name in cambridge_scenes: 55 | return conf['cambridge'] 56 | else: 57 | return NotImplementedError 58 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import gdown 2 | import os 3 | import shutil 4 | import urllib.request 5 | 6 | 7 | cambridge_scenes = { 8 | 'OldHospital': 'https://www.repository.cam.ac.uk/bitstreams/ae577bfb-bdce-488c-8ce6-3765eabe420e/download', 9 | 'KingsCollege': 'https://www.repository.cam.ac.uk/bitstreams/1cd2b04b-ada9-4841-8023-8207f1f3519b/download', 10 | 'StMarysChurch': 'https://www.repository.cam.ac.uk/bitstreams/2559ba20-c4d1-4295-b77f-183f580dbc56/download', 11 | 'ShopFacade': 'https://www.repository.cam.ac.uk/bitstreams/4e5c67dd-9497-4a1d-add4-fd0e00bcb8cb/download' 12 | } 13 | DATA_DIR = 'data' 14 | os.makedirs(DATA_DIR, exist_ok=True) 15 | 16 | ## Download splatting models for Cambridge 17 | print('Downloading Gaussian Splatting models for cambrdige...') 18 | url = 'https://drive.google.com/uc?id=1iCyizI0jZwZ7mdXG9wGGh2fuwdsbQVkL' 19 | output = 'g_down.zip' 20 | gdown.download(url, output, quiet=False) 21 | 22 | unzip_cmd = f'unzip {output}' 23 | os.system(unzip_cmd) 24 | os.rename('out_gp', 'cambridge_splats') 25 | shutil.move('cambridge_splats', DATA_DIR) 26 | os.remove(output) # remove zip file 27 | 28 | # Download undistorted colmap models for Cambridge and 7scenes 29 | print('Downloading colmap models for Cambridge and 7 scenes...') 30 | url = 'https://drive.google.com/uc?id=1BPUU7Z_Xc4SIQJwfIiIHwUA9ud9c9tCU' 31 | output = 'all_colmaps.zip' 32 | gdown.download(url, output, quiet=False) 33 | unzip_cmd = f'unzip {output}' 34 | os.system(unzip_cmd) 35 | shutil.move('all_colmaps', DATA_DIR) 36 | os.remove(output) # remove zip file 37 | 38 | ## Download cambridge scenes 39 | wget_cmd = 'wget {url} -O {out_name}' 40 | for cs, cs_url in cambridge_scenes.items(): 41 | print(f'Downloading dataset for {cs}...') 42 | # urllib.request.urlretrieve(cs_url, f'{cs}.zip') 43 | os.system(wget_cmd.format(url=cs_url, out_name=f'{cs}.zip')) 44 | unzip_cmd = f'unzip {cs}.zip' 45 | os.system(unzip_cmd) 46 | shutil.move(cs, DATA_DIR) 47 | os.remove(f'{cs}.zip') # remove zip file 48 | -------------------------------------------------------------------------------- /eval_pose_file.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname 3 | import argparse 4 | 5 | import commons 6 | from parse_args import parse_args 7 | from configs import get_path_conf 8 | from gloc import extraction 9 | from gloc import rendering 10 | from gloc.rendering import get_renderer 11 | from gloc.datasets import get_dataset, RenderedImagesDataset, get_transform 12 | from gloc.utils import utils, visualization 13 | from gloc.build_fine_model import get_fine_model 14 | from gloc.resamplers import get_protocol 15 | 16 | 17 | def main(args): 18 | DS = args.name 19 | print(f"Arguments: {args}") 20 | 21 | paths_conf = get_path_conf(320, None) 22 | pd = get_dataset(DS, paths_conf[DS], None) 23 | 24 | print(f'Loading pose prior from {args.pose_prior}') 25 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pd) 26 | 27 | all_true_t, all_true_R = pd.get_q_poses() 28 | errors_t, errors_R = utils.get_all_errors_first_estimate(all_true_t, all_true_R, all_pred_t, all_pred_R) 29 | out_str, _ = utils.eval_poses(errors_t, errors_R, descr='Retrieval first estimate') 30 | print(out_str) 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser(description='Argument parser') 35 | parser.add_argument('name', type=str, help='colmap') 36 | parser.add_argument('pose_prior', type=str, help='colmap') 37 | 38 | args = parser.parse_args() 39 | main(args) 40 | 41 | 42 | """ 43 | Example: 44 | python eval_pose_file.py chess_dslam logs/reb_chess_2_bms2_pt2_0_N36_M2_V004_T2_gsplat_dinol4_320/2024-01-27_16-21-05/renderings/pt2_0_s19_sz320_theta2,0_t0,0_0,0_0,0/est_poses.txt 45 | """ 46 | -------------------------------------------------------------------------------- /gloc/__init__.py: -------------------------------------------------------------------------------- 1 | # submodules 2 | from . import datasets 3 | from . import models 4 | from . import extraction 5 | from . import utils 6 | -------------------------------------------------------------------------------- /gloc/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['dataset', 'get_dataset'] 2 | 3 | from gloc.datasets.dataset import PoseDataset, RenderedImagesDataset 4 | from gloc.datasets.dataset import get_query_id, get_render_id 5 | from gloc.datasets.get_dataset import get_dataset, get_transform 6 | from gloc.datasets.imlist_dataset import ImListDataset, find_candidates_paths -------------------------------------------------------------------------------- /gloc/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | from os.path import join 5 | import torch.utils.data as data 6 | from PIL import Image 7 | import torchvision.transforms as T 8 | 9 | from gloc.utils import read_model_nopoints as read_model, parse_cam_model, qvec2rotmat 10 | from gloc.utils import Image as RImage 11 | 12 | 13 | def get_query_id(name): 14 | if isinstance(name, tuple): 15 | name = name[0] 16 | name = name.split('/')[-1].split('.')[0] 17 | return name 18 | 19 | 20 | def get_render_id(name): 21 | if isinstance(name, tuple): 22 | name = name[0] 23 | if ('night' in name) or ('day' in name): 24 | i = name.find('IMG') 25 | name = name[i:].split('_')[:-1] 26 | name = '_'.join(name) 27 | return name 28 | 29 | name = name.split('/')[-1].split('.')[0].split('_')[:-1] 30 | if len(name) > 2 and name[1] == 'nexus4': 31 | name = '_'.join(name[5:]) 32 | elif len(name) > 2 and name[1] == 'gopro3': 33 | name = '_'.join(name[3:]) 34 | else: 35 | name = name[1] 36 | return name 37 | 38 | 39 | class PoseDataset(data.Dataset): 40 | def __init__(self, name, paths_conf, transform=None, rendered_db=None): 41 | self.name = name 42 | self.root = paths_conf['root'] 43 | self.colmap_model = paths_conf['colmap'] 44 | self.transform = transform 45 | self.rendered_db = rendered_db 46 | self.use_render = False 47 | 48 | self.images, self.intrinsics = PoseDataset.load_colmap(self.colmap_model) 49 | 50 | queries = paths_conf['q_file'] 51 | db = paths_conf['db_file'] 52 | if (queries != '') and (db != ''): 53 | all_frames = np.array(list(map(lambda x: x.name, self.images))) 54 | self.db_frames_idxs, self.db_tvecs, self.db_qvecs = self.load_txt(db, all_frames) 55 | self.q_frames_idxs, self.q_tvecs, self.q_qvecs = self.load_txt(queries, all_frames) 56 | 57 | all_frames = np.array(list(map(lambda x: x.name, self.images))) 58 | self.db_frames_idxs, self.db_tvecs, self.db_qvecs = self.load_txt(db, all_frames) 59 | self.q_frames_idxs, self.q_tvecs, self.q_qvecs = self.load_txt(queries, all_frames) 60 | self.n_q = len(self.q_frames_idxs) 61 | 62 | def load_txt(self, fpath, all_frames): 63 | with open(fpath, 'r') as f: 64 | lines = f.readlines() 65 | if lines[0].startswith('Visual Landmark'): 66 | lines = lines[3:] 67 | frames = np.array(list(map(lambda x: x.split(' ')[0].strip(), lines))) 68 | frames_idxs = list(np.where(np.in1d(all_frames, frames))[0]) 69 | tvecs = np.array(list(map(lambda x: x.tvec, self.images)))[frames_idxs] 70 | qvecs = np.array(list(map(lambda x: x.qvec, self.images)))[frames_idxs] 71 | 72 | return frames_idxs, tvecs, qvecs 73 | 74 | def get_basename(self, im_idx): 75 | q_name = self.images[im_idx].name.replace('/', '_').split('.')[0] 76 | return q_name 77 | 78 | def get_pose(self, idx): 79 | R = qvec2rotmat(self.images[idx].qvec) 80 | t = self.images[idx].tvec 81 | 82 | return t, R 83 | 84 | def get_pose_by_name(self, name): 85 | names =np.array(list(map(lambda x: x.name, self.images))) 86 | idx = np.argwhere(name == names)[0,0] 87 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec 88 | 89 | return qvec, tvec 90 | 91 | def num_queries(self): 92 | return len(self.q_frames_idxs) 93 | 94 | def get_pose_by_name(self, name): 95 | names =np.array(list(map(lambda x: x.name, self.images))) 96 | idx = np.argwhere(name == names)[0,0] 97 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec 98 | 99 | return qvec, tvec 100 | 101 | def get_intrinsics(self, q_key_name): 102 | assert q_key_name in self.intrinsics, f'{q_key_name} is not a valid image name' 103 | 104 | w = self.intrinsics[q_key_name]['w'] 105 | h = self.intrinsics[q_key_name]['h'] 106 | K = self.intrinsics[q_key_name]['K'] 107 | 108 | return K, w, h 109 | 110 | def get_q_poses(self): 111 | Rs = [] 112 | ts = [] 113 | 114 | for q_idx in range(len(self.q_frames_idxs)): 115 | idx = self.q_frames_idxs[q_idx] 116 | 117 | R = qvec2rotmat(self.images[idx].qvec) 118 | t = self.images[idx].tvec 119 | Rs.append(R) 120 | ts.append(t) 121 | 122 | return np.array(ts), np.array(Rs) 123 | 124 | def get_all_poses(self): 125 | Rs = [] 126 | ts = [] 127 | 128 | for image in self.images: 129 | R = qvec2rotmat(image.qvec) 130 | t = image.tvec 131 | Rs.append(R) 132 | ts.append(t) 133 | 134 | return np.array(ts), np.array(Rs) 135 | 136 | def __getitem__(self, idx): 137 | """Return: 138 | dict:'im' is the image tensor 139 | 'xyz' is the absolute position of the image 140 | 'wpqr' is the absolute rotation quaternion of the image 141 | """ 142 | data_dict = {} 143 | im_data = self.images[idx] 144 | data_dict['im_ref'] = im_data 145 | if not self.use_render: 146 | im = Image.open(join(self.root, im_data.name)) 147 | else: 148 | im = Image.open(join(self.rendered_db, im_data.name.replace('/', '_')).replace('.jpg', '.png')) 149 | if self.transform: 150 | im = self.transform(im) 151 | data_dict['im'] = im 152 | 153 | return data_dict 154 | 155 | def __len__(self): 156 | return len(self.images) 157 | 158 | @staticmethod 159 | def load_colmap(colmap_model): 160 | # Load the images 161 | logging.info(f'Loading colmap from {colmap_model}') 162 | cam_list = {} 163 | cameras, images = read_model(colmap_model) 164 | for i in images: 165 | qvec = images[i].qvec 166 | tvec = images[i].tvec 167 | cam_data = cameras[images[i].camera_id] 168 | cam_dict = parse_cam_model(cam_data) 169 | 170 | K = np.array([ 171 | [cam_dict["fx"], 0.0, cam_dict["cx"]], 172 | [0.0, cam_dict["fy"], cam_dict["cy"]], 173 | [0.0, 0.0, 1.0]]) 174 | 175 | R = qvec2rotmat(qvec) 176 | T = np.eye(4) 177 | T[0:3, 0:3] = R 178 | T[0:3, 3] = tvec 179 | w, h = cam_dict["width"], cam_dict["height"] 180 | 181 | basename = os.path.splitext(images[i].name)[0] 182 | 183 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h, 'model': cam_data.model} 184 | return list(images.values()), cam_list 185 | 186 | 187 | class RenderedImagesDataset(data.Dataset): 188 | def __init__(self, im_root, transform=None, query_res=None, verbose=True): 189 | self.root = im_root 190 | self.descr_file = os.path.join(im_root, 'rendered_views.txt') 191 | 192 | self.images = RenderedImagesDataset.load_images(self.descr_file, verbose) 193 | self.final_resize = None 194 | if query_res is not None: 195 | # this to ensure that renderings end up having same resolution 196 | # as the query once resized 197 | self.final_resize = T.Resize(query_res, antialias=True) 198 | 199 | self.transform = transform 200 | 201 | def __getitem__(self, idx): 202 | """Return: 203 | dict:'im' is the image tensor 204 | 'xyz' is the absolute position of the image 205 | 'wpqr' is the absolute rotation quaternion of the image 206 | """ 207 | data_dict = {} 208 | im_data = self.images[idx] 209 | data_dict['im_ref'] = im_data 210 | im = Image.open(join(self.root, im_data.name)) 211 | 212 | if self.transform: 213 | im = self.transform(im) 214 | if self.final_resize: 215 | im = self.final_resize(im) 216 | data_dict['im'] = im 217 | 218 | return data_dict 219 | 220 | def __len__(self): 221 | return len(self.images) 222 | 223 | def get_full_paths(self): 224 | paths = list(map(lambda x: join(self.root, x.name), self.images)) 225 | return paths 226 | 227 | def get_names(self): 228 | image_names = list(map(lambda x: x.name, self.images)) 229 | image_names = np.array(image_names) 230 | 231 | return image_names 232 | 233 | def get_camera_centers(self): 234 | centers = [] 235 | for im in self.images: 236 | R = qvec2rotmat(im.qvec) 237 | tvec = im.tvec 238 | im_center = - R.T @ tvec 239 | centers.append(im_center) 240 | centers = np.array(centers) 241 | 242 | return centers 243 | 244 | def get_poses(self): 245 | Rs = [] 246 | ts = [] 247 | 248 | for image in self.images: 249 | R = qvec2rotmat(image.qvec) 250 | t = image.tvec 251 | Rs.append(R) 252 | ts.append(t) 253 | 254 | return np.array(ts), np.array(Rs) 255 | 256 | @staticmethod 257 | def load_images(fpath, verbose): 258 | # Load the images 259 | if verbose: 260 | print('Loading the rendered and cameras') 261 | images = [] 262 | with open(fpath, 'r') as rv: 263 | while True: 264 | line = rv.readline() 265 | if not line: 266 | break 267 | line = line.strip() 268 | fields = line.split(' ') 269 | 270 | name = fields[0]+'.png' 271 | tvec = np.array(tuple(map(float, fields[1:4]))) 272 | qvec = np.array(tuple(map(float, fields[4:8]))) 273 | 274 | im = RImage(id=-1, qvec=qvec, tvec=tvec, name=name, 275 | camera_id='', xys={}, point3D_ids={}) 276 | images.append(im) 277 | return images 278 | -------------------------------------------------------------------------------- /gloc/datasets/dataset_nolabels.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch.utils.data as data 5 | from PIL import Image 6 | import torchvision.transforms as T 7 | 8 | from gloc.utils import read_model_nopoints as read_model, parse_cam_model, qvec2rotmat 9 | from gloc.utils.camera_utils import read_cameras_intrinsics, Image as RImage 10 | 11 | 12 | class IntrinsicsDataset(data.Dataset): 13 | def __init__(self, name, paths_conf, transform=None): 14 | self.name = name 15 | self.root = paths_conf['root'] 16 | self.colmap_model = paths_conf['colmap'] 17 | self.q_files = paths_conf['q_intrinsics'] 18 | self.transform = transform 19 | 20 | self.db_images, self.db_intrinsics = IntrinsicsDataset.load_colmap(self.colmap_model) 21 | self.q_list, self.q_intrinsics = IntrinsicsDataset.load_queries(self.q_files) 22 | self.intrinsics = self.db_intrinsics.copy() 23 | self.intrinsics.update(self.q_intrinsics) 24 | 25 | self.images = self.db_images.copy() 26 | for q in self.q_list: 27 | im = RImage(id=-1, qvec=-1, tvec=-1, name=q.id, camera_id='', xys={}, point3D_ids={}) 28 | self.images.append(im) 29 | 30 | self.db_frames_idxs = list(range(len(self.db_images))) 31 | self.q_frames_idxs = list(np.arange(len(self.q_list)) + len(self.db_images)) 32 | self.db_qvecs = np.array(list(map(lambda x:x.qvec, self.db_images))) 33 | self.db_tvecs = np.array(list(map(lambda x:x.tvec, self.db_images))) 34 | 35 | self.n_db = len(self.db_images) 36 | self.n_q = len(self.q_list) 37 | logging.info(f'Loaded dataset with {self.n_db} db images and {self.n_q} queries w/ intrinsics') 38 | 39 | def get_basename(self, im_idx): 40 | q_name = self.images[im_idx].name.replace('/', '_').split('.')[0] 41 | return q_name 42 | 43 | def get_pose(self, idx): 44 | assert idx < self.n_db, f'Only db images have extrinsics, {idx} is a query' 45 | R = qvec2rotmat(self.images[idx].qvec) 46 | t = self.images[idx].tvec 47 | 48 | return t, R 49 | 50 | def num_queries(self): 51 | return self.n_q 52 | 53 | def get_db_poses(self): 54 | Rs = [] 55 | ts = [] 56 | 57 | for q_idx in range(len(self.db_frames_idxs)): 58 | idx = self.db_frames_idxs[q_idx] 59 | 60 | R = qvec2rotmat(self.images[idx].qvec) 61 | t = self.images[idx].tvec 62 | Rs.append(R) 63 | ts.append(t) 64 | 65 | return np.array(ts), np.array(Rs) 66 | 67 | def __getitem__(self, idx): 68 | """Return: 69 | dict:'im' is the image tensor 70 | 'xyz' is the absolute position of the image 71 | 'wpqr' is the absolute rotation quaternion of the image 72 | """ 73 | data_dict = {} 74 | im_data = self.images[idx] 75 | data_dict['im_ref'] = im_data 76 | im = Image.open(os.path.join(self.root, im_data.name)) 77 | if self.transform: 78 | im = self.transform(im) 79 | data_dict['im'] = im 80 | 81 | return data_dict 82 | 83 | def __len__(self): 84 | return len(self.images) 85 | 86 | def get_pose_by_name(self, name): 87 | names =np.array(list(map(lambda x: x.name, self.images))) 88 | idx = np.argwhere(name == names)[0,0] 89 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec 90 | 91 | return qvec, tvec 92 | 93 | @staticmethod 94 | def load_queries(q_file_list): 95 | q_intrinsincs = {} 96 | q_list = [] 97 | for q_file in q_file_list: 98 | data, intrinsics_dict = IntrinsicsDataset.load_camera_intrinsics(q_file) 99 | q_list += data 100 | q_intrinsincs.update(intrinsics_dict) 101 | 102 | return q_list, q_intrinsincs 103 | 104 | @staticmethod 105 | def load_camera_intrinsics(intrinsic_file): 106 | # Load the images 107 | logging.info(f'Loading intrinsics from {intrinsic_file}') 108 | cam_list = {} 109 | cameras = read_cameras_intrinsics(intrinsic_file) 110 | for cam_data in cameras: 111 | cam_dict = parse_cam_model(cam_data) 112 | K = np.array([ 113 | [cam_dict["fx"], 0.0, cam_dict["cx"]], 114 | [0.0, cam_dict["fy"], cam_dict["cy"]], 115 | [0.0, 0.0, 1.0]]) 116 | w, h = cam_dict["width"], cam_dict["height"] 117 | 118 | basename = os.path.splitext(cam_data.id)[0] 119 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h, 120 | 'model': cam_data.model, 'params':cam_data.params} 121 | 122 | return cameras, cam_list 123 | 124 | @staticmethod 125 | def load_colmap(colmap_model): 126 | # Load the images 127 | logging.info(f'Loading colmap from {colmap_model}') 128 | cam_list = {} 129 | cameras, images = read_model(colmap_model) 130 | for i in images: 131 | qvec = images[i].qvec 132 | tvec = images[i].tvec 133 | cam_data = cameras[images[i].camera_id] 134 | cam_dict = parse_cam_model(cam_data) 135 | 136 | K = np.array([ 137 | [cam_dict["fx"], 0.0, cam_dict["cx"]], 138 | [0.0, cam_dict["fy"], cam_dict["cy"]], 139 | [0.0, 0.0, 1.0]]) 140 | 141 | R = qvec2rotmat(qvec) 142 | T = np.eye(4) 143 | T[0:3, 0:3] = R 144 | T[0:3, 3] = tvec 145 | w, h = cam_dict["width"], cam_dict["height"] 146 | 147 | basename = os.path.splitext(images[i].name)[0] 148 | 149 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h} 150 | return list(images.values()), cam_list 151 | -------------------------------------------------------------------------------- /gloc/datasets/get_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import torchvision.transforms as T 3 | 4 | from gloc.datasets import PoseDataset 5 | from gloc.datasets.dataset_nolabels import IntrinsicsDataset 6 | 7 | 8 | def get_dataset(name, paths_conf, transform=None): 9 | # if 'Aachen' in name: 10 | if name in ['Aachen_night', 'Aachen_day', 'Aachen_real', 'Aachen_real_und']: 11 | dataset = IntrinsicsDataset(name, paths_conf, transform) 12 | else: 13 | dataset = PoseDataset(name, paths_conf, transform) 14 | 15 | return dataset 16 | 17 | 18 | def get_transform(args, colmap_dir=''): 19 | res = args.res 20 | if args.feat_model == 'Dinov2': 21 | cam_file = join(colmap_dir, 'cameras.txt') 22 | random_line = open(cam_file, 'r').readlines()[10].split(' ') 23 | w, h = int(random_line[2]), int(random_line[3]) 24 | patch_size = 14 25 | new_h = patch_size * (h // patch_size) 26 | new_w = patch_size * (w // patch_size) 27 | transform = T.Compose([ 28 | T.ToTensor(), 29 | T.Resize((new_h, new_w), antialias=True), 30 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 31 | ]) 32 | 33 | elif ('Aachen' not in args.name) and (colmap_dir != ''): 34 | cam_file = join(colmap_dir, 'cameras.txt') 35 | random_line = open(cam_file, 'r').readlines()[10].split(' ') 36 | w, h = int(random_line[2]), int(random_line[3]) 37 | ratio = min(h, w) / res 38 | new_h = int(h/ratio) 39 | new_w = int(w/ratio) 40 | transform = T.Compose([ 41 | T.ToTensor(), 42 | T.Resize((new_h, new_w), antialias=True), 43 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 44 | ]) 45 | 46 | else: 47 | transform = T.Compose([ 48 | T.ToTensor(), 49 | T.Resize(res, antialias=True), 50 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 51 | ]) 52 | 53 | return transform 54 | -------------------------------------------------------------------------------- /gloc/datasets/imlist_dataset.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | from PIL import Image 3 | import os 4 | import torch.utils.data as data 5 | from tqdm import tqdm 6 | 7 | from gloc.datasets import RenderedImagesDataset 8 | 9 | 10 | class ImListDataset(data.Dataset): 11 | def __init__(self, path_list, transform=None): 12 | self.path_list = path_list 13 | self.transform = transform 14 | 15 | def __getitem__(self, idx): 16 | im = Image.open(self.path_list[idx]) 17 | 18 | if self.transform: 19 | im = self.transform(im) 20 | return im 21 | 22 | def __len__(self): 23 | return len(self.path_list) 24 | 25 | 26 | def find_candidates_paths(pose_dataset, n_beams, render_dir): 27 | candidates_pathlist = [] 28 | for q_idx in tqdm(range(len(pose_dataset.q_frames_idxs)), ncols=100): 29 | q_name = pose_dataset.get_basename(pose_dataset.q_frames_idxs[q_idx]) 30 | query_dir = os.path.join(render_dir, q_name) 31 | 32 | for beam_i in range(n_beams): 33 | beam_dir = join(query_dir, f'beam_{beam_i}') 34 | 35 | rd = RenderedImagesDataset(beam_dir, verbose=False) 36 | paths = rd.get_full_paths() 37 | candidates_pathlist += paths 38 | 39 | # return last query res; assumes they are all the same 40 | query_tensor = pose_dataset[pose_dataset.q_frames_idxs[q_idx]]['im'] 41 | query_res = tuple(query_tensor.shape[-2:]) 42 | 43 | return candidates_pathlist, query_res 44 | -------------------------------------------------------------------------------- /gloc/extraction/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['extract_real', 'extract_render', 'utils'] 2 | 3 | 4 | from gloc.extraction.extract_feats import extract_features, get_query_features, get_candidates_features 5 | from gloc.extraction.utils import get_predictions_from_step_dir, get_retrieval_predictions, split_renders_into_chunks, get_feat_dim 6 | -------------------------------------------------------------------------------- /gloc/extraction/extract_feats.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from tqdm import tqdm 5 | from os.path import join 6 | import numpy as np 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.dataset import Subset 9 | 10 | 11 | def extract_features(model, model_name, pose_dataset, res, bs=32, check_cache=True): 12 | pd = pose_dataset 13 | DS = pd.name 14 | 15 | q_cache_file = join('descr_cache',f'{DS}_{model_name}_{res}_q_descriptors.pth') 16 | db_cache_file = join('descr_cache', f'{DS}_{model_name}_{res}_db_descriptors.pth') 17 | if check_cache: 18 | if (os.path.isfile(db_cache_file) and os.path.isfile(q_cache_file)): 19 | 20 | logging.info(f"Loading {db_cache_file}") 21 | db_descriptors = torch.load(db_cache_file) 22 | q_descriptors = torch.load(q_cache_file) 23 | 24 | return db_descriptors, q_descriptors 25 | 26 | model = model.eval() 27 | 28 | queries_subset_ds = Subset(pd, pd.q_frames_idxs) 29 | database_subset_ds = Subset(pd, pd.db_frames_idxs) 30 | 31 | db_descriptors = get_query_features(model, database_subset_ds, bs) 32 | q_descriptors = get_query_features(model, queries_subset_ds, bs) 33 | db_descriptors = np.vstack(db_descriptors) 34 | q_descriptors = np.vstack(q_descriptors) 35 | 36 | if check_cache: 37 | os.makedirs('descr_cache', exist_ok=True) 38 | torch.save(db_descriptors, db_cache_file) 39 | torch.save(q_descriptors, q_cache_file) 40 | 41 | return db_descriptors, q_descriptors 42 | 43 | 44 | def get_query_features(model, dataset, bs=1): 45 | """ 46 | Separate function for the queries as they might have different 47 | resolutions; thus it does not use a matrix to store descriptors 48 | but a list of arrays 49 | """ 50 | model = model.eval() 51 | # bs = 1 as resolution might differ 52 | dataloader = DataLoader(dataset=dataset, num_workers=4, batch_size=bs) 53 | 54 | iterator = tqdm(dataloader, ncols=100) 55 | descriptors = [] 56 | with torch.no_grad(): 57 | for images in iterator: 58 | images = images['im'].cuda() 59 | 60 | descr = model(images) 61 | descr = descr.cpu().numpy() 62 | descriptors.append(descr) 63 | 64 | return descriptors 65 | 66 | 67 | def get_candidates_features(model, dataset, descr_dim, bs=32): 68 | dl = DataLoader(dataset=dataset, num_workers=8, batch_size=bs) 69 | 70 | len_ds = len(dataset) 71 | descriptors = np.empty((len_ds, *descr_dim), dtype=np.float32) 72 | 73 | with torch.no_grad(): 74 | for i, images in enumerate(tqdm(dl, ncols=100)): 75 | descr = model(images.cuda()) 76 | 77 | descr = descr.cpu().numpy() 78 | descriptors[i*bs:(i+1)*bs] = descr 79 | 80 | return descriptors 81 | -------------------------------------------------------------------------------- /gloc/extraction/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch 6 | import einops 7 | from os.path import join 8 | from torch.utils.data.dataset import Subset 9 | 10 | from gloc.utils import utils, qvec2rotmat 11 | from gloc import extraction 12 | from gloc.datasets import RenderedImagesDataset 13 | from gloc import models 14 | 15 | 16 | def get_retrieval_predictions(model_name, res, pose_dataset, topk=5): 17 | model = models.get_retrieval_model(model_name) 18 | 19 | db_descriptors, q_descriptors = extraction.extract_features(model, model_name, pose_dataset, res, bs=1) 20 | logging.info(f'N. db descriptors: {db_descriptors.shape[0]}, N. Q descriptors: {q_descriptors.shape[0]}') 21 | logging.info('Computing first retrieval prediction...') 22 | all_pred_t, all_pred_R = utils.get_predictions(db_descriptors, q_descriptors, pose_dataset, top_k=topk) 23 | # all_true_t, all_true_R, all_pred_t, all_pred_R = utils.get_predictions_w_truths(db_descriptors, q_descriptors, pd, top_k=topk) 24 | 25 | return all_pred_t, all_pred_R 26 | 27 | 28 | def get_predictions_from_step_dir(fine_model, pd, render_dir, transform, N_per_beam, n_beams): 29 | queries_subset = Subset(pd, pd.q_frames_idxs) 30 | q_descriptors = extraction.get_features_from_dataset(fine_model, queries_subset) 31 | 32 | all_pred_t = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3)) 33 | all_pred_R = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3, 3)) 34 | names = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam), dtype=object) 35 | all_scores = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam)) 36 | 37 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100): 38 | q_name = pd.get_basename(pd.q_frames_idxs[q_idx]) 39 | query_dir = os.path.join(render_dir, q_name) 40 | query_tensor = pd[pd.q_frames_idxs[q_idx]]['im'] 41 | query_res = tuple(query_tensor.shape[-2:]) 42 | 43 | q_feats = extraction.get_query_descriptor_by_idx(q_descriptors, q_idx) 44 | for beam_i in range(n_beams): 45 | if n_beams == 1: 46 | beam_dir = query_dir 47 | else: 48 | beam_dir = join(query_dir, f'beam_{beam_i}') 49 | rd = RenderedImagesDataset(beam_dir, transform, query_res, verbose=False) 50 | 51 | scores_file = join(beam_dir, 'scores.pth') 52 | if not os.path.isfile(scores_file): 53 | r_db_descriptors = extraction.get_features_from_dataset(fine_model, rd, bs=fine_model.conf.bs, is_render=True) 54 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True) 55 | torch.save((predictions, scores), scores_file) 56 | else: 57 | predictions, scores = torch.load(scores_file) 58 | pred_t, pred_R = utils.get_pose_from_preds(q_idx, pd, rd, predictions, N_per_beam) 59 | 60 | all_pred_t[q_idx, beam_i] = pred_t 61 | all_pred_R[q_idx, beam_i] = pred_R 62 | all_scores[q_idx, beam_i] = scores[:N_per_beam] 63 | 64 | nn = [] 65 | for pr in predictions[:N_per_beam]: 66 | nn.append(rd.images[pr].name) 67 | names[q_idx, beam_i] = np.array(nn) 68 | 69 | #del q_feats, r_db_descriptors 70 | if n_beams > 1: 71 | # flatten stuff to sort predictions based on similarity 72 | flatten_beams = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)') 73 | flatten_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3) 74 | flatten_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3) 75 | 76 | flat_preds = np.argsort(flatten_beams(all_scores)) 77 | names = np.take_along_axis(flatten_beams(names), flat_preds, axis=1) 78 | 79 | flat_t = flatten_t(all_pred_t) 80 | flat_R = flatten_R(all_pred_R) 81 | 82 | stacked_t, stacked_R = [], [] 83 | for i in range(len(all_pred_t)): 84 | sorted_t = flat_t[i, flat_preds[i]] 85 | sorted_R = flat_R[i, flat_preds[i]] 86 | 87 | stacked_t.append(sorted_t) 88 | stacked_R.append(sorted_R) 89 | all_pred_t = np.stack(stacked_t) 90 | all_pred_R = np.stack(stacked_R) 91 | else: 92 | all_pred_t = all_pred_t.squeeze() 93 | all_pred_R = all_pred_R.squeeze() 94 | names = names.squeeze() 95 | 96 | return all_pred_t, all_pred_R, names 97 | 98 | 99 | def split_renders_into_chunks(num_queries, num_candidates, n_beams, N_per_beam, im_per_chunk): 100 | chunk_limit = im_per_chunk 101 | if num_queries > chunk_limit: 102 | q_range = np.arange(num_queries) 103 | # this will be [0, 1100, 2200..] 104 | start_chunk_q_idx = q_range[::chunk_limit] 105 | # start from [1], so the first chunk is from 0 to 1100*beams*n_beam 106 | chunk_idx_ims = [start_q*n_beams*N_per_beam for start_q in start_chunk_q_idx[1:]] 107 | chunks = [np.arange(chunk_idx_ims[0])] 108 | for ic in range(1, len(chunk_idx_ims)): 109 | this_chunk = np.arange(chunk_idx_ims[ic-1], chunk_idx_ims[ic]) 110 | chunks.append(this_chunk) 111 | last_chunk = np.arange(chunk_idx_ims[-1], num_candidates) 112 | chunks.append(last_chunk) 113 | 114 | chunk_start_q_idx = start_chunk_q_idx 115 | chunk_end_q_idx = list(start_chunk_q_idx[1:]) + [num_queries] 116 | else: 117 | chunks = [np.arange(num_candidates)] 118 | chunk_start_q_idx = [0] 119 | chunk_end_q_idx = [num_queries] 120 | 121 | return chunk_start_q_idx, chunk_end_q_idx, chunks 122 | 123 | 124 | def get_feat_dim(fine_model, query_res): 125 | x = torch.rand(1, 3, *query_res) 126 | dim = tuple(fine_model(x.cuda()).shape)[1:] 127 | 128 | return dim 129 | -------------------------------------------------------------------------------- /gloc/initialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join, dirname 3 | import torch 4 | import logging 5 | 6 | from gloc import extraction 7 | from gloc.utils import utils 8 | 9 | 10 | def init_refinement(args, pose_dataset): 11 | first_step = 0 12 | scores = {} 13 | 14 | if (args.pose_prior is None) and (args.resume_step is None): 15 | # if no pose prior, init with retrieval 16 | all_pred_t, all_pred_R = extraction.get_retrieval_predictions(args.retr_model, args.res, pose_dataset, args.beams*args.M) 17 | elif args.pose_prior is not None: 18 | assert os.path.isfile(args.pose_prior), f'{args.pose_prior} does not exist as a file' 19 | 20 | logging.info(f'Loading pose prior from {args.pose_prior}') 21 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pose_dataset, args.beams*args.M) 22 | 23 | if args.resume_step is None: 24 | all_true_t, all_true_R = pose_dataset.get_q_poses() 25 | errors_t, errors_R = utils.get_all_errors_first_estimate(all_true_t, all_true_R, all_pred_t, all_pred_R) 26 | 27 | out_str, out_vals = utils.eval_poses(errors_t, errors_R, descr='Retrieval first estimate') 28 | scores['baseline'] = out_vals 29 | logging.info(out_str) 30 | 31 | # go from (NQ, M*beams, 3/3,3) to (NQ, beams, M, 3/3, 3) 32 | all_pred_t = utils.reshape_preds_per_beam(args.beams, args.M, all_pred_t) 33 | all_pred_R = utils.reshape_preds_per_beam(args.beams, args.M, all_pred_R) 34 | else: 35 | all_pred_t, all_pred_R = None, None 36 | 37 | scores['steps'] = [] 38 | if args.first_step is not None: 39 | first_step = args.first_step 40 | if args.resume_step is not None: 41 | score_path = join(dirname(dirname(args.resume_step)), 'scores.pth') 42 | if os.path.isfile(score_path): 43 | scores = torch.load(score_path) 44 | if len(scores['steps']) > first_step: 45 | logging.info(f"Cutting score file to {first_step}, from {len(scores['steps'])}") 46 | scores['steps'] = scores['steps'][:first_step] 47 | 48 | return first_step, all_pred_t, all_pred_R, scores 49 | -------------------------------------------------------------------------------- /gloc/models/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['layers', 'get_model', 'fine_models'] 2 | 3 | 4 | from gloc.models.get_model import get_retrieval_model, get_ref_model 5 | -------------------------------------------------------------------------------- /gloc/models/features.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | from collections import OrderedDict 4 | import torch 5 | from torch import nn 6 | from torchvision.models import resnet18, resnet50 7 | 8 | from gloc.models.layers import L2Norm 9 | 10 | 11 | class BaseFeaturesClass(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | self.model = nn.Identity() 15 | 16 | def forward(self, x): 17 | """ 18 | To be used by subclasses, each specifying their own `self.model` 19 | Args: 20 | x (torch.tensor): batch of images shape Bx3xHxW 21 | Returns: 22 | torch.tensor: Features maps of shape BxDxHxW 23 | """ 24 | return self.model(x) 25 | 26 | 27 | class CosplaceFeatures(BaseFeaturesClass): 28 | def __init__(self, model_name): 29 | super().__init__() 30 | if 'r18' in model_name: 31 | arch = 'ResNet18' 32 | else: # 'r50' in model_name 33 | arch = 'ResNet50' 34 | # FC dim set to 512 as a placeholder, it will be truncated anyway before the last FC 35 | model = torch.hub.load("gmberton/cosplace", "get_trained_model", 36 | backbone=arch, fc_output_dim=512) 37 | 38 | backbone = model.backbone 39 | if '_l1' in model_name: 40 | backbone = backbone[:-3] 41 | elif '_l2' in model_name: 42 | backbone = backbone[:-2] 43 | elif '_l3' in model_name: 44 | backbone = backbone[:-1] 45 | 46 | self.model = backbone.eval() 47 | 48 | 49 | class ResnetFeatures(BaseFeaturesClass): 50 | def __init__(self, model_name): 51 | super().__init__() 52 | 53 | if model_name.startswith('resnet18'): 54 | model = resnet18(weights='DEFAULT') 55 | elif model_name.startswith('resnet50'): 56 | model = resnet50(weights='DEFAULT') 57 | else: 58 | raise NotImplementedError 59 | 60 | layers = list(model.children())[:-2] # Remove avg pooling and FC layer 61 | backbone = torch.nn.Sequential(*layers) 62 | 63 | if '_l1' in model_name: 64 | backbone = backbone[:-3] 65 | elif '_l2' in model_name: 66 | backbone = backbone[:-2] 67 | elif '_l3' in model_name: 68 | backbone = backbone[:-1] 69 | 70 | self.model = backbone.eval() 71 | 72 | 73 | class AlexnetFeatures(BaseFeaturesClass): 74 | def __init__(self, model_name): 75 | super().__init__() 76 | 77 | model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True) 78 | backbone = model.features 79 | 80 | if '_l1' in model_name: 81 | backbone = backbone[:4] 82 | elif '_l2' in model_name: 83 | backbone = backbone[:7] 84 | elif '_l3' in model_name: 85 | backbone = backbone[:9] 86 | 87 | self.model = backbone.eval() 88 | 89 | 90 | class DinoFeatures(BaseFeaturesClass): 91 | def __init__(self, conf): 92 | super().__init__() 93 | self.conf = conf 94 | self.clamp = conf.clamp 95 | self.norm = L2Norm() 96 | self.conf.bs = 32 97 | self.feat_level = conf.level[0] 98 | 99 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') 100 | self.ref_model = dinov2_vits14 101 | 102 | # override 103 | def forward(self, x): 104 | desc = self.ref_model.get_intermediate_layers(x, n=self.feat_level, reshape=True)[-1] 105 | desc = self.norm(desc) 106 | 107 | return desc 108 | 109 | 110 | class RomaFeatures(BaseFeaturesClass): 111 | weight_urls = { 112 | "roma": { 113 | "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", 114 | "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", 115 | }, 116 | "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D 117 | } 118 | 119 | def __init__(self, conf): 120 | super().__init__() 121 | sys.path.append(str(Path(__file__).parent.joinpath('third_party/RoMa'))) 122 | from roma.models.encoders import CNNandDinov2 123 | 124 | self.conf = conf 125 | weights = torch.hub.load_state_dict_from_url(self.weight_urls["roma"]["outdoor"]) 126 | dinov2_weights = torch.hub.load_state_dict_from_url(self.weight_urls["dinov2"]) 127 | 128 | ww = OrderedDict({k.replace('encoder.', ''): v for (k, v) in weights.items() if k.startswith('encoder') }) 129 | encoder = CNNandDinov2( 130 | cnn_kwargs = dict( 131 | pretrained=False, 132 | amp = True), 133 | amp = True, 134 | use_vgg = True, 135 | dinov2_weights = dinov2_weights 136 | ) 137 | encoder.load_state_dict(ww) 138 | 139 | self.ref_model = encoder.cnn 140 | self.clamp = conf.clamp 141 | self.scale_n = conf.scale_n 142 | self.norm = L2Norm() 143 | self.conf.bs = 16 144 | self.feat_level = conf.level[0] 145 | 146 | # override 147 | def forward(self, x): 148 | f_pyramid = self.ref_model(x) 149 | fmaps = f_pyramid[self.feat_level] 150 | 151 | if self.scale_n != -1: 152 | # optionally scale down fmaps 153 | nh, nw = tuple(fmaps.shape[-2:]) 154 | half = nn.AdaptiveAvgPool2d((nh // self.scale_n, nw // self.scale_n)) 155 | fmaps = half(fmaps) 156 | 157 | desc = self.norm(fmaps) 158 | return desc 159 | -------------------------------------------------------------------------------- /gloc/models/get_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from dataclasses import dataclass 4 | import torch 5 | from torch import nn 6 | import os 7 | from torchvision.models import resnet18, resnet50 8 | 9 | from gloc.models.layers import L2Norm, FlattenFeatureMaps 10 | from gloc.models import features 11 | 12 | 13 | def get_retrieval_model(model_name, cuda=True, eval=True): 14 | if model_name.startswith('cosplace'): 15 | model = torch.hub.load("gmberton/cosplace", "get_trained_model", 16 | backbone="ResNet18", fc_output_dim=512) 17 | else: 18 | raise NotImplementedError() 19 | 20 | if cuda: 21 | model = model.cuda() 22 | if eval: 23 | model = model.eval() 24 | 25 | return model 26 | 27 | 28 | def get_feature_model(args, model_name, cuda=True): 29 | if model_name.startswith('cosplace'): 30 | feat_model = features.CosplaceFeatures(model_name) 31 | 32 | elif model_name.startswith('resnet'): 33 | feat_model = features.ResnetFeatures(model_name) 34 | 35 | elif model_name.startswith('alexnet'): 36 | feat_model = features.AlexnetFeatures(model_name) 37 | 38 | elif model_name == 'Dinov2': 39 | conf = DinoConf(clamp=args.clamp_score, level=args.feat_level) 40 | feat_model = features.DinoFeatures(conf) 41 | 42 | elif model_name == 'Roma': 43 | conf = RomaConf(clamp=args.clamp_score, level=args.feat_level, scale_n=args.scale_fmaps) 44 | feat_model = features.RomaFeatures(conf) 45 | 46 | else: 47 | raise NotImplementedError() 48 | 49 | if cuda: 50 | feat_model = feat_model.cuda() 51 | 52 | return feat_model 53 | 54 | 55 | def get_ref_model(args, cuda=True): 56 | model_name = args.ref_model 57 | feat_model = get_feature_model(args, args.feat_model, cuda) 58 | 59 | if model_name == 'DenseFeatures': 60 | from gloc.models.refinement_model import DenseFeaturesRefiner 61 | model_class = DenseFeaturesRefiner 62 | conf = DenseFeaturesConf(clamp=args.clamp_score) 63 | 64 | else: 65 | raise NotImplementedError() 66 | 67 | model = model_class(conf, feat_model) 68 | if cuda: 69 | model = model.cuda() 70 | 71 | return model 72 | 73 | 74 | @dataclass 75 | class DenseFeaturesConf: 76 | clamp: float = -1 77 | def get_str__conf(self): 78 | repr = f"_cl{self.clamp}" 79 | return repr 80 | 81 | 82 | @dataclass 83 | class DinoConf: 84 | clamp: float = -1 85 | level: int = 8 86 | def get_str__conf(self): 87 | repr = f"_l{self.level}_cl{self.clamp}" 88 | return repr 89 | 90 | @dataclass 91 | class RomaConf: 92 | clamp: float = -1 93 | level: int = 4 # 1 2 4 8 94 | # pool feature maps to 1/n 95 | scale_n: int = -1 96 | def get_str__conf(self): 97 | repr = f"_l{self.level}_sn{self.scale_n}_cl{self.clamp}" 98 | return repr 99 | -------------------------------------------------------------------------------- /gloc/models/layers.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | from torch import nn 4 | from torch.nn import Parameter 5 | import torch.nn.functional as F 6 | 7 | 8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6): 9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 10 | 11 | 12 | class GeM(nn.Module): 13 | def __init__(self, p=3, eps=1e-6): 14 | super().__init__() 15 | self.p = Parameter(torch.ones(1)*p) 16 | self.eps = eps 17 | 18 | def forward(self, x): 19 | return gem(x, p=self.p, eps=self.eps) 20 | 21 | def __repr__(self): 22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})" 23 | 24 | 25 | class Flatten(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def forward(self, x): 30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1" 31 | return x[:, :, 0, 0] 32 | 33 | 34 | class FlattenFeatureMaps(nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | self.flatten = lambda x: einops.rearrange(x, 'b c h w -> b (c h w)') 38 | def forward(self, x): 39 | return self.flatten(x) 40 | 41 | 42 | class L2Norm(nn.Module): 43 | def __init__(self, dim=1): 44 | super().__init__() 45 | self.dim = dim 46 | def forward(self, x): 47 | return F.normalize(x, p=2, dim=self.dim) 48 | -------------------------------------------------------------------------------- /gloc/models/refinement_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from gloc.models.layers import L2Norm 6 | 7 | 8 | class DenseFeaturesRefiner(nn.Module): 9 | def __init__(self, conf, ref_model): 10 | super().__init__() 11 | self.conf = conf 12 | self.ref_model = ref_model 13 | self.clamp = conf.clamp 14 | self.norm = L2Norm() 15 | self.conf.bs = 32 16 | 17 | def forward(self, x): 18 | """ 19 | Args: 20 | x (torch.tensor): batch of images shape Bx3xHxW 21 | Returns: 22 | torch.tensor: Features of shape BxDxHxW 23 | """ 24 | with torch.no_grad(): 25 | desc = self.ref_model(x) 26 | desc = self.norm(desc) 27 | 28 | return desc 29 | 30 | def score_candidates(self, q_feats, r_db_descriptors): 31 | """_summary_ 32 | 33 | Args: 34 | q_feats (np.array): shape 1 x C x H x W 35 | r_db_descriptors (np.array): shape N_cand x C x H x W 36 | 37 | Returns: 38 | torch.tensor : vector of shape (N_cand, ), score of each one 39 | """ 40 | q_feats = torch.tensor(q_feats) 41 | 42 | # this version is faster than looped, but requires much more memory due to broadcasting 43 | # r_db = torch.tensor(r_db_descriptors).squeeze(1) 44 | # scores = torch.linalg.norm(q_feats - r_db, dim=1) 45 | scores = torch.zeros(len(r_db_descriptors), q_feats.shape[-2], q_feats.shape[-1]) 46 | for i, desc in enumerate(r_db_descriptors): 47 | # q_feats : 1, D, H, W 48 | # desc : D, H, W 49 | # score : 1, H, W 50 | score = torch.linalg.norm(q_feats - torch.tensor(desc), dim=1) 51 | scores[i] = score[0] 52 | 53 | if self.clamp > 0: 54 | scores = scores.clamp(max=self.clamp) 55 | scores = scores.sum(dim=(1,2)) / np.prod(scores.shape[-2:]) 56 | 57 | return scores 58 | 59 | def rank_candidates(self, q_feats, r_db_descriptors, get_scores=False): 60 | scores = self.score_candidates(q_feats, r_db_descriptors) 61 | preds = torch.argsort(scores) 62 | if get_scores: 63 | return preds, scores[preds] 64 | return preds 65 | -------------------------------------------------------------------------------- /gloc/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['render_conf', 'rend_utils'] 2 | 3 | 4 | from gloc.rendering.rend_conf import get_renderer 5 | from gloc.rendering.rend_utils import log_poses, split_to_beam_folder 6 | -------------------------------------------------------------------------------- /gloc/rendering/base_renderer.py: -------------------------------------------------------------------------------- 1 | class BaseRenderer: 2 | def __init__(self, conf): 3 | pass 4 | 5 | def load_model(self): 6 | return None 7 | 8 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh): 9 | pass 10 | 11 | def end_epoch(self, step_dir): 12 | pass 13 | 14 | @staticmethod 15 | def clean_file_names(r_dir, r_names, verbose): 16 | pass 17 | -------------------------------------------------------------------------------- /gloc/rendering/mesh_renderer.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import logging 3 | import numpy as np 4 | import open3d as o3d 5 | from os.path import join 6 | 7 | from gloc.rendering.base_renderer import BaseRenderer 8 | 9 | 10 | class MeshRenderer(BaseRenderer): 11 | def __init__(self, conf): 12 | super().__init__(conf) 13 | self.mesh_path = conf.mesh_path 14 | self.background = 'black' 15 | self.supports_deferred_rendering = False 16 | logging.info(f'Using mesh from {self.mesh_path}') 17 | 18 | # override 19 | def load_model(self): 20 | mesh = o3d.io.read_triangle_model(self.mesh_path, False) 21 | 22 | for iter in range(len(mesh.materials)): 23 | mesh.materials[iter].shader = "defaultLit" 24 | # mesh.materials[iter].shader = "defaultUnlit" 25 | 26 | # - the original colors make the textures too dark - set to white 27 | mesh.materials[iter].base_color = [1.0, 1.0, 1.0, 1.0] 28 | 29 | return mesh 30 | 31 | # override 32 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, **kwargs): 33 | w, h = wh 34 | 35 | renderer = o3d.visualization.rendering.OffscreenRenderer(w, h) 36 | renderer.scene.add_model("Scene mesh", model) 37 | # - setup lighting 38 | renderer.scene.scene.enable_sun_light(True) 39 | if self.background == 'black': 40 | renderer.scene.set_background([0., 0., 0.,0.]) 41 | 42 | for i, fname in enumerate(r_names): 43 | output_path = join(out_dir, f"{fname}.png") 44 | T, K = pose_list[i] 45 | 46 | renderer.setup_camera(K, T, w, h) 47 | color = np.array(renderer.render_to_image()) 48 | 49 | img_rendering = PIL.Image.fromarray(color) 50 | img_rendering.save(output_path) 51 | 52 | del renderer 53 | -------------------------------------------------------------------------------- /gloc/rendering/nerf_renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import shutil 4 | import numpy as np 5 | from os.path import join 6 | from pathlib import Path 7 | import mediapy as media 8 | import torch 9 | import logging 10 | from typing import List 11 | 12 | from nerfstudio.cameras.cameras import Cameras 13 | from nerfstudio.pipelines.base_pipeline import Pipeline 14 | from nerfstudio.utils import colormaps 15 | from nerfstudio.utils.eval_utils import eval_setup 16 | from nerfstudio.cameras.camera_paths import get_path_from_json 17 | from nerfstudio.utils import colormaps 18 | 19 | from gloc.rendering.base_renderer import BaseRenderer 20 | from gloc.utils import get_c2w_nerfconv 21 | 22 | 23 | class NerfRenderer(BaseRenderer): 24 | def __init__(self, conf): 25 | super().__init__(conf) 26 | self.ns_config = conf.ns_config 27 | self.ns_transform = conf.ns_transform 28 | self.supports_deferred_rendering = False 29 | logging.info(f'Using nerf from {self.ns_config}') 30 | 31 | # override 32 | def load_model(self): 33 | _, pipeline, _, _ = eval_setup( 34 | Path(self.ns_config), 35 | eval_num_rays_per_chunk=None, 36 | test_mode="inference", 37 | ) 38 | 39 | return pipeline 40 | 41 | # override 42 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, deferred=False): 43 | # for refinement experiments, 44 | # the K matrix is the same throughout the pose list, so take the 1st 45 | K = pose_list[0][1] 46 | traj_file = self._gen_trajectory_file(out_dir, wh, K, r_names, render_ts, render_qvecs) 47 | 48 | with open(traj_file, "r", encoding="utf-8") as f: 49 | camera_path = json.load(f) 50 | camera_path = get_path_from_json(camera_path) 51 | 52 | output_path = Path(out_dir) 53 | rendered_output_names = ['rgb'] 54 | self._render_trajectory( 55 | model, 56 | camera_path, 57 | output_filename=output_path, 58 | rendered_output_names=rendered_output_names, 59 | ) 60 | 61 | def _gen_trajectory_file(self, out_dir, wh, K, r_names, render_ts, render_qvecs): 62 | width, height = wh 63 | f_len = K[0,0] 64 | traj_file = join(out_dir, 'trajectory.json') 65 | 66 | out = { 67 | 'camera_path': [] 68 | } 69 | out["render_height"] = height 70 | out["render_width"] = width 71 | out["seconds"] = len(r_names) 72 | 73 | # load the transformation from dataparser_transforms.json 74 | with open(self.ns_transform) as f: 75 | json_data_txt = f.read() 76 | json_data = json.loads(json_data_txt) 77 | 78 | T_dp = np.eye(4) 79 | T_dp[0:3,:] = np.array(json_data["transform"]) 80 | s_dp = json_data["scale"] 81 | 82 | for i in range(len(r_names)): 83 | cam_dict = {} 84 | name = r_names[i] 85 | tvec, qvec = render_ts[i], render_qvecs[i] 86 | 87 | # get camera to world matrix, in nerf convention 88 | c2w = get_c2w_nerfconv(qvec, tvec) 89 | # apply dataparser transform 90 | c2w = T_dp @ c2w 91 | c2w[0:3, 3] = s_dp * c2w[0:3, 3] 92 | 93 | cam_dict['camera_to_world'] = c2w.tolist() 94 | cam_dict['fov'] = np.degrees(2*np.arctan2(height, 2*f_len)) 95 | cam_dict['file_path'] = name 96 | 97 | out['camera_path'].append(cam_dict) 98 | 99 | with open(traj_file, 'w') as f: 100 | json.dump(out, f, indent=4) 101 | 102 | return traj_file 103 | 104 | @staticmethod 105 | def _render_trajectory( 106 | pipeline: Pipeline, 107 | cameras: Cameras, 108 | output_filename: Path, 109 | rendered_output_names: List[str], 110 | colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(), 111 | ) -> None: 112 | """Helper function to create a video of the spiral trajectory. 113 | 114 | Args: 115 | pipeline: Pipeline to evaluate with. 116 | cameras: Cameras to render. 117 | output_filename: Name of the output file. 118 | rendered_output_names: List of outputs to visualise. 119 | colormap_options: Options for colormap. 120 | """ 121 | cameras = cameras.to(pipeline.device) 122 | output_image_dir = output_filename.parent / output_filename.stem 123 | for camera_idx in range(cameras.size): 124 | aabb_box = None 125 | camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx, aabb_box=aabb_box) 126 | 127 | with torch.no_grad(): 128 | outputs = pipeline.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle) 129 | 130 | render_image = [] 131 | for rendered_output_name in rendered_output_names: 132 | output_image = outputs[rendered_output_name] 133 | 134 | output_image = colormaps.apply_colormap( 135 | image=output_image, 136 | colormap_options=colormap_options, 137 | ).cpu().numpy() 138 | 139 | render_image.append(output_image) 140 | render_image = np.concatenate(render_image, axis=1) 141 | 142 | media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png") 143 | 144 | @staticmethod 145 | def clean_file_names(r_dir, r_names, verbose=False): 146 | if verbose: 147 | print('Changing filenames format...') 148 | 149 | f_names = sorted(filter(lambda x: x.endswith('.png'), os.listdir(r_dir))) 150 | for i, f_name in enumerate(f_names): 151 | src = join(r_dir, f_name) 152 | dst = join(r_dir, r_names[i]+'.png') 153 | shutil.move(src, dst) 154 | -------------------------------------------------------------------------------- /gloc/rendering/rend_conf.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from gloc.rendering.mesh_renderer import MeshRenderer 4 | 5 | 6 | @dataclass 7 | class O3DConf(): 8 | mesh_path:str 9 | 10 | @dataclass 11 | class NeRFConf(): 12 | ns_config:str 13 | ns_transform:str 14 | 15 | @dataclass 16 | class GSplattingConf(): 17 | gaussians:str 18 | 19 | 20 | def get_renderer(args, paths_conf): 21 | 22 | if args.renderer == 'o3d': 23 | rend_class = MeshRenderer 24 | conf = O3DConf(paths_conf[args.name]['mesh_path']) 25 | 26 | elif args.renderer == 'nerf': 27 | from gloc.rendering.nerf_renderer import NerfRenderer 28 | conf = NeRFConf(ns_config=paths_conf[args.name]['ns_config'], 29 | ns_transform=paths_conf[args.name]['ns_transform']) 30 | rend_class = NerfRenderer 31 | 32 | elif args.renderer == 'g_splatting': 33 | from gloc.rendering.splatting_renderer import GaussianSplattingRenderer 34 | conf = GSplattingConf(paths_conf[args.name]['gaussians']) 35 | rend_class = GaussianSplattingRenderer 36 | 37 | else: 38 | raise NotImplementedError() 39 | 40 | renderer = rend_class(conf) 41 | 42 | return renderer 43 | -------------------------------------------------------------------------------- /gloc/rendering/rend_utils.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | import shutil 3 | import os 4 | 5 | 6 | def log_poses(r_dir, r_names, render_ts, render_qvecs, renderer): 7 | if renderer == 'nerf': 8 | width = 5 9 | r_names = [str(idx).zfill(width) for idx in range(len(r_names))] 10 | 11 | with open(join(r_dir, 'rendered_views.txt'), 'w') as rv: 12 | for i in range(len(r_names)): 13 | line_data = [r_names[i], *tuple(render_ts[i]), *tuple(render_qvecs[i])] 14 | line = " ".join(map(str, line_data)) 15 | rv.write(line+'\n') 16 | 17 | 18 | def split_to_beam_folder(r_dir, n_beams, r_names_per_beam_q_idx, create_beams=False): 19 | for beam_i in range(n_beams): 20 | beam_dir = join(r_dir, f'beam_{beam_i}') 21 | if create_beams: 22 | os.makedirs(beam_dir, exist_ok=True) 23 | beam_names = r_names_per_beam_q_idx[beam_i] 24 | for b_name in beam_names: 25 | src = join(r_dir, b_name+'.png') 26 | dst = join(beam_dir, b_name+'.png') 27 | shutil.move(src, dst) 28 | -------------------------------------------------------------------------------- /gloc/rendering/splatting_renderer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from os.path import join 3 | import os 4 | import sys 5 | from pathlib import Path 6 | from argparse import Namespace 7 | import torch 8 | import torchvision 9 | import numpy as np 10 | from tqdm import tqdm 11 | import shutil 12 | from glob import glob 13 | from typing import List 14 | 15 | from gloc.utils import camera_utils, Image as RImage 16 | from gloc.rendering.base_renderer import BaseRenderer 17 | 18 | sys.path.append(str(Path(__file__).parent.parent.parent.joinpath('third_party/gaussian-splatting'))) 19 | 20 | from scene import Scene 21 | from gaussian_renderer import render 22 | from gaussian_renderer import GaussianModel 23 | 24 | 25 | class GaussianSplattingRenderer(BaseRenderer): 26 | def __init__(self, conf): 27 | super().__init__(conf) 28 | self.pipeline = Namespace(convert_SHs_python=False, compute_cov3D_python=False, debug=False) 29 | self.gaussians = conf.gaussians 30 | self.sh_degree = 3 31 | self.supports_deferred_rendering = True 32 | self.iteration = 7000 33 | self.buf_deferred = self.init_buffer() 34 | logging.info(f'Using Gaussians from from {self.gaussians}') 35 | self.dataset = Namespace(data_device= 'cuda', eval= False, images='dont', 36 | model_path=self.gaussians, resolution='dont', sh_degree= 3, 37 | source_path='', white_background=False) 38 | 39 | # override 40 | def load_model(self): 41 | gaussians = GaussianModel(self.sh_degree) 42 | return gaussians 43 | 44 | # override 45 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, deferred=True): 46 | # for refinement experiments, 47 | # the K matrix is the same throughout the pose list, so take the 1st 48 | K = pose_list[0][1] 49 | if deferred: 50 | self.update_buffer(wh, K, r_names, render_ts, render_qvecs) 51 | else: 52 | if not isinstance(wh, list): 53 | # wh, r_names, render_ts, render_qvecs = [wh], [r_names], [render_ts], [render_qvecs] 54 | wh = [wh] 55 | 56 | # BEFORE it was only THIS 57 | colmap_dir = self._gen_colmap(out_dir, wh, [K], [r_names], [render_ts], [render_qvecs]) 58 | self.dataset.source_path = colmap_dir 59 | scene = Scene(self.dataset, model, load_iteration=self.iteration, shuffle=False) 60 | bg_color = [0, 0, 0] 61 | 62 | with torch.no_grad(): 63 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 64 | views = scene.getTrainCameras() 65 | for view in tqdm(views, ncols=100): 66 | rendering = render(view, model, self.pipeline, background)["render"] 67 | torchvision.utils.save_image(rendering, join(out_dir, view.image_name+'.png')) 68 | 69 | # override 70 | def end_epoch(self, step_dir): 71 | if self.is_buffer_empty(): 72 | return 73 | 74 | model = GaussianModel(self.sh_degree) 75 | 76 | wh_l, K_l, r_names_l, render_ts_l, render_qvecs_l = self.read_buffer() 77 | colmap_dir = self._gen_colmap(step_dir, wh_l, K_l, r_names_l, render_ts_l, render_qvecs_l) 78 | self.dataset.source_path = colmap_dir 79 | scene = Scene(self.dataset, model, load_iteration=self.iteration, shuffle=False) 80 | bg_color = [0, 0, 0] 81 | 82 | with torch.no_grad(): 83 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") 84 | views = scene.getTrainCameras() 85 | for view in tqdm(views, ncols=100): 86 | rendering = render(view, model, self.pipeline, background)["render"] 87 | torchvision.utils.save_image(rendering, join(step_dir, view.image_name+'.png')) 88 | 89 | # clear buffer 90 | self.buf_deferred = self.init_buffer() 91 | print(f'Sorting out the step dir....') 92 | self.sort_step_dir(step_dir) 93 | 94 | def sort_step_dir(self, step_dir): 95 | # put all the renders inside their query directory 96 | all_images = glob(join(step_dir, '*.png')) 97 | im_per_query = {} 98 | for im in all_images: 99 | r_name = im.split('.png')[0] 100 | end_name = r_name.rfind('_') 101 | q_name = r_name[:end_name] 102 | q_name = os.path.basename(q_name) 103 | if q_name not in im_per_query: 104 | im_per_query[q_name] = [r_name] 105 | else: 106 | im_per_query[q_name].append(r_name) 107 | 108 | for q_name in tqdm(im_per_query, ncols=100): 109 | q_renders = im_per_query[q_name] 110 | os.makedirs(join(step_dir, q_name), exist_ok=True) 111 | for q_r in q_renders: 112 | rbg_r = q_r + '.png' 113 | shutil.move(rbg_r, join(step_dir, q_name)) 114 | 115 | def is_buffer_empty(self): 116 | return (len(self.buf_deferred['K']) == 0) 117 | 118 | def read_buffer(self): 119 | bf = (self.buf_deferred['wh'], self.buf_deferred['K'], self.buf_deferred['r_names'], 120 | self.buf_deferred['render_ts'], self.buf_deferred['render_qvecs']) 121 | return bf 122 | 123 | def update_buffer(self, wh, K, r_names, render_ts, render_qvecs): 124 | self.buf_deferred['wh'].append(wh) 125 | self.buf_deferred['K'].append(K) 126 | self.buf_deferred['r_names'].append(r_names) 127 | self.buf_deferred['render_ts'].append(render_ts) 128 | self.buf_deferred['render_qvecs'].append(render_qvecs) 129 | 130 | @staticmethod 131 | def init_buffer(): 132 | buffer = { 133 | 'wh': [], 134 | 'K': [], 135 | 'r_names': [], 136 | 'render_ts': [], 137 | 'render_qvecs': [] 138 | } 139 | return buffer 140 | 141 | @staticmethod 142 | def _gen_colmap(out_dir: str, wh_l: List[tuple], K_l: List[np.array], 143 | r_names_l: List[str], render_ts_l: List[np.array], render_qvecs_l: List[np.array]): 144 | out_cameras = {} 145 | out_images = {} 146 | # print(type(wh_l), type(wh_l[0]), len(wh_l)) 147 | # print(len(r_names_l), len(render_ts_l), len(render_qvecs_l)) 148 | n_cameras = len(K_l) 149 | im_per_camera = len(r_names_l[0]) 150 | # print(n_cameras) 151 | # print(K_l) 152 | for c_id in range(n_cameras): 153 | wh, K, r_names, render_ts, render_qvecs = wh_l[c_id], K_l[c_id], r_names_l[c_id], render_ts_l[c_id], render_qvecs_l[c_id] 154 | width, height = wh 155 | # assumes PINHOLE model 156 | model = 'PINHOLE' 157 | fx, fy = K[0,0], K[1,1] 158 | cx, cy = K[0,2], K[1,2] 159 | params = np.array([fx, fy, cx, cy]) 160 | c = camera_utils.Camera(id=c_id, model=model, width=width, height=height, params=params) 161 | out_cameras[c_id] = c 162 | 163 | for r_id in range(len(r_names)): 164 | name = r_names[r_id] 165 | im_id = c_id*im_per_camera + r_id 166 | im = RImage(id=im_id, qvec=render_qvecs[r_id], 167 | tvec=render_ts[r_id], camera_id=c_id, 168 | name=name, xys={}, point3D_ids={}) 169 | out_images[im_id] = im 170 | 171 | colmap_subdir = join(out_dir, 'sparse', '0') 172 | os.makedirs(colmap_subdir) 173 | camera_utils.write_model_nopoints(out_cameras, out_images, colmap_subdir, ext='.txt') 174 | 175 | return out_dir 176 | -------------------------------------------------------------------------------- /gloc/resamplers/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['get_protocol', 'strategies'] 2 | 3 | from gloc.resamplers.get_protocol import get_protocol 4 | from gloc.resamplers.sampling_utils import gen_rotations, gen_translations, parse_pose_data 5 | -------------------------------------------------------------------------------- /gloc/resamplers/get_protocol.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from os.path import join 3 | import torch 4 | 5 | from gloc.resamplers.scalers_conf import get_sampler 6 | from gloc.resamplers.strategies import Protocol1, Protocol2 7 | 8 | 9 | def get_protocol(args, n_views, protocol): 10 | scaler_name = protocol.split('_')[-1] 11 | sampler, scaler = get_sampler(args, args.sampler, scaler_name) 12 | 13 | if protocol.startswith('1'): 14 | protocol_conf = Protocol1Conf(N_steps=args.steps, n_views=n_views) 15 | protocol_class = Protocol1 16 | 17 | elif protocol.startswith('2'): 18 | protocol_conf = Protocol2Conf(N_steps=args.steps, n_views=n_views, M_candidates=args.M) 19 | protocol_class = Protocol2 20 | 21 | else: 22 | raise NotImplementedError 23 | 24 | protocol_obj = protocol_class(protocol_conf, sampler, scaler, protocol) 25 | return protocol_obj 26 | 27 | 28 | @dataclass 29 | class ProtocolConf: 30 | N_steps: int = 20 31 | n_views: int = 20 32 | 33 | 34 | @dataclass 35 | class Protocol1Conf(ProtocolConf): 36 | pass 37 | 38 | 39 | @dataclass 40 | class Protocol2Conf(ProtocolConf): 41 | M_candidates: int = 4 42 | -------------------------------------------------------------------------------- /gloc/resamplers/samplers.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import numpy as np 4 | from pyquaternion import Quaternion 5 | 6 | from gloc.utils import rotmat2qvec, qvec2rotmat 7 | 8 | 9 | class RandomConstantSampler(): 10 | def __init__(self, conf): 11 | pass 12 | 13 | def sample_batch(self, n_views, center_noise, angle_noise, old_t, old_R): 14 | qvecs = [] 15 | tvecs = [] 16 | poses = [] 17 | 18 | for _ in range(n_views): 19 | new_tvec, new_qvec, new_T = self.sample(center_noise, angle_noise, old_t, old_R) 20 | 21 | qvecs.append(new_qvec) 22 | tvecs.append(new_tvec) 23 | poses.append(new_T) 24 | 25 | return tvecs, qvecs, poses 26 | 27 | @staticmethod 28 | def sample(center_noise, angle_noise, old_t, old_R, low_std_ax=1): 29 | old_qvec = rotmat2qvec(old_R) # transform to qvec 30 | 31 | r_axis = Quaternion.random().axis # sample random axis 32 | teta = angle_noise # sample random angle smaller than theta 33 | r_quat = Quaternion(axis=r_axis, degrees=teta) 34 | new_qvec = r_quat * old_qvec # perturb the original pose 35 | 36 | # convert from Quaternion to np.array 37 | new_qvec = new_qvec.elements 38 | new_R = qvec2rotmat(new_qvec) 39 | 40 | old_center = - old_R.T @ old_t # get image center using original pose 41 | perturb_c = np.random.rand(2) 42 | perturb_low_ax = np.random.rand(1)*0.1 43 | perturb_c = np.insert(perturb_c, low_std_ax, perturb_low_ax) 44 | perturb_c /= np.linalg.norm(perturb_c) # normalize noise vector 45 | 46 | # move along the noise direction for a fixed magnitude 47 | new_center = old_center + perturb_c*center_noise 48 | new_t = - new_R @ new_center # use the new pose to convert to translation vector 49 | 50 | new_T = np.eye(4) 51 | new_T[0:3, 0:3] = new_R 52 | new_T[0:3, 3] = new_t 53 | 54 | return new_t, new_qvec, new_T 55 | 56 | 57 | class RandomGaussianSampler(): 58 | def __init__(self, conf): 59 | pass 60 | 61 | def sample_batch(self, n_views, center_std, max_angle, old_t, old_R): 62 | qvecs = [] 63 | tvecs = [] 64 | poses = [] 65 | 66 | for _ in range(n_views): 67 | new_tvec, new_qvec, new_T = self.sample(center_std, max_angle, old_t, old_R) 68 | 69 | qvecs.append(new_qvec) 70 | tvecs.append(new_tvec) 71 | poses.append(new_T) 72 | 73 | return tvecs, qvecs, poses 74 | 75 | @staticmethod 76 | def sample(center_std, max_angle, old_t, old_R): 77 | old_qvec = rotmat2qvec(old_R) # transform to qvec 78 | 79 | r_axis = Quaternion.random().axis # sample random axis 80 | teta = np.random.rand()*max_angle # sample random angle smaller than theta 81 | r_quat = Quaternion(axis=r_axis, degrees=teta) 82 | new_qvec = r_quat * old_qvec # perturb the original pose 83 | 84 | # convert from Quaternion to np.array 85 | new_qvec = new_qvec.elements 86 | new_R = qvec2rotmat(new_qvec) 87 | 88 | old_center = - old_R.T @ old_t # get image center using original pose 89 | perturb_c = torch.normal(0., center_std) 90 | new_center = old_center + np.array(perturb_c) # perturb it 91 | new_t = - new_R @ new_center # use the new pose to convert to translation vector 92 | 93 | new_T = np.eye(4) 94 | new_T[0:3, 0:3] = new_R 95 | new_T[0:3, 3] = new_t 96 | 97 | return new_t, new_qvec, new_T 98 | 99 | 100 | class RandomDoubleAxisSampler(): 101 | rotate_axis = { 102 | 'pitch': [1, 0, 0], # x, pitch 103 | 'yaw': [0, 1, 0] # y, yaw 104 | } 105 | 106 | def __init__(self, conf): 107 | pass 108 | 109 | def sample_batch(self, n_views: int, center_std: torch.tensor, max_angle: torch.tensor, 110 | old_t: np.array, old_R: np.array): 111 | qvecs = [] 112 | tvecs = [] 113 | poses = [] 114 | 115 | for _ in range(n_views): 116 | # apply yaw first 117 | ax = self.rotate_axis['yaw'] 118 | new_tvec, _, new_R, _ = self.sample(ax, center_std, float(max_angle[0]), old_t, old_R) 119 | 120 | # apply pitch then 121 | ax = self.rotate_axis['pitch'] 122 | new_tvec, new_qvec, _, new_T = self.sample(ax, center_std, float(max_angle[1]), new_tvec, new_R) 123 | 124 | qvecs.append(new_qvec) 125 | tvecs.append(new_tvec) 126 | poses.append(new_T) 127 | 128 | return tvecs, qvecs, poses 129 | 130 | @staticmethod 131 | def sample(axis, center_std: torch.tensor, max_angle: float, old_t: np.array, old_R: np.array, rot_only: bool =False): 132 | old_qvec = rotmat2qvec(old_R) # transform to qvec 133 | 134 | teta = np.random.rand()*max_angle # sample random angle smaller than theta 135 | r_quat = Quaternion(axis=axis, degrees=teta) 136 | new_qvec = r_quat * old_qvec # perturb the original pose 137 | 138 | # convert from Quaternion to np.array 139 | new_qvec = new_qvec.elements 140 | new_R = qvec2rotmat(new_qvec) 141 | 142 | if not rot_only: 143 | old_center = - old_R.T @ old_t # get image center using original pose 144 | perturb_c = torch.normal(0., center_std) 145 | new_center = old_center + np.array(perturb_c) # perturb it 146 | new_t = - new_R @ new_center # use the new pose to convert to translation vector 147 | else: 148 | new_t = - new_R @ old_center # use the new pose to convert to translation vector 149 | 150 | new_T = np.eye(4) 151 | new_T[0:3, 0:3] = new_R 152 | new_T[0:3, 3] = new_t 153 | 154 | return new_t, new_qvec, new_R, new_T 155 | 156 | 157 | class RandomSamplerByAxis(): 158 | rotate_axis = [ 159 | [1, 0, 0], # x, pitch 160 | [0, 1, 0] # y, yaw 161 | ] 162 | 163 | def __init__(self, conf): 164 | pass 165 | 166 | def sample_batch(self, n_views, center_std, max_angle, old_t, old_R): 167 | qvecs = [] 168 | tvecs = [] 169 | poses = [] 170 | 171 | for i in range(n_views): 172 | # use first axis half the time, the other for the rest 173 | ax_i = i // ((n_views+1) // 2) 174 | ax = self.rotate_axis[ax_i] 175 | 176 | new_tvec, new_qvec, new_T = self.sample(ax, center_std, max_angle, old_t, old_R) 177 | 178 | qvecs.append(new_qvec) 179 | tvecs.append(new_tvec) 180 | poses.append(new_T) 181 | 182 | return tvecs, qvecs, poses 183 | 184 | @staticmethod 185 | def sample(axis, center_std, max_angle, old_t, old_R): 186 | old_qvec = rotmat2qvec(old_R) # transform to qvec 187 | 188 | teta = np.random.rand()*max_angle # sample random angle smaller than theta 189 | r_quat = Quaternion(axis=axis, degrees=teta) 190 | new_qvec = r_quat * old_qvec # perturb the original pose 191 | 192 | # convert from Quaternion to np.array 193 | new_qvec = new_qvec.elements 194 | new_R = qvec2rotmat(new_qvec) 195 | 196 | old_center = - old_R.T @ old_t # get image center using original pose 197 | perturb_c = torch.normal(0., center_std) 198 | new_center = old_center + np.array(perturb_c) # perturb it 199 | new_t = - new_R @ new_center # use the new pose to convert to translation vector 200 | 201 | new_T = np.eye(4) 202 | new_T[0:3, 0:3] = new_R 203 | new_T[0:3, 3] = new_t 204 | 205 | return new_t, new_qvec, new_T 206 | 207 | 208 | class RandomAndDoubleAxisSampler(): 209 | rotate_axis = { 210 | 'pitch': [1, 0, 0], # x, pitch 211 | 'yaw': [0, 1, 0] # y, yaw 212 | } 213 | 214 | def __init__(self, conf): 215 | pass 216 | 217 | def sample_batch(self, n_views: int, center_std: torch.tensor, max_angle: torch.tensor, 218 | old_t: np.array, old_R: np.array): 219 | qvecs = [] 220 | tvecs = [] 221 | poses = [] 222 | 223 | for i in range(n_views): 224 | # use DoubleRotation half the time, Random the rest 225 | ax_i = i // ((n_views+1) // 2) 226 | if ax_i == 0: 227 | # double ax rotation 228 | 229 | # apply yaw first 230 | ax = self.rotate_axis['yaw'] 231 | new_tvec, _, new_R, _ = self.sample(ax, center_std, float(max_angle[0]), old_t, old_R) 232 | 233 | # apply pitch then 234 | ax = self.rotate_axis['pitch'] 235 | new_tvec, new_qvec, _, new_T = self.sample(ax, center_std, float(max_angle[1]), new_tvec, new_R) 236 | 237 | qvecs.append(new_qvec) 238 | tvecs.append(new_tvec) 239 | poses.append(new_T) 240 | else: 241 | # use random axis, with yaw magnitude 242 | new_tvec, new_qvec, _, new_T = self.sample(None, center_std, float(max_angle[0]), old_t, old_R) 243 | 244 | qvecs.append(new_qvec) 245 | tvecs.append(new_tvec) 246 | poses.append(new_T) 247 | 248 | return tvecs, qvecs, poses 249 | 250 | @staticmethod 251 | def sample(axis, center_std: torch.tensor, max_angle: float, old_t: np.array, old_R: np.array, rot_only: bool =False): 252 | old_qvec = rotmat2qvec(old_R) # transform to qvec 253 | 254 | if axis is None: 255 | # if no axis provided, use a random one 256 | axis = Quaternion.random().axis # sample random axis 257 | 258 | teta = np.random.rand()*max_angle # sample random angle smaller than theta 259 | r_quat = Quaternion(axis=axis, degrees=teta) 260 | new_qvec = r_quat * old_qvec # perturb the original pose 261 | 262 | # convert from Quaternion to np.array 263 | new_qvec = new_qvec.elements 264 | new_R = qvec2rotmat(new_qvec) 265 | 266 | if not rot_only: 267 | old_center = - old_R.T @ old_t # get image center using original pose 268 | perturb_c = torch.normal(0., center_std) 269 | new_center = old_center + np.array(perturb_c) # perturb it 270 | new_t = - new_R @ new_center # use the new pose to convert to translation vector 271 | else: 272 | new_t = - new_R @ old_center # use the new pose to convert to translation vector 273 | 274 | new_T = np.eye(4) 275 | new_T[0:3, 0:3] = new_R 276 | new_T[0:3, 3] = new_t 277 | 278 | return new_t, new_qvec, new_R, new_T 279 | -------------------------------------------------------------------------------- /gloc/resamplers/sampling_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyquaternion import Quaternion 3 | 4 | from gloc.utils import qvec2rotmat 5 | 6 | 7 | def gen_translations(border_points, radius, points_per_meter, q_center, axis_up): 8 | theta = np.linspace(0, 2 * np.pi, border_points, endpoint=False) 9 | 10 | x = np.cos(theta)*radius 11 | if axis_up == 'y': 12 | y = np.zeros(x.shape) 13 | z = np.sin(theta)*radius 14 | elif axis_up == 'z': 15 | y = np.sin(theta)*radius 16 | z = np.zeros(x.shape) 17 | else: 18 | raise NotImplementedError() 19 | 20 | points = np.array([x, y, z]).transpose() 21 | n_points = int(radius * points_per_meter+2) 22 | for point in points: 23 | x0 = np.linspace(0, point[0], n_points)[1:-1] 24 | if axis_up == 'y': 25 | y0 = np.zeros(x0.shape) 26 | z0 = np.linspace(0, point[2], n_points)[1:-1] 27 | elif axis_up == 'z': 28 | y0 = np.linspace(0, point[1], n_points)[1:-1] 29 | z0 = np.zeros(x0.shape) 30 | else: 31 | raise NotImplementedError() 32 | 33 | new_p=np.stack((x0, y0, z0)).transpose() 34 | points = np.vstack((points, new_p)) 35 | 36 | points += q_center 37 | points = np.insert(points, 0, q_center, axis=0) 38 | camera_centers = list(points) 39 | return camera_centers 40 | 41 | 42 | def gen_rotations(qvec, R, tvec, c_center, points_per_axis, max_angle, n_axis): 43 | axis_set = [ 44 | [1, 0, 0], 45 | [0, 1, 0], 46 | [1, -1, 0], 47 | [1, 1, 0], 48 | ] 49 | 50 | gen_poses = [(qvec, R, tvec)] 51 | for axis in axis_set[:n_axis]: 52 | theta = np.linspace(-max_angle, max_angle, points_per_axis) 53 | theta = np.delete(theta, points_per_axis // 2) 54 | 55 | for th in theta: 56 | my_quaternion = Quaternion(axis=axis, angle=th) 57 | new_qvec = my_quaternion * qvec 58 | new_R = qvec2rotmat(new_qvec) 59 | new_t = - new_R @ c_center 60 | gen_poses.append((new_qvec, new_R, new_t)) 61 | return gen_poses 62 | 63 | 64 | def parse_pose_data(q_basename, gen_poses, K, sub_index): 65 | render_qvecs = [] 66 | render_ts = [] 67 | calibr_pose = [] 68 | r_names = [] 69 | for i, pose in enumerate(gen_poses): 70 | new_qvec, new_R, new_t = pose 71 | T = np.eye(4) 72 | T[0:3, 0:3] = new_R 73 | T[0:3, 3] = new_t 74 | 75 | render_qvecs.append(new_qvec) 76 | render_ts.append(new_t) 77 | r_names.append(q_basename + f'_{sub_index}' + f'_{i}') 78 | calibr_pose.append((T, K)) 79 | 80 | return render_qvecs, render_ts, r_names, calibr_pose 81 | -------------------------------------------------------------------------------- /gloc/resamplers/scalers.py: -------------------------------------------------------------------------------- 1 | class ConstantScaler(): 2 | def __init__(self, conf): 3 | self.max_angle = conf.max_angle 4 | self.center_std = conf.max_center_std 5 | 6 | def step(self, i): 7 | pass 8 | 9 | def get_noise(self): 10 | return self.center_std, self.max_angle 11 | 12 | def get_max_noise(self, multiplier=1): 13 | return self.center_std*multiplier, self.max_angle*multiplier 14 | 15 | 16 | class UniformScaler(): 17 | def __init__(self, conf): 18 | # gamma is the minimum multiplier that will be applied 19 | self.max_angle = conf.max_angle 20 | self.max_center_std = conf.max_center_std 21 | self.current_angle = conf.max_angle 22 | self.current_center_std = conf.max_center_std 23 | 24 | self.n_steps = conf.N_steps 25 | self.gamma = conf.gamma 26 | 27 | def step(self, i): 28 | scale_noise = max(self.gamma, (self.n_steps - i)/self.n_steps) 29 | 30 | self.current_center_std = self.max_center_std*scale_noise 31 | self.current_angle = self.max_angle* scale_noise 32 | 33 | def get_noise(self): 34 | return self.current_center_std, self.current_angle 35 | 36 | def get_max_noise(self, multiplier=1): 37 | return self.max_center_std*multiplier, self.max_angle*multiplier 38 | -------------------------------------------------------------------------------- /gloc/resamplers/scalers_conf.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | from gloc.resamplers.samplers import (RandomGaussianSampler, RandomSamplerByAxis, 5 | RandomDoubleAxisSampler, RandomAndDoubleAxisSampler) 6 | from gloc.resamplers.scalers import ConstantScaler, UniformScaler 7 | 8 | 9 | def get_sampler(args, sampler_name, scaler_name): 10 | max_angle_delta = torch.tensor(args.teta) 11 | max_center_std = torch.tensor(args.center_std) 12 | 13 | sampler_conf = SamplerConf() 14 | if sampler_name == 'rand': 15 | sampler_class = RandomGaussianSampler 16 | elif sampler_name == 'rand_yaw_or_pitch': 17 | sampler_class = RandomSamplerByAxis 18 | elif sampler_name == 'rand_yaw_and_pitch': 19 | sampler_class = RandomDoubleAxisSampler 20 | elif sampler_name == 'rand_and_yaw_and_pitch': 21 | sampler_class = RandomAndDoubleAxisSampler 22 | else: 23 | raise NotImplementedError() 24 | 25 | if scaler_name == '0': 26 | scaler_conf = ConstantScalerConf(max_center_std=max_center_std, max_angle=max_angle_delta) 27 | scaler_class = ConstantScaler 28 | elif scaler_name == '1': 29 | scaler_conf = UniformScalerConf(max_center_std=max_center_std, max_angle=max_angle_delta, 30 | N_steps=args.steps, gamma=args.gamma) 31 | scaler_class = UniformScaler 32 | else: 33 | raise NotImplementedError() 34 | 35 | sampler = sampler_class(sampler_conf) 36 | scaler = scaler_class(scaler_conf) 37 | 38 | return sampler, scaler 39 | 40 | 41 | @dataclass 42 | class ConstantScalerConf: 43 | max_center_std: torch.tensor = torch.tensor([1, 1, 1]) 44 | max_angle: torch.tensor = torch.tensor([8]) 45 | 46 | 47 | @dataclass 48 | class UniformScalerConf(ConstantScalerConf): 49 | N_steps: int = 20 50 | gamma: float = 0.1 51 | 52 | 53 | @dataclass 54 | class SamplerConf: 55 | pass -------------------------------------------------------------------------------- /gloc/resamplers/strategies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gloc.utils import rotmat2qvec 4 | 5 | 6 | class BaseProtocol: 7 | """This base dummy class serves as template for subclasses. it always returns 8 | the same poses without perturbing them""" 9 | def __init__(self, conf, sampler, scaler, protocol_name): 10 | self.sampler = sampler 11 | self.scaler = scaler 12 | self.n_steps = conf.N_steps 13 | self.n_views = conf.n_views 14 | self.protocol = protocol_name 15 | # init for later 16 | self.center_std = None 17 | self.max_angle = None 18 | 19 | def init_step(self, i): 20 | self.scaler.step(i) 21 | self.center_std, self.max_angle = self.scaler.get_noise() 22 | 23 | def get_pertubr_str(self, step, res): 24 | c_str = "_".join(list(map(lambda x: f'{x:.1f}'.replace('.', ','), map(float, self.center_std)))) 25 | angle_str = "_".join(list(map(lambda x: f'{x:.1f}'.replace('.', ','), map(float, self.max_angle)))) 26 | 27 | perturb_str = f'pt{self.protocol}_s{step}_sz{res}_theta{angle_str}_t{c_str}' 28 | return perturb_str 29 | 30 | @staticmethod 31 | def get_r_name(q_name, r_i, beam_i): 32 | r_name = q_name+f'_{r_i}beam{beam_i}' 33 | return r_name 34 | 35 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs): 36 | # this base class returns the same pose all over again 37 | render_qvecs = [] 38 | render_ts = [] 39 | calibr_pose = [] 40 | r_names = [] 41 | 42 | for i in range(self.n_views): 43 | t = pred_t[i] 44 | R = pred_R[i] 45 | qvec = rotmat2qvec(R) 46 | 47 | render_qvecs.append(qvec) 48 | render_ts.append(t) 49 | r_names.append(BaseProtocol.get_r_name(q_name, i, beam_i)) 50 | T = np.eye(4) 51 | T[0:3, 0:3] = R 52 | T[0:3, 3] = t 53 | calibr_pose.append((T, K)) 54 | 55 | return r_names, render_ts, render_qvecs, calibr_pose 56 | 57 | 58 | class Protocol1(BaseProtocol): 59 | """ 60 | This protocol keeps only the first prediction, to perturb N times 61 | """ 62 | def __init__(self, conf, scaler, sampler, protocol_name): 63 | super().__init__(conf, scaler, sampler, protocol_name) 64 | 65 | # override 66 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs): 67 | render_qvecs = [] 68 | render_ts = [] 69 | calibr_pose = [] 70 | r_names = [] 71 | 72 | t = pred_t[0] # take first prediction 73 | R = pred_R[0] # take first prediction 74 | qvec = rotmat2qvec(R) # transform to qvec 75 | 76 | #### keep previous estimate ##### 77 | render_qvecs.append(qvec) 78 | render_ts.append(t) 79 | r_names.append(BaseProtocol.get_r_name(q_name, 0, beam_i)) 80 | T = np.eye(4) 81 | T[0:3, 0:3] = R 82 | T[0:3, 3] = t 83 | calibr_pose.append((T, K)) 84 | #################### 85 | views_per_candidate = self.n_views - 1 86 | new_ts, new_qvecs, new_poses = self.sampler.sample_batch(views_per_candidate, 87 | self.center_std, self.max_angle, 88 | t, R) 89 | render_ts += new_ts 90 | render_qvecs += new_qvecs 91 | for j in range(views_per_candidate): 92 | r_name = BaseProtocol.get_r_name(q_name, j + 1, beam_i) 93 | r_names.append(r_name) 94 | calibr_pose.append((new_poses[j], K)) 95 | 96 | return r_names, render_ts, render_qvecs, calibr_pose 97 | 98 | 99 | class Protocol2(BaseProtocol): 100 | """ 101 | This protocol keeps the first M predictions, perturbing them N // M times 102 | """ 103 | def __init__(self, conf, scaler, sampler, protocol_name): 104 | super().__init__(conf, scaler, sampler, protocol_name) 105 | self.M = conf.M_candidates 106 | 107 | # override 108 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs): 109 | render_qvecs = [] 110 | render_ts = [] 111 | calibr_pose = [] 112 | r_names = [] 113 | 114 | #### keep previous first M estimates ##### 115 | for i in range(self.M): 116 | t = pred_t[i] 117 | R = pred_R[i] 118 | qvec = rotmat2qvec(R) 119 | 120 | render_qvecs.append(qvec) 121 | render_ts.append(t) 122 | r_names.append(BaseProtocol.get_r_name(q_name, i, beam_i)) 123 | T = np.eye(4) 124 | T[0:3, 0:3] = R 125 | T[0:3, 3] = t 126 | calibr_pose.append((T, K)) 127 | #################### 128 | 129 | views_per_candidate = self.n_views // self.M - 1 130 | for i in range(self.M): 131 | t = pred_t[i] 132 | R = pred_R[i] 133 | 134 | new_ts, new_qvecs, new_poses = self.sampler.sample_batch(views_per_candidate, self.center_std, self.max_angle, 135 | t, R) 136 | render_ts += new_ts 137 | render_qvecs += new_qvecs 138 | for j in range(views_per_candidate): 139 | r_name = BaseProtocol.get_r_name(q_name, self.M+i*views_per_candidate+j, beam_i) 140 | r_names.append(r_name) 141 | calibr_pose.append((new_poses[j], K)) 142 | return r_names, render_ts, render_qvecs, calibr_pose 143 | -------------------------------------------------------------------------------- /gloc/utils/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ['camera_utils', 'utils', 'visualization'] 2 | 3 | 4 | from gloc.utils.camera_utils import (qvec2rotmat, rotmat2qvec, get_c2w_nerfconv, 5 | read_model_nopoints, parse_cam_model, Image) 6 | from gloc.utils import utils 7 | from gloc.utils import visualization 8 | -------------------------------------------------------------------------------- /gloc/utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import collections 4 | import struct 5 | 6 | 7 | #### Code taken from Colmap: 8 | # from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py 9 | CameraModel = collections.namedtuple( 10 | "CameraModel", ["model_id", "model_name", "num_params"]) 11 | Camera = collections.namedtuple( 12 | "Camera", ["id", "model", "width", "height", "params"]) 13 | BaseImage = collections.namedtuple( 14 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]) 15 | Point3D = collections.namedtuple( 16 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]) 17 | 18 | 19 | class Image(BaseImage): 20 | def qvec2rotmat(self): 21 | return qvec2rotmat(self.qvec) 22 | 23 | 24 | CAMERA_MODELS = { 25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3), 26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4), 27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4), 28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5), 29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8), 30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8), 31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12), 32 | CameraModel(model_id=7, model_name="FOV", num_params=5), 33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4), 34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5), 35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12) 36 | } 37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) 38 | for camera_model in CAMERA_MODELS]) 39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model) 40 | for camera_model in CAMERA_MODELS]) 41 | 42 | 43 | 44 | def parse_cam_model(cam_data): 45 | model = cam_data.model 46 | width = cam_data.width 47 | height = cam_data.height 48 | 49 | if model == "SIMPLE_PINHOLE" or model == "SIMPLE_RADIAL" or model == "RADIAL" or model == "SIMPLE_RADIAL_FISHEYE" or model == "RADIAL_FISHEYE": 50 | fx = cam_data.params[0] 51 | fy = fx 52 | cx = cam_data.params[1] 53 | cy = cam_data.params[2] 54 | elif model == "PINHOLE" or model == "OPENCV" or model == "OPENCV_FISHEYE" or model == "FULL_OPENCV" or model == "FOV" or model == "THIN_PRISM_FISHEYE": 55 | fx = cam_data.params[0] 56 | fy = cam_data.params[1] 57 | cx = cam_data.params[2] 58 | cy = cam_data.params[3] 59 | 60 | return {"width":width, "height":height, "fx":fx, "fy":fy, "cx":cx, "cy":cy} 61 | 62 | 63 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"): 64 | """Read and unpack the next bytes from a binary file. 65 | :param fid: 66 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc. 67 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 68 | :param endian_character: Any of {@, =, <, >, !} 69 | :return: Tuple of read and unpacked values. 70 | """ 71 | data = fid.read(num_bytes) 72 | return struct.unpack(endian_character + format_char_sequence, data) 73 | 74 | 75 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"): 76 | """pack and write to a binary file. 77 | :param fid: 78 | :param data: data to send, if multiple elements are sent at the same time, 79 | they should be encapsuled either in a list or a tuple 80 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}. 81 | should be the same length as the data list or tuple 82 | :param endian_character: Any of {@, =, <, >, !} 83 | """ 84 | if isinstance(data, (list, tuple)): 85 | bytes = struct.pack(endian_character + format_char_sequence, *data) 86 | else: 87 | bytes = struct.pack(endian_character + format_char_sequence, data) 88 | fid.write(bytes) 89 | 90 | 91 | def get_c2w_nerfconv(qvec, tvec): 92 | R = qvec2rotmat(qvec) 93 | 94 | c_T_w = np.eye(4) 95 | c_T_w[0:3, 0:3] = R 96 | c_T_w[0:3, 3] = tvec 97 | 98 | c2w = np.linalg.inv(c_T_w) 99 | c2w[0:3,2] *= -1 # flip the y and z axis 100 | c2w[0:3,1] *= -1 101 | # for Aachen, axis in the 3D model are different, 102 | # so comment these 2 lines below 103 | c2w=c2w[[1,0,2,3],:] # swap y and z 104 | c2w[2,:] *= -1 # flip whole world upside down 105 | 106 | return c2w 107 | 108 | 109 | def qvec2rotmat(qvec): 110 | return np.array([ 111 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 114 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 115 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 116 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 117 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 118 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 119 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 120 | 121 | 122 | def rotmat2qvec(R): 123 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 124 | K = np.array([ 125 | [Rxx - Ryy - Rzz, 0, 0, 0], 126 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 127 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 128 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 129 | eigvals, eigvecs = np.linalg.eigh(K) 130 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 131 | if qvec[0] < 0: 132 | qvec *= -1 133 | return qvec 134 | 135 | 136 | def read_cameras_intrinsics(path): 137 | """ 138 | see: src/base/reconstruction.cc 139 | void Reconstruction::WriteCamerasText(const std::string& path) 140 | void Reconstruction::ReadCamerasText(const std::string& path) 141 | """ 142 | cameras = [] 143 | with open(path, "r") as fid: 144 | while True: 145 | line = fid.readline() 146 | if not line: 147 | break 148 | line = line.strip() 149 | if len(line) > 0 and line[0] != "#": 150 | elems = line.split() 151 | #camera_id = int(elems[0]) 152 | path = elems[0] 153 | model = elems[1] 154 | width = int(elems[2]) 155 | height = int(elems[3]) 156 | params = np.array(tuple(map(float, elems[4:]))) 157 | cameras.append(Camera(id=path, model=model, 158 | width=width, height=height, 159 | params=params)) 160 | return cameras 161 | 162 | 163 | def read_cameras_text(path): 164 | """ 165 | see: src/base/reconstruction.cc 166 | void Reconstruction::WriteCamerasText(const std::string& path) 167 | void Reconstruction::ReadCamerasText(const std::string& path) 168 | """ 169 | cameras = {} 170 | with open(path, "r") as fid: 171 | while True: 172 | line = fid.readline() 173 | if not line: 174 | break 175 | line = line.strip() 176 | if len(line) > 0 and line[0] != "#": 177 | elems = line.split() 178 | camera_id = int(elems[0]) 179 | model = elems[1] 180 | width = int(elems[2]) 181 | height = int(elems[3]) 182 | params = np.array(tuple(map(float, elems[4:]))) 183 | cameras[camera_id] = Camera(id=camera_id, model=model, 184 | width=width, height=height, 185 | params=params) 186 | return cameras 187 | 188 | 189 | def read_cameras_binary(path_to_model_file): 190 | """ 191 | see: src/base/reconstruction.cc 192 | void Reconstruction::WriteCamerasBinary(const std::string& path) 193 | void Reconstruction::ReadCamerasBinary(const std::string& path) 194 | """ 195 | cameras = {} 196 | with open(path_to_model_file, "rb") as fid: 197 | num_cameras = read_next_bytes(fid, 8, "Q")[0] 198 | for _ in range(num_cameras): 199 | camera_properties = read_next_bytes( 200 | fid, num_bytes=24, format_char_sequence="iiQQ") 201 | camera_id = camera_properties[0] 202 | model_id = camera_properties[1] 203 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name 204 | width = camera_properties[2] 205 | height = camera_properties[3] 206 | num_params = CAMERA_MODEL_IDS[model_id].num_params 207 | params = read_next_bytes(fid, num_bytes=8*num_params, 208 | format_char_sequence="d"*num_params) 209 | cameras[camera_id] = Camera(id=camera_id, 210 | model=model_name, 211 | width=width, 212 | height=height, 213 | params=np.array(params)) 214 | assert len(cameras) == num_cameras 215 | return cameras 216 | 217 | 218 | 219 | def read_images_text(path): 220 | """ 221 | see: src/base/reconstruction.cc 222 | void Reconstruction::ReadImagesText(const std::string& path) 223 | void Reconstruction::WriteImagesText(const std::string& path) 224 | """ 225 | images = {} 226 | with open(path, "r") as fid: 227 | while True: 228 | line = fid.readline() 229 | if not line: 230 | break 231 | line = line.strip() 232 | if len(line) > 0 and line[0] != "#": 233 | elems = line.split() 234 | image_id = int(elems[0]) 235 | qvec = np.array(tuple(map(float, elems[1:5]))) 236 | tvec = np.array(tuple(map(float, elems[5:8]))) 237 | camera_id = int(elems[8]) 238 | image_name = elems[9] 239 | elems = fid.readline().split() 240 | # xys = np.column_stack([tuple(map(float, elems[0::3])), 241 | # tuple(map(float, elems[1::3]))]) 242 | # point3D_ids = np.array(tuple(map(int, elems[2::3]))) 243 | # images[image_id] = Image( 244 | # id=image_id, qvec=qvec, tvec=tvec, 245 | # camera_id=camera_id, name=image_name, 246 | # xys=xys, point3D_ids=point3D_ids) 247 | images[image_id] = Image( 248 | id=image_id, qvec=qvec, tvec=tvec, 249 | camera_id=camera_id, name=image_name, 250 | xys={}, point3D_ids={}) 251 | return images 252 | 253 | 254 | def read_images_binary(path_to_model_file): 255 | """ 256 | see: src/base/reconstruction.cc 257 | void Reconstruction::ReadImagesBinary(const std::string& path) 258 | void Reconstruction::WriteImagesBinary(const std::string& path) 259 | """ 260 | images = {} 261 | with open(path_to_model_file, "rb") as fid: 262 | num_reg_images = read_next_bytes(fid, 8, "Q")[0] 263 | for _ in range(num_reg_images): 264 | binary_image_properties = read_next_bytes( 265 | fid, num_bytes=64, format_char_sequence="idddddddi") 266 | image_id = binary_image_properties[0] 267 | qvec = np.array(binary_image_properties[1:5]) 268 | tvec = np.array(binary_image_properties[5:8]) 269 | camera_id = binary_image_properties[8] 270 | image_name = "" 271 | current_char = read_next_bytes(fid, 1, "c")[0] 272 | while current_char != b"\x00": # look for the ASCII 0 entry 273 | image_name += current_char.decode("utf-8") 274 | current_char = read_next_bytes(fid, 1, "c")[0] 275 | num_points2D = read_next_bytes(fid, num_bytes=8, 276 | format_char_sequence="Q")[0] 277 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D, 278 | format_char_sequence="ddq"*num_points2D) 279 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])), 280 | tuple(map(float, x_y_id_s[1::3]))]) 281 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3]))) 282 | images[image_id] = Image( 283 | id=image_id, qvec=qvec, tvec=tvec, 284 | camera_id=camera_id, name=image_name, 285 | xys={}, point3D_ids={}) 286 | return images 287 | 288 | 289 | def read_points3D_text(path): 290 | """ 291 | see: src/colmap/scene/reconstruction.cc 292 | void Reconstruction::ReadPoints3DText(const std::string& path) 293 | void Reconstruction::WritePoints3DText(const std::string& path) 294 | """ 295 | points3D = {} 296 | with open(path, "r") as fid: 297 | while True: 298 | line = fid.readline() 299 | if not line: 300 | break 301 | line = line.strip() 302 | if len(line) > 0 and line[0] != "#": 303 | elems = line.split() 304 | point3D_id = int(elems[0]) 305 | xyz = np.array(tuple(map(float, elems[1:4]))) 306 | rgb = np.array(tuple(map(int, elems[4:7]))) 307 | error = float(elems[7]) 308 | image_ids = np.array(tuple(map(int, elems[8::2]))) 309 | point2D_idxs = np.array(tuple(map(int, elems[9::2]))) 310 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb, 311 | error=error, image_ids=image_ids, 312 | point2D_idxs=point2D_idxs) 313 | return points3D 314 | 315 | 316 | def read_points3D_binary(path_to_model_file): 317 | """ 318 | see: src/colmap/scene/reconstruction.cc 319 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 320 | void Reconstruction::WritePoints3DBinary(const std::string& path) 321 | """ 322 | points3D = {} 323 | with open(path_to_model_file, "rb") as fid: 324 | num_points = read_next_bytes(fid, 8, "Q")[0] 325 | for _ in range(num_points): 326 | binary_point_line_properties = read_next_bytes( 327 | fid, num_bytes=43, format_char_sequence="QdddBBBd") 328 | point3D_id = binary_point_line_properties[0] 329 | xyz = np.array(binary_point_line_properties[1:4]) 330 | rgb = np.array(binary_point_line_properties[4:7]) 331 | error = np.array(binary_point_line_properties[7]) 332 | track_length = read_next_bytes( 333 | fid, num_bytes=8, format_char_sequence="Q")[0] 334 | track_elems = read_next_bytes( 335 | fid, num_bytes=8*track_length, 336 | format_char_sequence="ii"*track_length) 337 | image_ids = np.array(tuple(map(int, track_elems[0::2]))) 338 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2]))) 339 | points3D[point3D_id] = Point3D( 340 | id=point3D_id, xyz=xyz, rgb=rgb, 341 | error=error, image_ids=image_ids, 342 | point2D_idxs=point2D_idxs) 343 | return points3D 344 | 345 | 346 | def detect_model_format(path, ext): 347 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \ 348 | os.path.isfile(os.path.join(path, "images" + ext)): 349 | print("Detected model format: '" + ext + "'") 350 | return True 351 | 352 | return False 353 | 354 | 355 | def read_model_nopoints(path, ext=""): 356 | # try to detect the extension automatically 357 | if ext == "": 358 | if detect_model_format(path, ".bin"): 359 | ext = ".bin" 360 | elif detect_model_format(path, ".txt"): 361 | ext = ".txt" 362 | else: 363 | print("Provide model format: '.bin' or '.txt'") 364 | return 365 | 366 | if ext == ".txt": 367 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 368 | images = read_images_text(os.path.join(path, "images" + ext)) 369 | else: 370 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 371 | images = read_images_binary(os.path.join(path, "images" + ext)) 372 | return cameras, images 373 | 374 | 375 | def read_model(path, ext=""): 376 | # try to detect the extension automatically 377 | if ext == "": 378 | if detect_model_format(path, ".bin"): 379 | ext = ".bin" 380 | elif detect_model_format(path, ".txt"): 381 | ext = ".txt" 382 | else: 383 | print("Provide model format: '.bin' or '.txt'") 384 | return 385 | 386 | if ext == ".txt": 387 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext)) 388 | images = read_images_text(os.path.join(path, "images" + ext)) 389 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext) 390 | else: 391 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext)) 392 | images = read_images_binary(os.path.join(path, "images" + ext)) 393 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext) 394 | return cameras, images, points3D 395 | 396 | 397 | def write_cameras_text(cameras, path, header=True): 398 | """ 399 | see: src/base/reconstruction.cc 400 | void Reconstruction::WriteCamerasText(const std::string& path) 401 | void Reconstruction::ReadCamerasText(const std::string& path) 402 | """ 403 | HEADER = "# Camera list with one line of data per camera:\n" + \ 404 | "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \ 405 | "# Number of cameras: {}\n".format(len(cameras)) 406 | with open(path, "w") as fid: 407 | if header: 408 | fid.write(HEADER) 409 | for _, cam in cameras.items(): 410 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params] 411 | line = " ".join([str(elem) for elem in to_write]) 412 | fid.write(line + "\n") 413 | 414 | 415 | def write_cameras_binary(cameras, path_to_model_file): 416 | """ 417 | see: src/colmap/scene/reconstruction.cc 418 | void Reconstruction::WriteCamerasBinary(const std::string& path) 419 | void Reconstruction::ReadCamerasBinary(const std::string& path) 420 | """ 421 | with open(path_to_model_file, "wb") as fid: 422 | write_next_bytes(fid, len(cameras), "Q") 423 | for _, cam in cameras.items(): 424 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id 425 | camera_properties = [cam.id, 426 | model_id, 427 | cam.width, 428 | cam.height] 429 | write_next_bytes(fid, camera_properties, "iiQQ") 430 | for p in cam.params: 431 | write_next_bytes(fid, float(p), "d") 432 | return cameras 433 | 434 | 435 | def write_images_text(images, path): 436 | """ 437 | see: src/colmap/scene/reconstruction.cc 438 | void Reconstruction::ReadImagesText(const std::string& path) 439 | void Reconstruction::WriteImagesText(const std::string& path) 440 | """ 441 | if len(images) == 0: 442 | mean_observations = 0 443 | else: 444 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images) 445 | HEADER = "# Image list with two lines of data per image:\n" + \ 446 | "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \ 447 | "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \ 448 | "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations) 449 | 450 | with open(path, "w") as fid: 451 | fid.write(HEADER) 452 | for _, img in images.items(): 453 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name] 454 | first_line = " ".join(map(str, image_header)) 455 | fid.write(first_line + "\n") 456 | 457 | points_strings = [] 458 | for xy, point3D_id in zip(img.xys, img.point3D_ids): 459 | points_strings.append(" ".join(map(str, [*xy, point3D_id]))) 460 | fid.write(" ".join(points_strings) + "\n") 461 | 462 | 463 | def write_images_binary(images, path_to_model_file): 464 | """ 465 | see: src/colmap/scene/reconstruction.cc 466 | void Reconstruction::ReadImagesBinary(const std::string& path) 467 | void Reconstruction::WriteImagesBinary(const std::string& path) 468 | """ 469 | with open(path_to_model_file, "wb") as fid: 470 | write_next_bytes(fid, len(images), "Q") 471 | for _, img in images.items(): 472 | write_next_bytes(fid, img.id, "i") 473 | write_next_bytes(fid, img.qvec.tolist(), "dddd") 474 | write_next_bytes(fid, img.tvec.tolist(), "ddd") 475 | write_next_bytes(fid, img.camera_id, "i") 476 | for char in img.name: 477 | write_next_bytes(fid, char.encode("utf-8"), "c") 478 | write_next_bytes(fid, b"\x00", "c") 479 | write_next_bytes(fid, len(img.point3D_ids), "Q") 480 | for xy, p3d_id in zip(img.xys, img.point3D_ids): 481 | write_next_bytes(fid, [*xy, p3d_id], "ddq") 482 | 483 | 484 | def write_points3D_text(points3D, path): 485 | """ 486 | see: src/colmap/scene/reconstruction.cc 487 | void Reconstruction::ReadPoints3DText(const std::string& path) 488 | void Reconstruction::WritePoints3DText(const std::string& path) 489 | """ 490 | if len(points3D) == 0: 491 | mean_track_length = 0 492 | else: 493 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D) 494 | HEADER = "# 3D point list with one line of data per point:\n" + \ 495 | "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \ 496 | "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length) 497 | 498 | with open(path, "w") as fid: 499 | fid.write(HEADER) 500 | for _, pt in points3D.items(): 501 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error] 502 | fid.write(" ".join(map(str, point_header)) + " ") 503 | track_strings = [] 504 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs): 505 | track_strings.append(" ".join(map(str, [image_id, point2D]))) 506 | fid.write(" ".join(track_strings) + "\n") 507 | 508 | 509 | def write_points3D_binary(points3D, path_to_model_file): 510 | """ 511 | see: src/colmap/scene/reconstruction.cc 512 | void Reconstruction::ReadPoints3DBinary(const std::string& path) 513 | void Reconstruction::WritePoints3DBinary(const std::string& path) 514 | """ 515 | with open(path_to_model_file, "wb") as fid: 516 | write_next_bytes(fid, len(points3D), "Q") 517 | for _, pt in points3D.items(): 518 | write_next_bytes(fid, pt.id, "Q") 519 | write_next_bytes(fid, pt.xyz.tolist(), "ddd") 520 | write_next_bytes(fid, pt.rgb.tolist(), "BBB") 521 | write_next_bytes(fid, pt.error, "d") 522 | track_length = pt.image_ids.shape[0] 523 | write_next_bytes(fid, track_length, "Q") 524 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs): 525 | write_next_bytes(fid, [image_id, point2D_id], "ii") 526 | 527 | 528 | def write_model(cameras, images, points3D, path, ext=".bin"): 529 | if ext == ".txt": 530 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 531 | write_images_text(images, os.path.join(path, "images" + ext)) 532 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext) 533 | else: 534 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 535 | write_images_binary(images, os.path.join(path, "images" + ext)) 536 | write_points3D_binary(points3D, os.path.join(path, "points3D") + ext) 537 | return cameras, images, points3D 538 | 539 | 540 | def write_model_nopoints(cameras, images, path, ext=".txt"): 541 | if ext == ".txt": 542 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext)) 543 | write_images_text(images, os.path.join(path, "images" + ext)) 544 | else: 545 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext)) 546 | write_images_binary(images, os.path.join(path, "images" + ext)) 547 | -------------------------------------------------------------------------------- /gloc/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from math import ceil 3 | import numpy as np 4 | import einops 5 | import faiss 6 | from os.path import join 7 | 8 | from gloc.utils import qvec2rotmat, rotmat2qvec 9 | 10 | threshs_t = [ 0.25, 0.5, 5.0, 10.0] 11 | threshs_R = [2.0, 5.0, 10.0, 15.0] 12 | 13 | 14 | def log_pose_estimate(render_dir, pd, pred_R, pred_t, flat_preds=None, top_ns=[3, 6]): 15 | f_results = join(render_dir, 'est_poses.txt') 16 | is_aachen = 'Aachen' in pd.name 17 | print(f'WRITING TO {f_results}') 18 | with open(f_results, 'w') as f: 19 | for q_idx in range(len(pd.q_frames_idxs)): 20 | idx = pd.q_frames_idxs[q_idx] 21 | name = pd.images[idx].name 22 | if flat_preds is not None: 23 | qvec = rotmat2qvec(pred_R[q_idx][flat_preds[q_idx,0]]) 24 | tvec = pred_t[q_idx][flat_preds[q_idx,0]] 25 | else: 26 | qvec = rotmat2qvec(pred_R[q_idx][0]) 27 | tvec = pred_t[q_idx][0] 28 | 29 | if is_aachen: 30 | name = os.path.basename(name) 31 | qvec = ' '.join(map(str, qvec)) 32 | tvec = ' '.join(map(str, tvec)) 33 | f.write(f'{name} {qvec} {tvec}\n') 34 | 35 | for topn in top_ns: 36 | tn_results = join(render_dir, f'top{topn}_est_poses.txt') 37 | print(f'WRITING TO {tn_results}') 38 | 39 | with open(tn_results, 'w') as f: 40 | for q_idx in range(len(pd.q_frames_idxs)): 41 | idx = pd.q_frames_idxs[q_idx] 42 | name = pd.images[idx].name 43 | for i in range(topn): 44 | if flat_preds is not None: 45 | qvec = rotmat2qvec(pred_R[q_idx][flat_preds[q_idx, i]]) 46 | tvec = pred_t[q_idx][flat_preds[q_idx, i]] 47 | else: 48 | qvec = rotmat2qvec(pred_R[q_idx, i]) 49 | tvec = pred_t[q_idx, i] 50 | 51 | name = os.path.basename(name) 52 | qvec = ' '.join(map(str, qvec)) 53 | tvec = ' '.join(map(str, tvec)) 54 | f.write(f'{name} {qvec} {tvec}\n') 55 | 56 | return f_results 57 | 58 | 59 | def load_pose_prior(pose_file, pd, M=1): 60 | with open(pose_file, 'r') as fp: 61 | est_poses = fp.readlines() 62 | # the format is: 'basename_img.ext qw qx qy qz tx ty tz\n' 63 | est_poses = list(map(lambda x: x.strip().split(' '), est_poses)) 64 | poses_dict = {} 65 | for pose in est_poses: 66 | qvec_float = list(map(float, pose[1:5])) 67 | tvec_float = list(map(float, pose[5:8])) 68 | if pose[0] not in poses_dict: 69 | poses_dict[pose[0]] = [] 70 | 71 | poses_dict[pose[0]].append( (np.array(qvec_float), np.array(tvec_float)) ) 72 | 73 | all_pred_t, all_pred_R = np.empty((pd.n_q, M, 3)), np.empty((pd.n_q, M, 3, 3)) 74 | 75 | if pd.name in ['Aachen_night', 'Aachen_day', 'Aachen_real', 'Aachen_real_und']: 76 | get_q_key = lambda x: os.path.basename(x) 77 | else: 78 | get_q_key = lambda x: x 79 | for q_idx in range(len(pd.q_frames_idxs)): 80 | idx = pd.q_frames_idxs[q_idx] 81 | #q_key = os.path.basename(pd.images[idx].name) 82 | q_key = get_q_key(pd.images[idx].name) 83 | poses_q = poses_dict[q_key] 84 | 85 | if len(poses_q) == 1: 86 | qvec, tvec = poses_q[0] 87 | R = qvec2rotmat(qvec).reshape(-1, 3, 3) 88 | R_rp = np.repeat(R, M, axis=0) 89 | tvec_rp = np.repeat(tvec.reshape(-1, 3), M, axis=0) 90 | 91 | all_pred_t[q_idx] = tvec_rp 92 | all_pred_R[q_idx] = R_rp 93 | 94 | else: 95 | assert len(poses_q) >= M, f'This query has {len(poses_q)} poses and you asked for {M}' 96 | for i in range(M): 97 | qvec, tvec = poses_q[i] 98 | R = qvec2rotmat(qvec) 99 | tvec = tvec 100 | all_pred_t[q_idx, i] = tvec 101 | all_pred_R[q_idx, i] = R 102 | 103 | return all_pred_t, all_pred_R 104 | 105 | 106 | def reshape_preds_per_beam(n_beams, M, preds): 107 | n_dims = len(preds.shape) 108 | to_stack = [] 109 | for i in range(n_beams): 110 | # with n-beams, beam i gets the i-th cand., and the i+n, i+2, so on 111 | # so take one every n 112 | ts = preds[:,i::n_beams][:,:M, :] 113 | to_stack.append(ts) 114 | 115 | if n_dims == 3: 116 | stacked = np.hstack(to_stack).reshape(-1, n_beams, M, 3) 117 | elif n_dims == 4: 118 | stacked = np.hstack(to_stack).reshape(-1, n_beams, M, 3, 3) 119 | 120 | return stacked 121 | 122 | 123 | def repeat_first_preds_per_beam(n_beams, M, preds): 124 | first_preds = preds[:, :n_beams, :] 125 | # go from (Q, n_beams, 3/3,3), to (Q, n_beams, 1, 3/3,3) 126 | # so that then the first pred. in each beam can be repeated N times 127 | first_preds = np.expand_dims(first_preds, axis=2) 128 | result = np.repeat(first_preds, M, axis=2) 129 | 130 | return result 131 | 132 | 133 | def get_n_steps(num_queries, render_per_step, max_steps, renderer, hard_stop): 134 | if hard_stop > 0: 135 | return hard_stop 136 | if renderer != 'o3d': 137 | return max_steps 138 | # due to open3d bug, black images after 1e5 renders 139 | max_renders = 1e5 140 | n_steps = ceil(max_renders / (num_queries*render_per_step)) 141 | return n_steps 142 | 143 | 144 | def eval_poses(errors_t, errors_R, descr=''): 145 | med_t = np.median(errors_t) 146 | med_R = np.median(errors_R) 147 | out = f'Results {descr}:' 148 | out += f'\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg' 149 | out_vals = [] 150 | 151 | out += '\nPercentage of test images localized within:' 152 | for th_t, th_R in zip(threshs_t, threshs_R): 153 | ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) 154 | out += f'\n\t{th_t:.2f}m, {th_R:.0f}deg : {ratio*100:.2f}%' 155 | out_vals.append(ratio) 156 | 157 | return out, np.array(out_vals) 158 | 159 | 160 | def eval_poses_top_n(all_errors_t, all_errors_R, descr=''): 161 | best_candidates = [1, 5, max(20, all_errors_R.shape[1])] 162 | 163 | med_t = np.median(all_errors_t[:, 0]) 164 | med_best_t = np.median(all_errors_t.min(axis=1)) 165 | 166 | med_R = np.median(all_errors_R[:, 0]) 167 | med_best_R = np.median(all_errors_R.min(axis=1)) 168 | 169 | out = f'Results {descr}:' 170 | out += f'\nMedian errors on first/best: {med_t:.2f}m, {med_R:.2f}deg // {med_best_t:.2f}m, {med_best_R:.2f}deg' 171 | out_vals = np.zeros((len(best_candidates), len(threshs_t))) 172 | 173 | out += f"\nPercentage of test images localized within (TOP {'|TOP '.join(map(str, best_candidates))}):" 174 | for i, (th_t, th_R) in enumerate(zip(threshs_t, threshs_R)): 175 | out += f'\n\t{th_t:.2f}m, {int(th_R):2d}deg :' 176 | for j, best in enumerate(best_candidates): 177 | ratio = np.mean( ((all_errors_t[:, :best] < th_t) & (all_errors_R[:, :best] < th_R)).any(axis=1) ) 178 | out += f' {ratio*100:4.1f}% |' 179 | out_vals[j, i] = ratio 180 | 181 | return out, out_vals 182 | 183 | 184 | def get_predictions(db_descriptors, q_descriptors, pose_dataset, fc_output_dim=512, top_k=20): 185 | pd = pose_dataset 186 | 187 | # Use a kNN to find predictions 188 | faiss_index = faiss.IndexFlatL2(fc_output_dim) 189 | faiss_index.add(db_descriptors) 190 | 191 | _, predictions = faiss_index.search(q_descriptors, top_k) 192 | pred_Rs = [] 193 | pred_ts = [] 194 | for q_idx in range(len(pd.q_frames_idxs)): 195 | pred_Rs_per_query = [] 196 | pred_ts_per_query = [] 197 | for k in range(top_k): 198 | pred_idx = pd.db_frames_idxs[predictions[q_idx, k]] 199 | R = qvec2rotmat(pd.images[pred_idx].qvec) 200 | t = pd.images[pred_idx].tvec 201 | pred_Rs_per_query.append(R) 202 | pred_ts_per_query.append(t) 203 | 204 | pred_Rs.append(np.array(pred_Rs_per_query)) 205 | pred_ts.append(np.array(pred_ts_per_query)) 206 | 207 | return np.array(pred_ts), np.array(pred_Rs) 208 | 209 | 210 | def get_error(R, t, R_gt, t_gt): 211 | e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) 212 | cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1., 1.) 213 | e_R = np.rad2deg(np.abs(np.arccos(cos))) 214 | 215 | return e_t, e_R 216 | 217 | 218 | def get_errors_from_preds(true_t, true_R, pred_t, pred_R, top_k=20): 219 | errors_t = [] 220 | errors_R = [] 221 | 222 | top_k = min(top_k, pred_t.shape[1]) 223 | for k in range(top_k): 224 | e_t, e_R = get_error(pred_R[0, k], pred_t[0, k], true_R[0], true_t[0]) 225 | 226 | errors_t.append(e_t) 227 | errors_R.append(e_R) 228 | errors_t, errors_R = np.array(errors_t), np.array(errors_R) 229 | 230 | return errors_t, errors_R 231 | 232 | 233 | def get_all_errors_first_estimate(true_t, true_R, pred_t, pred_R): 234 | errors_t = [] 235 | errors_R = [] 236 | n_queries = len(pred_R) 237 | for q_idx in range(n_queries): 238 | e_t, e_R = get_error(pred_R[q_idx, 0], pred_t[q_idx, 0], true_R[q_idx], true_t[q_idx]) 239 | 240 | errors_t.append(e_t) 241 | errors_R.append(e_R) 242 | errors_t = np.array(errors_t) 243 | errors_R = np.array(errors_R) 244 | 245 | return errors_t, errors_R 246 | 247 | 248 | def get_pose_from_preds_w_truth(q_idx, pd, rd, predictions, top_k=20): 249 | true_Rs = [] 250 | true_ts = [] 251 | pred_Rs = [] 252 | pred_ts = [] 253 | 254 | idx = pd.q_frames_idxs[q_idx] 255 | R_gt = qvec2rotmat(pd.images[idx].qvec) 256 | t_gt = pd.images[idx].tvec 257 | true_Rs.append(R_gt) 258 | true_ts.append(t_gt) 259 | 260 | pred_Rs_per_query = [] 261 | pred_ts_per_query = [] 262 | 263 | top_k = min(top_k, len(predictions)) 264 | for k in range(top_k): 265 | pred_idx = predictions[k] 266 | 267 | R = qvec2rotmat(rd.images[pred_idx].qvec) 268 | t = rd.images[pred_idx].tvec 269 | pred_Rs_per_query.append(R) 270 | pred_ts_per_query.append(t) 271 | 272 | pred_Rs.append(np.array(pred_Rs_per_query)) 273 | pred_ts.append(np.array(pred_ts_per_query)) 274 | 275 | true_t, true_R, pred_t, pred_R= np.array(true_ts), np.array(true_Rs), np.array(pred_ts), np.array(pred_Rs) 276 | return true_t, true_R, pred_t, pred_R 277 | 278 | 279 | def get_pose_from_preds(q_idx, pd, rd, predictions, top_k=20): 280 | pred_Rs = [] 281 | pred_ts = [] 282 | 283 | pred_Rs_per_query = [] 284 | pred_ts_per_query = [] 285 | 286 | top_k = min(top_k, len(predictions)) 287 | for k in range(top_k): 288 | pred_idx = predictions[k] 289 | 290 | R = qvec2rotmat(rd.images[pred_idx].qvec) 291 | t = rd.images[pred_idx].tvec 292 | pred_Rs_per_query.append(R) 293 | pred_ts_per_query.append(t) 294 | 295 | pred_Rs.append(np.array(pred_Rs_per_query)) 296 | pred_ts.append(np.array(pred_ts_per_query)) 297 | 298 | pred_t, pred_R= np.array(pred_ts), np.array(pred_Rs) 299 | return pred_t, pred_R 300 | 301 | 302 | def sort_preds_across_beams(all_scores, all_pred_t, all_pred_R, all_errors_t, all_errors_R): 303 | # flatten stuff to sort predictions based on similarity 304 | flat_err = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)') 305 | flat_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3) 306 | flat_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3) 307 | flat_preds = np.argsort(flat_err(all_scores)) 308 | all_errors_t = np.take_along_axis(flat_err(all_errors_t), flat_preds, axis=1) 309 | all_errors_R = np.take_along_axis(flat_err(all_errors_R), flat_preds, axis=1) 310 | flat_pred_t = flat_t(all_pred_t) 311 | flat_pred_R = flat_R(all_pred_R) 312 | 313 | return flat_pred_R, flat_pred_t, flat_preds, all_errors_t, all_errors_R 314 | 315 | 316 | def update_scores(scores, scores_temp): 317 | if scores is None: 318 | scores = scores_temp 319 | else: 320 | scores['steps'] += scores_temp['steps'] 321 | 322 | return scores 323 | -------------------------------------------------------------------------------- /gloc/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from os.path import join 4 | import matplotlib as mpl 5 | import matplotlib.pyplot as plt 6 | import cv2 7 | 8 | from gloc.utils.utils import threshs_R, threshs_t 9 | 10 | 11 | def plot_error_distr(err_R, err_t, step_num, save_dir, f_name): 12 | fig, axi = plt.subplots(1, 2, figsize=(12,6), dpi=80) 13 | 14 | meds = [np.median(err_R[:,0]), np.median(err_t[:,0])] 15 | for i, errors in enumerate([err_R, err_t]): 16 | ax = axi[i] 17 | for label in (ax.get_xticklabels() + ax.get_yticklabels()): 18 | label.set_fontsize(16) 19 | med = meds[i] 20 | ax.hist(errors[:, 0], bins=50, range=(0, 20), alpha=0.8) 21 | ax.vlines(med, 0, 50, colors='red', label='50th perc.') 22 | ax.legend(prop={"size":15}) 23 | 24 | fig.suptitle(f'Step {step_num}\nMedian errors: ( {meds[1]:.1f} m, {meds[0]:.1f}° )', fontsize=18) 25 | plt.tight_layout() 26 | plt.savefig(join(save_dir, f_name), bbox_inches='tight', pad_inches=0.0) 27 | plt.close() 28 | 29 | 30 | def plot_scores(scores, out_dir): 31 | threshs = list(zip(threshs_t, threshs_R)) 32 | 33 | steps = np.array(scores['steps']) 34 | fig, axis = plt.subplots(1, 2, figsize=(14, 7), dpi=100) 35 | x=list(range(len(steps))) 36 | 37 | for i, idx in enumerate([1, 2]): 38 | ax = axis[i] 39 | 40 | ax.plot(x, steps[:, 0, idx], label=f'Recall') 41 | ax.plot(x, steps[:, 2, idx], label=f'Upper bound') 42 | ax.hlines(scores['baseline'][idx], x[0], x[-1], colors='red', linestyles='dashed', label='baseline') 43 | ax.set_title(f'Threshold= {threshs[idx]}') 44 | ax.legend() 45 | plt.tight_layout() 46 | plt.savefig(f'{out_dir}/scores_plot.png') 47 | -------------------------------------------------------------------------------- /parse_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os.path import join 3 | from datetime import datetime 4 | 5 | 6 | def parse_args(): 7 | parser = argparse.ArgumentParser(description='Argument parser') 8 | # exp args 9 | parser.add_argument('name', type=str, help='DS name', 10 | choices=['Aachen', 'Aachen_real', 'Aachen_day', 'Aachen_night', 'Aachen_real_und', 'Aachen_small', 11 | 'KingsCollege', 'KingsCollege_und', 'StMarysChurch_und', 12 | 'ShopFacade', 'ShopFacade_und', 13 | 'OldHospital', 'OldHospital_und', 14 | 'chess', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin', 'heads']) 15 | parser.add_argument('--exp_name', type=str, help='log folder', default='default') 16 | parser.add_argument('--res', type=int, help='resolution', default=320) 17 | parser.add_argument('--seed', type=int, help='seed', default=0) 18 | parser.add_argument('--first_step', type=int, help='start from', default=None) 19 | parser.add_argument('--hard_stop', type=int, help='interrupt at step N, but dont consider it for scaling noise', default=-1) 20 | parser.add_argument('--resume_step', type=str, help='resume folder', default=None) 21 | parser.add_argument('--save_feats', action='store_true', help='seed', default=False) 22 | parser.add_argument('--pose_prior', type=str, help='start from a pose prior in this file', default=None) 23 | parser.add_argument('--clean_logs', action='store_true', help='remove renderings in the end', default=False) 24 | parser.add_argument('--chunk_size', type=int, help='n feats at a time', default=1100) 25 | 26 | # model args 27 | parser.add_argument('--retr_model', type=str, help='retrieval model', default='cosplace', choices=['cosplace']) 28 | parser.add_argument('--ref_model', type=str, help='fine model', default='DenseFeatures', 29 | choices=['DenseFeatures']) 30 | parser.add_argument('--feat_model', type=str, help='refinement model arch', default='cosplace', 31 | choices=['', 32 | 'cosplace_r18_l1', 'cosplace_r18_l2', 'cosplace_r18_l3', 33 | 'cosplace_r50_l1', 'cosplace_r50_l2', 'cosplace_r50_l3', 34 | 'resnet18_l1', 'resnet18_l2', 'resnet18_l3', 35 | 'resnet50_l1', 'resnet50_l2', 'resnet50_l3', 36 | 'alexnet_l1', 'alexnet_l2', 'alexnet_l3', 37 | 'Dinov2', 'Roma', 38 | ]) 39 | 40 | # fine models args 41 | parser.add_argument('--clamp_score', type=float, help='thresholded scoring function', default=-1) 42 | parser.add_argument('--feat_level', nargs='+', type=int, help='Level of features for ALIKED', default=[-1]) 43 | parser.add_argument('--scale_fmaps', type=int, help='Scale F.maps to 1/n', default=6) 44 | 45 | # path args 46 | parser.add_argument("--storage_dir", type=str, default='/storage/gtrivigno/vloc/renderings', help='model path') 47 | parser.add_argument("--fix_storage", action='store_true', default=False, help='model path') 48 | 49 | # render args 50 | parser.add_argument('--colmap_res', type=int, help='res', default=320) 51 | parser.add_argument('--mesh', type=str, help='mesh type', choices=['colored', 'colored_14', 'colored_15', 'textured'], default='colored') 52 | parser.add_argument('--renderer', type=str, help='renderer type', choices=['o3d', 'nerf', 'g_splatting'], default='o3d') 53 | 54 | # perturb args 55 | parser.add_argument('-pt', '--protocol', type=str, help='protocol', 56 | choices=['1_0', '1_1', '2_0', '2_1'], default='2_1') 57 | parser.add_argument('--sampler', type=str, help='sampler', default='rand', 58 | choices=['rand', 'rand_yaw_or_pitch', 'rand_yaw_and_pitch', 'rand_and_yaw_and_pitch']) 59 | parser.add_argument('--beams', type=int, help='N. beams to optimize independetly', default=1) 60 | parser.add_argument('--steps', type=int, help='iterations', default=20) 61 | parser.add_argument('--teta', nargs='+', type=float, help='max angle', default=[8]) 62 | parser.add_argument('--center_std', nargs='+', type=float, default=[1., 1., 1.]) 63 | parser.add_argument('--N', type=int, help='N views to render in total, per query', default=20) 64 | parser.add_argument('--M', type=int, help='In each beam, perturb the first M rather than only first cand.', default=4) 65 | parser.add_argument('--gamma', type=float, help='min scale', default=0.1) 66 | 67 | # eval scripts 68 | parser.add_argument("--eval_renders_dir", type=str, default='', help='eval render dir') 69 | 70 | # random args 71 | parser.add_argument('--only_step', type=int, help='iterations', default=-1) 72 | 73 | args = parser.parse_args() 74 | args.save_dir = join("logs", args.exp_name, datetime.now().strftime('%Y-%m-%d_%H-%M-%S')) 75 | 76 | ## some consistency checks 77 | assert args.N % args.beams == 0, 'N (total views to rend) has to be a multiple of N. beams' 78 | if args.mesh == 'textured' and not args.name.startswith('Aachen'): 79 | raise ValueError('Textured mesh is only available for Aachen') 80 | if args.protocol[0] not in ['0', '1']: 81 | assert args.N % args.M == 0, f'In protocol 2, N ({args.N}) has to be a multiple of M ({args.M})' 82 | assert (args.N // args.beams) % args.M == 0, f'In protocols with M!=1, N/beams ({args.N//args.beams}) has to be a multiple of M ({args.M})' 83 | if 'yaw_and_pitch' in args.sampler: 84 | assert len(args.teta) == 2, f'Sampler {args.sampler} requires 2 angles, 1 for yaw, 1 for pitch' 85 | else: 86 | assert len(args.teta) == 1, f'Sampler {args.sampler} requires only 1 angle' 87 | 88 | return args -------------------------------------------------------------------------------- /path_configs.py: -------------------------------------------------------------------------------- 1 | from os.path import join 2 | 3 | 4 | def get_paths(base_path, colmap_res, mesh_type): 5 | aachen_meshes = { 6 | 'colored': 'AC13_colored.ply', 7 | 'colored_14': 'AC14_colored.ply', 8 | 'colored_15': 'AC15_colored.ply', 9 | 'textured': 'AC13-C_textured/AC13-C_textured.obj', 10 | } 11 | 12 | paths_conf = { 13 | # Aachen 14 | 'Aachen': { 15 | 'root': f'{base_path}/Aachen-Day-Night/images/images_upright', 16 | 'colmap': f'{base_path}/all_colmaps/{colmap_res}_undist', 17 | 'q_file': '', 18 | 'db_file': '', 19 | 'q_intrinsics': [f'{base_path}/Aachen-Day-Night/queries/undist_rs{colmap_res}_allquery_intrinsics.txt'], 20 | 'mesh_path': f'{base_path}/meshes', 21 | }, 22 | } 23 | 24 | # Cambridge scenes 25 | cambridge_scenes = ['StMarysChurch', 'OldHospital', 'KingsCollege', 'ShopFacade'] 26 | gs_models = { 27 | 'StMarysChurch': 'gauss_church', 28 | 'OldHospital': 'gauss_hosp', 29 | 'KingsCollege': 'gauss_kings', 30 | 'ShopFacade': 'gauss_shop', 31 | } 32 | for cs in cambridge_scenes: 33 | paths_conf[cs] = { 34 | 'root': f'{base_path}/{cs}', 35 | 'colmap': f'{base_path}/all_colmaps/{cs}/{colmap_res}_undist/sparse/0', 36 | 'q_file': f'{base_path}/{cs}/dataset_test.txt', 37 | 'db_file': f'{base_path}/{cs}/dataset_train.txt', 38 | 'mesh_path': NotImplemented, 39 | 'ns_config': NotImplemented, 40 | 'ns_transform': NotImplemented, 41 | 'gaussians': f'{base_path}/cambridge_splats/{gs_models[cs]}', 42 | } 43 | 44 | # 7scenes 45 | scenes7 = ['chess', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin', 'heads'] 46 | for sc in scenes7: 47 | paths_conf[sc] = { 48 | 'root': f'{base_path}/7scenes/{sc}', 49 | 'colmap': f'{base_path}/all_colmaps/{sc}/colmap_{colmap_res}', 50 | 'q_file': f'{base_path}/7scenes/{sc}/test.txt', 51 | 'db_file': f'{base_path}/7scenes/{sc}/train.txt', 52 | 'mesh_path': NotImplemented, 53 | 'ns_config': NotImplemented, 54 | 'ns_transform': NotImplemented, 55 | 'gaussians': f'{base_path}/7scenes_splats/{sc}', 56 | } 57 | ###################################### 58 | 59 | paths_conf['Aachen']['mesh_path'] = join(paths_conf['Aachen']['mesh_path'], aachen_meshes.get(mesh_type, '')) 60 | 61 | return paths_conf 62 | 63 | 64 | def get_path_conf(colmap_res, mesh_type): 65 | base_path = 'data' 66 | temp_path = 'data/temp' 67 | 68 | conf = get_paths(base_path, colmap_res, mesh_type) 69 | conf['temp'] = temp_path 70 | return conf 71 | -------------------------------------------------------------------------------- /refine_pose.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import shutil 5 | import torch 6 | from os.path import join 7 | from tqdm import tqdm 8 | import numpy as np 9 | from torch.utils.data.dataset import Subset 10 | import torchvision.transforms as T 11 | 12 | import commons 13 | from parse_args import parse_args 14 | from path_configs import get_path_conf 15 | from gloc import extraction 16 | from gloc import initialization 17 | from gloc import rendering 18 | from gloc.models import get_ref_model 19 | from gloc.rendering import get_renderer 20 | from gloc.datasets import get_dataset, find_candidates_paths, get_transform, RenderedImagesDataset, ImListDataset 21 | from gloc.utils import utils, visualization 22 | from gloc.resamplers import get_protocol 23 | from configs import get_config 24 | 25 | 26 | def main(args): 27 | commons.make_deterministic(args.seed) 28 | commons.setup_logging(args.save_dir, console="info") 29 | logging.info(" ".join(sys.argv)) 30 | logging.info(f"Arguments: {args}") 31 | logging.info(f"The outputs are being saved in {args.save_dir}") 32 | 33 | paths_conf = get_path_conf(args.colmap_res, args.mesh) 34 | temp_dir = join(paths_conf['temp'], args.exp_name) 35 | os.makedirs(temp_dir) 36 | 37 | exp_config = get_config(args.name) 38 | scores = None 39 | for i in range(len(exp_config)): 40 | ref_args = exp_config[i] 41 | args.__dict__.update(ref_args) 42 | scores_temp, render_dir = refinement_loop(args) 43 | 44 | scores = utils.update_scores(scores, scores_temp) 45 | args.pose_prior = join(render_dir, 'est_poses.txt') 46 | 47 | visualization.plot_scores(scores, args.save_dir) 48 | 49 | ### cleaning up... 50 | logging.info(f'Moving rendering from temp dir {temp_dir} to {args.save_dir}') 51 | shutil.move(join(temp_dir, 'renderings'), args.save_dir, copy_function=shutil.move) 52 | shutil.rmtree(temp_dir) 53 | logging.info('Terminating without errors!') 54 | 55 | 56 | def refinement_loop(args): 57 | DS = args.name 58 | res = args.res 59 | 60 | paths_conf = get_path_conf(args.colmap_res, args.mesh) 61 | transform = get_transform(args, paths_conf[DS]['colmap']) 62 | pose_dataset = get_dataset(DS, paths_conf[DS], transform) 63 | temp_dir = join(paths_conf['temp'], args.exp_name) 64 | 65 | first_step, all_pred_t, all_pred_R, scores = initialization.init_refinement(args, pose_dataset) 66 | ######### START REFINEMENT LOOP 67 | N_steps = args.steps 68 | N_per_beam = args.N // args.beams 69 | n_beams = args.beams 70 | N_views = args.N 71 | fine_model = get_ref_model(args) 72 | 73 | logging.info('Recomputing query features with refinement model...') 74 | queries_subset = Subset(pose_dataset, pose_dataset.q_frames_idxs) 75 | q_descriptors = extraction.get_query_features(fine_model, queries_subset) 76 | 77 | resampler = get_protocol(args, N_per_beam, args.protocol) 78 | renderer = get_renderer(args, paths_conf) 79 | 80 | max_step = utils.get_n_steps(pose_dataset.num_queries(), N_views, N_steps, args.renderer, args.hard_stop) 81 | for step in range(first_step, N_steps): 82 | if (step - first_step) == max_step: 83 | logging.info('Stopping due to Open3D bug') 84 | break 85 | 86 | resampler.init_step(step) 87 | center_std, angle_delta = resampler.scaler.get_noise() 88 | 89 | logging.info(f'[||] Starting iteration n.{step+1}/{N_steps} [||]') 90 | logging.info(f'Perturbing poses with Theta {angle_delta} and center STD {center_std}. Resolution {res}') 91 | 92 | if (first_step == step) and (args.resume_step is not None): 93 | render_dir = args.resume_step 94 | logging.info(f'Resuming from step dir {render_dir}...') 95 | else: 96 | perturb_str = resampler.get_pertubr_str(step, res) 97 | render_dir = perturb_step(perturb_str, pose_dataset, renderer, resampler, all_pred_t, all_pred_R, temp_dir, n_beams) 98 | renderer.end_epoch(render_dir) 99 | 100 | (all_pred_t, all_pred_R, 101 | all_errors_t, all_errors_R) = rank_candidates(fine_model, pose_dataset, render_dir, pose_dataset.transform, 102 | q_descriptors, N_per_beam, n_beams, chunk_limit=args.chunk_size) 103 | result_str, results = utils.eval_poses_top_n(all_errors_t, all_errors_R, descr=f'step {step}') 104 | logging.info(result_str) 105 | 106 | scores['steps'].append(results) 107 | torch.save(scores, join(args.save_dir, 'scores.pth')) 108 | 109 | if args.clean_logs: 110 | from clean_logs import main as cl_logs 111 | logging.info('Removing rendering files...') 112 | cl_logs(temp_dir, only_step=step) 113 | 114 | return scores, render_dir 115 | 116 | 117 | def perturb_step(perturb_str, pose_dataset, renderer, resampler, pred_t, pred_R, basepath, n_beams=1): 118 | out_dir = os.path.join(basepath, 'renderings', perturb_str) 119 | os.makedirs(out_dir) 120 | logging.info(f'Generating renders in {out_dir}') 121 | 122 | rend_model = renderer.load_model() 123 | r_names_per_beam = {} 124 | for q_idx in tqdm(range(len(pose_dataset.q_frames_idxs)), ncols=100): 125 | idx = pose_dataset.q_frames_idxs[q_idx] 126 | q_name = pose_dataset.get_basename(idx) 127 | q_key_name = os.path.splitext(pose_dataset.images[idx].name)[0] 128 | r_names_per_beam[q_idx] = {} 129 | 130 | K, w, h = pose_dataset.get_intrinsics(q_key_name) 131 | 132 | r_dir = os.path.join(out_dir, q_name) 133 | os.makedirs(r_dir) 134 | for beam_i in range(n_beams): 135 | beam_dir = join(r_dir, f'beam_{beam_i}') 136 | os.makedirs(beam_dir) 137 | pred_t_beam = pred_t[q_idx, beam_i] 138 | pred_R_beam = pred_R[q_idx, beam_i] 139 | 140 | r_names, render_ts, render_qvecs, calibr_pose = resampler.resample(K, q_name, pred_t_beam, pred_R_beam, q_idx=q_idx, beam_i=beam_i) 141 | r_names_per_beam[q_idx][beam_i] = r_names 142 | # poses have to be logged in 'beam_dir', but rendered in 'r_dir', so that 143 | # they can be rendered all together in 'deferred' mode, thus being more efficient 144 | rendering.log_poses(beam_dir, r_names, render_ts, render_qvecs, args.renderer) 145 | if renderer.supports_deferred_rendering: 146 | to_render_dir = r_dir 147 | else: 148 | to_render_dir = beam_dir 149 | renderer.render_poses(to_render_dir, rend_model, r_names, render_ts, render_qvecs, calibr_pose, (w, h), 150 | deferred=renderer.supports_deferred_rendering) 151 | del rend_model 152 | 153 | renderer.end_epoch(out_dir) 154 | logging.info('Moving each renders into their beams folder') 155 | if renderer.supports_deferred_rendering: 156 | for q_idx in range(len(pose_dataset.q_frames_idxs)): 157 | idx = pose_dataset.q_frames_idxs[q_idx] 158 | q_name = pose_dataset.get_basename(idx) 159 | 160 | r_dir = join(out_dir, q_name) 161 | rendering.split_to_beam_folder(r_dir, n_beams, r_names_per_beam[q_idx]) 162 | 163 | return out_dir 164 | 165 | 166 | def rank_candidates(fine_model, pose_dataset, render_dir, transform, q_descriptors, N_per_beam, n_beams, chunk_limit=1100): 167 | all_pred_t = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam, 3)) 168 | all_pred_R = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam, 3, 3)) 169 | all_errors_t = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam)) 170 | all_errors_R = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam)) 171 | all_scores = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam)) 172 | 173 | logging.info(f'Extracting candidates paths') 174 | candidates_pathlist, query_res = find_candidates_paths(pose_dataset, n_beams, render_dir) 175 | 176 | logging.info(f'Found {len(candidates_pathlist)} images for {pose_dataset.n_q} queries, now extracting features altogether') 177 | same_res_transform = T.Compose(transform.transforms.copy()) 178 | same_res_transform.transforms[1] = T.Resize(query_res, antialias=True) 179 | imlist_ds = ImListDataset(candidates_pathlist, same_res_transform) 180 | 181 | chunk_start_q_idx, chunk_end_q_idx, chunks = extraction.split_renders_into_chunks( 182 | pose_dataset.n_q, len(imlist_ds), n_beams, N_per_beam, chunk_limit 183 | ) 184 | dim = extraction.get_feat_dim(fine_model, query_res) 185 | 186 | logging.info(f'Query splits: {chunk_start_q_idx}, {chunk_end_q_idx}') 187 | logging.info(f'Chunk splits: {[c[-1] for c in chunks]}') 188 | for ic, chunk in enumerate(chunks): 189 | q_idx_start = chunk_start_q_idx[ic] 190 | q_idx_end = chunk_end_q_idx[ic] 191 | 192 | logging.info(f'Chunk n.{ic}') 193 | logging.info(f'Query from {q_idx_start} to {q_idx_end}') 194 | logging.info(f'Images from {chunk[0]} to {chunk[-1]}') 195 | 196 | chunk_ds = Subset(imlist_ds, chunk) 197 | descriptors = extraction.get_candidates_features(fine_model, chunk_ds, dim) 198 | 199 | logging.info(f'Extracted shape {descriptors.shape}, now computing predictions') 200 | for q_idx in tqdm(range(q_idx_start, q_idx_end), ncols=100): 201 | q_name = pose_dataset.get_basename(pose_dataset.q_frames_idxs[q_idx]) 202 | query_dir = os.path.join(render_dir, q_name) 203 | q_feats = q_descriptors[q_idx] 204 | 205 | for beam_i in range(n_beams): 206 | beam_dir = join(query_dir, f'beam_{beam_i}') 207 | rd = RenderedImagesDataset(beam_dir, verbose=False) 208 | 209 | start_idx = (q_idx-q_idx_start)*n_beams*N_per_beam + beam_i*N_per_beam 210 | end_idx = start_idx + N_per_beam 211 | r_db_descriptors = descriptors[start_idx:end_idx] 212 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True) 213 | true_t, true_R, pred_t, pred_R = utils.get_pose_from_preds_w_truth(q_idx, pose_dataset, rd, predictions, N_per_beam) 214 | errors_t, errors_R = utils.get_errors_from_preds(true_t, true_R, pred_t, pred_R, N_per_beam) 215 | 216 | all_pred_t[q_idx, beam_i] = pred_t 217 | all_pred_R[q_idx, beam_i] = pred_R 218 | all_errors_t[q_idx, beam_i] = errors_t 219 | all_errors_R[q_idx, beam_i] = errors_R 220 | all_scores[q_idx, beam_i] = scores[:N_per_beam] 221 | 222 | # save scores within each beam so renders can be deleted afterwards 223 | torch.save((predictions, scores), join(beam_dir, 'scores.pth')) 224 | 225 | del q_feats, r_db_descriptors, descriptors 226 | 227 | # sort predictions/errors according the score across beams 228 | # only needed to log and eval poses, the optimization is beam-independent 229 | flat_pred_R, flat_pred_t, flat_preds, all_errors_t, all_errors_R = utils.sort_preds_across_beams(all_scores, all_pred_t, all_pred_R, all_errors_t, all_errors_R) 230 | # log pose estimate 231 | if flat_preds.shape[-1] > 6: 232 | # if there are at least 6 preds per query 233 | logging.info(f'Generating pose file...') 234 | utils.log_pose_estimate(render_dir, pose_dataset, flat_pred_R, flat_pred_t, flat_preds=flat_preds) 235 | 236 | return all_pred_t, all_pred_R, all_errors_t, all_errors_R 237 | 238 | 239 | if __name__ == '__main__': 240 | args = parse_args() 241 | main(args) 242 | -------------------------------------------------------------------------------- /refine_pose_aachen.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import shutil 5 | import torch 6 | from os.path import join 7 | from tqdm import tqdm 8 | import numpy as np 9 | import einops 10 | from torch.utils.data.dataset import Subset 11 | import torchvision.transforms as T 12 | 13 | import commons 14 | from parse_args import parse_args 15 | from path_configs import get_path_conf 16 | from gloc.models import get_ref_model 17 | from gloc import extraction 18 | from gloc import rendering 19 | from gloc.rendering import get_renderer 20 | from gloc.utils import utils, rotmat2qvec 21 | from gloc.resamplers import get_protocol 22 | from gloc.datasets import RenderedImagesDataset, get_dataset, get_transform 23 | 24 | 25 | def main(args): 26 | DS = args.name 27 | res = args.res 28 | 29 | commons.make_deterministic(args.seed) 30 | commons.setup_logging(args.save_dir, console="info") 31 | logging.info(" ".join(sys.argv)) 32 | logging.info(f"Arguments: {args}") 33 | logging.info(f"The outputs are being saved in {args.save_dir}") 34 | 35 | paths_conf = get_path_conf(args.colmap_res, args.mesh) 36 | temp_dir = join(paths_conf['temp'], args.exp_name) 37 | os.makedirs(temp_dir) 38 | 39 | transform = get_transform(args) 40 | pd = get_dataset(DS, paths_conf[DS], transform) 41 | 42 | if args.pose_prior == '': 43 | all_pred_t, all_pred_R = extraction.get_retrieval_predictions(args.model, args.outdim, args.res, pd, topk=args.beams*args.M) 44 | else: 45 | logging.info(f'Loading pose prior from {args.pose_prior}') 46 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pd, args.beams*args.M) 47 | 48 | ######### START REFINEMENT LOOP 49 | N_steps = args.steps 50 | n_beams = args.beams 51 | N_per_beam = args.N // args.beams 52 | M = args.M 53 | fine_model = get_ref_model(args) 54 | 55 | logging.info('Recomputing query features with refinement model...') 56 | queries_subset = Subset(pd, pd.q_frames_idxs) 57 | q_descriptors = extraction.get_features_from_dataset(fine_model, queries_subset, use_tqdm=True) 58 | 59 | resampler = get_protocol(args, N_per_beam, args.protocol) 60 | renderer = get_renderer(args, paths_conf) 61 | 62 | first_step = 0 63 | if args.first_step is not None: 64 | first_step = args.first_step 65 | 66 | max_step = utils.get_n_steps(pd.num_queries(), args.N, N_steps, args.renderer, args.hard_stop) 67 | # go from (NQ, M*beams, 3/3,3) to (NQ, beams, M, 3/3, 3) 68 | all_pred_t = utils.reshape_preds_per_beam(n_beams, M, all_pred_t) 69 | all_pred_R = utils.reshape_preds_per_beam(n_beams, M, all_pred_R) 70 | for step in range(first_step, N_steps): 71 | if (step - first_step) == max_step: 72 | logging.info('Stopping due to Open3D bug') 73 | break 74 | 75 | resampler.init_step(step) 76 | center_std, angle_delta = resampler.scaler.get_noise() 77 | 78 | logging.info(f'[||] Starting iteration n.{step+1}/{N_steps} [||]') 79 | logging.info(f'Perturbing poses with Theta {angle_delta} and center STD {center_std}. Resolution {res}') 80 | 81 | if (first_step == step) and (args.resume_step is not None): 82 | render_dir = args.resume_step 83 | else: 84 | perturb_str = resampler.get_pertubr_str(step, res) 85 | render_dir = perturb_step(perturb_str, pd, renderer, resampler, all_pred_t, all_pred_R, temp_dir, n_beams) 86 | 87 | all_pred_t, all_pred_R = rank_candidates(fine_model, pd, render_dir, 88 | pd.transform, q_descriptors, N_per_beam, n_beams) 89 | 90 | logging.info(f'[!] Concluded iteration n.{step+1}/{N_steps} [!]') 91 | 92 | ### cleaning up... 93 | if args.clean_logs: 94 | from clean_logs import main as cl_logs 95 | logging.info('Removing rendering files...') 96 | cl_logs(temp_dir) 97 | 98 | logging.info(f'Moving rendering from temp dir {temp_dir} to {args.save_dir}') 99 | shutil.move(join(temp_dir, 'renderings'), args.save_dir, copy_function=shutil.move) 100 | shutil.rmtree(temp_dir) 101 | logging.info('Terminating without errors!') 102 | 103 | 104 | def perturb_step(perturb_str, pd, renderer, resampler, pred_t, pred_R, basepath, n_beams=1): 105 | out_dir = os.path.join(basepath, 'renderings', perturb_str) 106 | os.makedirs(out_dir) 107 | logging.info(f'Generating renders in {out_dir}') 108 | 109 | rend_model = renderer.load_model() 110 | r_names_per_beam = {} 111 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100): 112 | idx = pd.q_frames_idxs[q_idx] 113 | q_name = pd.get_basename(idx) 114 | q_key_name = pd.images[idx].name.split('.')[0] 115 | r_names_per_beam[q_idx] = {} 116 | 117 | w = pd.q_intrinsics[q_key_name]['w'] 118 | h = pd.q_intrinsics[q_key_name]['h'] 119 | K = pd.q_intrinsics[q_key_name]['K'] 120 | r_dir = os.path.join(out_dir, q_name) 121 | os.makedirs(r_dir) 122 | for beam_i in range(n_beams): 123 | beam_dir = join(r_dir, f'beam_{beam_i}') 124 | os.makedirs(beam_dir) 125 | pred_t_beam = pred_t[q_idx, beam_i] 126 | pred_R_beam = pred_R[q_idx, beam_i] 127 | 128 | r_names, render_ts, render_qvecs, calibr_pose = resampler.resample(K, q_name, pred_t_beam, pred_R_beam, 129 | q_idx=q_idx, beam_i=beam_i) 130 | 131 | r_names_per_beam[q_idx][beam_i] = r_names 132 | # poses have to be logged in 'beam_dir', but rendered in 'r_dir', so that 133 | # they can be rendered all together by deferred renderers such as ibmr 134 | rendering.log_poses(beam_dir, r_names, render_ts, render_qvecs, args.renderer) 135 | renderer.render_poses(r_dir, rend_model, r_names, render_ts, render_qvecs, calibr_pose, (w, h)) 136 | del rend_model 137 | 138 | renderer.end_epoch(out_dir) 139 | logging.info('Moving each renders into their beams folder') 140 | for q_idx in range(len(pd.q_frames_idxs)): 141 | idx = pd.q_frames_idxs[q_idx] 142 | q_name = pd.get_basename(idx) 143 | 144 | r_dir = join(out_dir, q_name) 145 | for beam_i in range(n_beams): 146 | beam_dir = join(r_dir, f'beam_{beam_i}') 147 | beam_names = r_names_per_beam[q_idx][beam_i] 148 | for b_name in beam_names: 149 | src = join(r_dir, b_name+'.png') 150 | dst = join(beam_dir, b_name+'.png') 151 | shutil.move(src, dst) 152 | 153 | return out_dir 154 | 155 | 156 | def rank_candidates(fine_model, pd, render_dir, transform, q_descriptors, N_per_beam, n_beams): 157 | all_pred_t = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3)) 158 | all_pred_R = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3, 3)) 159 | all_scores = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam)) 160 | 161 | logging.info('Extracting features from rendered images...') 162 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100): 163 | q_name = pd.get_basename(pd.q_frames_idxs[q_idx]) 164 | query_dir = os.path.join(render_dir, q_name) 165 | query_tensor = pd[pd.q_frames_idxs[q_idx]]['im'] 166 | query_res = tuple(query_tensor.shape[-2:]) 167 | q_feats = extraction.get_query_descriptor_by_idx(q_descriptors, q_idx) 168 | 169 | for beam_i in range(n_beams): 170 | beam_dir = join(query_dir, f'beam_{beam_i}') 171 | rd = RenderedImagesDataset(beam_dir, transform, query_res, verbose=False) 172 | r_db_descriptors = extraction.get_features_from_dataset(fine_model, rd, bs=fine_model.conf.bs, is_render=True) 173 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True) 174 | pred_t, pred_R = utils.get_pose_from_preds(q_idx, pd, rd, predictions, N_per_beam) 175 | 176 | all_pred_t[q_idx][beam_i] = pred_t 177 | all_pred_R[q_idx][beam_i] = pred_R 178 | all_scores[q_idx][beam_i] = scores[:N_per_beam] 179 | 180 | scores_file = join(beam_dir, 'scores.pth') 181 | torch.save((predictions, scores), scores_file) 182 | 183 | del q_feats, r_db_descriptors 184 | # flatten stuff for eval 185 | flat_score = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)') 186 | flat_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3) 187 | flat_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3) 188 | flat_preds = np.argsort(flat_score(all_scores)) 189 | flat_pred_t = flat_t(all_pred_t) 190 | flat_pred_R = flat_R(all_pred_R) 191 | 192 | # log pose estimate 193 | logging.info(f'Generating pose file...') 194 | f_results = utils.log_pose_estimate(render_dir, pd, flat_pred_R, flat_pred_t, flat_preds=flat_preds) 195 | 196 | method_name = os.path.dirname(render_dir).split('/')[-2] 197 | method_name = method_name.replace('_aachenreal_ibmr', '') 198 | dir_to_step = lambda x: int(x.split('/')[-1].split('_')[2].split('s')[-1]) 199 | step_n = dir_to_step(render_dir) 200 | if (step_n+1) % 5 == 0: 201 | method_name += f'_s{step_n}' 202 | try: 203 | commons.submit_poses(method=method_name, path=f_results) 204 | except: 205 | logging.info('Submit script failed') 206 | 207 | return all_pred_t, all_pred_R 208 | 209 | 210 | if __name__ == '__main__': 211 | args = parse_args() 212 | main(args) 213 | -------------------------------------------------------------------------------- /render_dataset_from_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os.path import join 4 | import numpy as np 5 | 6 | from configs import get_path_conf 7 | from gloc import rendering 8 | from gloc.rendering import get_renderer 9 | from gloc.datasets import PoseDataset 10 | from gloc.utils import rotmat2qvec 11 | 12 | """ 13 | Example: 14 | python render_dataset_from_script.py stairs --out_dir nerf_renders --colmap_res 320 --renderer g_splatting 15 | """ 16 | 17 | parser = argparse.ArgumentParser(description='Argument parser') 18 | # general args 19 | parser.add_argument('name', type=str, help='') 20 | parser.add_argument('--out_dir', type=str, help='', default='') 21 | parser.add_argument('--renderer', type=str, help='', default='nerf') 22 | parser.add_argument('--colmap_res', type=int, default=320, help='') 23 | args = parser.parse_args() 24 | 25 | print(args) 26 | 27 | DS = args.name 28 | out_dir = join(args.out_dir, args.name, args.renderer, str(args.colmap_res)) 29 | os.makedirs(out_dir, exist_ok=True) 30 | print(f'Renders will in in {out_dir}') 31 | 32 | ##### parse path info and instantiate dataset 33 | paths_conf = get_path_conf(args.colmap_res, None) 34 | pd = PoseDataset(DS, paths_conf[DS]) 35 | #################### 36 | 37 | #### parse pose and intrinsics metadata 38 | all_tvecs, all_Rs = pd.get_q_poses() 39 | names = [pd.images[pd.q_frames_idxs[q_idx]].name.replace('/', '_') for q_idx in range(len(pd.q_frames_idxs)) ] 40 | key = os.path.splitext(pd.images[0].name)[0] 41 | chosen_camera = pd.intrinsics[key] 42 | height = chosen_camera['h'] 43 | width = chosen_camera['w'] 44 | K = chosen_camera['K'] 45 | 46 | calibr_pose = [] 47 | all_qvecs = [] 48 | for tvec, R in zip(all_tvecs, all_Rs): 49 | all_qvecs.append(rotmat2qvec(R)) 50 | 51 | T = np.eye(4) 52 | T[0:3, 0:3] = R 53 | T[0:3, 3] = tvec 54 | calibr_pose.append((T, K)) 55 | #################### 56 | renderer = get_renderer(args, paths_conf) 57 | mod = renderer.load_model() 58 | renderer.render_poses(out_dir, mod, names, all_tvecs, all_qvecs, calibr_pose, (width, height), deferred=False) 59 | 60 | renderer.clean_file_names(out_dir, names, verbose=True) 61 | # specify 'mesh' as renderer because we converted filenames 62 | rendering.log_poses(out_dir, names, all_tvecs, all_qvecs, 'mesh') 63 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.1 2 | faiss-cpu==1.7.4 3 | gdown==4.7.1 4 | googledrivedownloader==0.4 5 | h5py==3.8.0 6 | huggingface-hub==0.14.1 7 | imageio==2.28.1 8 | open3d==0.17.0 9 | opencv-python==4.7.0.72 10 | pandas==2.0.1 11 | plotly==5.14.1 12 | requests==2.28.1 13 | scikit-image==0.20.0 14 | scikit-learn==1.2.2 15 | scipy==1.9.1 16 | tqdm==4.65.0 17 | utm==0.7.0 18 | -------------------------------------------------------------------------------- /submit_poses.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import commons 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("result_path", type=str) 6 | parser.add_argument('-m', "--method_name", required=True, type=str) 7 | args = parser.parse_args() 8 | 9 | commons.submit_poses(method=args.method_name, path=args.result_path) 10 | --------------------------------------------------------------------------------