├── .gitignore
├── .gitmodules
├── README.md
├── assets
└── new_arch_final.png
├── clean_logs.py
├── commons.py
├── configs.py
├── download.py
├── eval_pose_file.py
├── gloc
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── dataset.py
│ ├── dataset_nolabels.py
│ ├── get_dataset.py
│ └── imlist_dataset.py
├── extraction
│ ├── __init__.py
│ ├── extract_feats.py
│ └── utils.py
├── initialization.py
├── models
│ ├── __init__.py
│ ├── features.py
│ ├── get_model.py
│ ├── layers.py
│ └── refinement_model.py
├── rendering
│ ├── __init__.py
│ ├── base_renderer.py
│ ├── mesh_renderer.py
│ ├── nerf_renderer.py
│ ├── rend_conf.py
│ ├── rend_utils.py
│ └── splatting_renderer.py
├── resamplers
│ ├── __init__.py
│ ├── get_protocol.py
│ ├── samplers.py
│ ├── sampling_utils.py
│ ├── scalers.py
│ ├── scalers_conf.py
│ └── strategies.py
└── utils
│ ├── __init__.py
│ ├── camera_utils.py
│ ├── utils.py
│ └── visualization.py
├── parse_args.py
├── path_configs.py
├── refine_pose.py
├── refine_pose_aachen.py
├── render_dataset_from_script.py
├── requirements.txt
└── submit_poses.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .spyproject
2 | .idea
3 | __pycache__
4 | *__pycache__*
5 | logs
6 | cache
7 | .ipynb_checkpoints
8 | *.ipynb
9 | *.torch
10 | renderings
11 | pretrained
12 | descr_cache
13 | rendered_views.txt
14 | fake*
15 | old*
16 | render_check
17 | masks*
18 | sphere_render*
19 | notebook_code.py
20 | extrinsic2pyramid
21 | *.pth
22 | triplet_renders
23 | gifs
24 | feat_steps
25 | extrinsic2pyramid
26 | plots
27 | .vscode
28 | nerf_renders
29 | render_straight
30 | renders_rotate
31 | measure_nerf*
32 | jonas*
33 | *test_keypoints*
34 | viz*
35 | vreph_stuff
36 | vreph_coms
37 | ign_*
38 | pose_priors
39 | to_refine*
40 | conv_analy*
41 | renderings_dense_post
42 | query_mask*
43 | pt2_1*
44 | paper_plots
45 | rebuttal_plots
46 | chess_7*
47 | examples
48 | *.zip
49 | data
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "gloc/models/third_party/RoMA"]
2 | path = gloc/models/third_party/RoMA
3 | url = https://github.com/ga1i13o/RoMa
4 | [submodule "third_party/gaussian-splatting"]
5 | path = third_party/gaussian-splatting
6 | url = https://github.com/ga1i13o/gaussian-splatting
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement
3 |
4 | This is the official pyTorch implementation of the CVPR24 paper "The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement".
5 | In this work, we present a simple approach for Pose Refinement that combines pre-trained features with a particle filter and a renderable representation of the scene.
6 |
7 |
8 | [[CVPR 2024 Open Access](https://openaccess.thecvf.com/content/CVPR2024/html/Trivigno_The_Unreasonable_Effectiveness_of_Pre-Trained_Features_for_Camera_Pose_Refinement_CVPR_2024_paper.html)] [[ArXiv](https://arxiv.org/abs/2404.10438)]
9 |
10 |
11 |
12 |
Our proposed Pose Refinement algorithm.
13 |
14 |
15 | ## Download data
16 |
17 | The following command
18 |
19 | `$ python download.py`
20 |
21 | Will download colmap models, and pre-trained Splatting models for the scene representation, as well as the Cambridge Landmark dataset.
22 |
23 | ## Environment
24 |
25 | Follow the instructions to install the `gaussian_splatting` environment from the [official repo](https://github.com/graphdeco-inria/gaussian-splatting/tree/main?tab=readme-ov-file#setup)
26 |
27 | Then, activate the environment and execute:
28 |
29 | `$ pin install -r requirements.txt`
30 |
31 | ## Reproduce our results
32 |
33 | ### Cambridge Landmarks
34 |
35 | To reproduce results on Cambridge Landmarks, e.g. for KingsCollege:
36 |
37 | `$ python refine_pose.py KingsCollege --exp_name kings_college_refine --renderer g_splatting --clean_logs`
38 |
39 | The script will load the config for the number of steps and hyperparameters of the MonteCarlo optimization from the `configs.py` file. It will utilize a Gaussian Splatting model to render candidate poses. Other options such as a colored mesh, or a NeRF model will be uploaded soon
40 |
41 | ### 7scenes
42 |
43 | Coming soon
44 |
45 | ### Aachen Day-Night
46 |
47 | Coming soon
48 |
49 | ## Cite
50 | Here is the bibtex to cite our paper
51 | ```@inproceedings{trivigno2024unreasonable,
52 | title={The Unreasonable Effectiveness of Pre-Trained Features for Camera Pose Refinement},
53 | author={Trivigno, Gabriele and Masone, Carlo and Caputo, Barbara and Sattler, Torsten},
54 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
55 | pages={12786--12798},
56 | year={2024}
57 | }
58 | ```
59 |
60 |
61 | ## Acknowledgements
62 | Parts of this repo are inspired by the following repositories:
63 | - [Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting)
64 | - [Nerfstudio](https://github.com/nerfstudio-project/nerfstudio)
65 |
--------------------------------------------------------------------------------
/assets/new_arch_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ga1i13o/mcloc_poseref/6f0610c3572486a4a0a4a8268e7844efd66c94e9/assets/new_arch_final.png
--------------------------------------------------------------------------------
/clean_logs.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import shutil
4 | from os.path import join
5 | from glob import glob
6 | from tqdm import tqdm
7 |
8 |
9 | def main(folder, only_step=-1):
10 | if not os.path.isdir(folder):
11 | raise ValueError(f'{folder} is not a directory')
12 | if not 'renderings' in os.listdir(folder):
13 | raise ValueError(f'{folder} should be a log directory')
14 |
15 | steps_folders = glob(join(folder, 'renderings', '*'))
16 | print(f'found {len(steps_folders)} folders')
17 | dir_to_step = lambda x: int(x.split('/')[-1].split('_')[2].split('s')[-1])
18 |
19 | for step_f in tqdm(steps_folders, ncols=100):
20 | s_num = dir_to_step(step_f)
21 | if (only_step != -1) and (s_num != only_step):
22 | continue
23 |
24 | query_renders = glob(join(step_f, '**', '*.png'), recursive=True)
25 | print(f'\nStep {s_num}: found {len(query_renders)} renderings, deleting them...')
26 | for q_r in tqdm(query_renders, ncols=100):
27 | os.remove(q_r)
28 |
29 | vreph_dir = join(step_f, 'vreph_conf')
30 | if os.path.isdir(vreph_dir):
31 | n_vreph = len(os.listdir(vreph_dir))
32 | print(f'Found {n_vreph} files in vreph conf, deleting them...')
33 | shutil.rmtree(vreph_dir)
34 |
35 |
36 | if __name__ == '__main__':
37 | parser = argparse.ArgumentParser(description='Argument parser')
38 | parser.add_argument('folder', type=str, help='log folder')
39 | parser.add_argument('-s', '--step', type=int, default=-1, help='delete only step n.')
40 |
41 | args = parser.parse_args()
42 | main(args.folder, args.step)
43 |
44 |
45 | """
46 | Example
47 | python clean_logs.py logs/pt2_1_kings_cpl3_320_colmaporig/2023-08-24_23-40-16
48 | python clean_logs.py logs/pt2_1_V2_T6_N40_M2_kings_nerf_cpl3_nozstd_320_1/2023-09-19_14-05-18
49 | """
--------------------------------------------------------------------------------
/commons.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains some functions and classes which can be useful in very diverse projects.
3 | """
4 | import os
5 | import sys
6 | import torch
7 | import logging
8 | import traceback
9 | from os.path import join
10 | import random
11 | import numpy as np
12 | import sys
13 | import requests
14 | from bs4 import BeautifulSoup
15 |
16 |
17 | def submit_poses(method, path, dataset='aachenv11'):
18 | session = open('ign_secret.txt', 'r').readline().strip()
19 | resp = requests.get("https://www.visuallocalization.net/submission/",
20 | headers={"Cookie": f"sessionid={session}"})
21 |
22 | bs = BeautifulSoup(resp.content, features="html.parser")
23 | csrf = bs.select('form > input[name="csrfmiddlewaretoken"]')[0].attrs['value']
24 |
25 | url = "https://www.visuallocalization.net/submission/"
26 |
27 | headers = {
28 | "Cookie": f"csrftoken={resp.cookies['csrftoken']}; sessionid={session}",
29 | "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/90.0.4430.93 Safari/537.36",
30 | }
31 |
32 | data = {
33 | "csrfmiddlewaretoken": csrf,
34 | "method_name": method, #"test",
35 | "publication_url": "",
36 | "code_url": "",
37 | "info_field": "",
38 | "dataset": dataset,
39 | "workshop_submission": "default",
40 | }
41 |
42 | files = {
43 | "result_file": open(path, "r"),
44 | }
45 | resp = requests.post(url, files=files, data=data, headers=headers)
46 | print(resp.status_code)
47 |
48 |
49 | def setup_logging(output_folder, console="debug",
50 | info_filename="info.log", debug_filename="debug.log"):
51 | """Set up logging files and console output.
52 | Creates one file for INFO logs and one for DEBUG logs.
53 | Args:
54 | output_folder (str): creates the folder where to save the files.
55 | debug (str):
56 | if == "debug" prints on console debug messages and higher
57 | if == "info" prints on console info messages and higher
58 | if == None does not use console (useful when a logger has already been set)
59 | info_filename (str): the name of the info file. if None, don't create info file
60 | debug_filename (str): the name of the debug file. if None, don't create debug file
61 | """
62 | if os.path.exists(output_folder):
63 | raise FileExistsError(f"{output_folder} already exists!")
64 | os.makedirs(output_folder, exist_ok=True)
65 | # logging.Logger.manager.loggerDict.keys() to check which loggers are in use
66 | logging.getLogger('PIL').setLevel(logging.INFO) # turn off logging tag for some images
67 | logging.getLogger('matplotlib.font_manager').disabled = True
68 |
69 | base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
70 | logger = logging.getLogger('')
71 | logger.setLevel(logging.DEBUG)
72 |
73 | if info_filename != None:
74 | info_file_handler = logging.FileHandler(join(output_folder, info_filename))
75 | info_file_handler.setLevel(logging.INFO)
76 | info_file_handler.setFormatter(base_formatter)
77 | logger.addHandler(info_file_handler)
78 |
79 | if debug_filename != None:
80 | debug_file_handler = logging.FileHandler(join(output_folder, debug_filename))
81 | debug_file_handler.setLevel(logging.DEBUG)
82 | debug_file_handler.setFormatter(base_formatter)
83 | logger.addHandler(debug_file_handler)
84 |
85 | if console != None:
86 | console_handler = logging.StreamHandler()
87 | if console == "debug": console_handler.setLevel(logging.DEBUG)
88 | if console == "info": console_handler.setLevel(logging.INFO)
89 | console_handler.setFormatter(base_formatter)
90 | logger.addHandler(console_handler)
91 |
92 | def exception_handler(type_, value, tb):
93 | logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
94 | sys.excepthook = exception_handler
95 |
96 |
97 | def make_deterministic(seed=0):
98 | """Make results deterministic. If seed == -1, do not make deterministic.
99 | Running your script in a deterministic way might slow it down.
100 | Note that for some packages (eg: sklearn's PCA) this function is not enough.
101 | """
102 | seed = int(seed)
103 | if seed == -1:
104 | return
105 | print(f'setting seed {seed}')
106 | random.seed(seed)
107 | np.random.seed(seed)
108 | torch.manual_seed(seed)
109 | torch.cuda.manual_seed_all(seed)
110 | torch.backends.cudnn.deterministic = True
111 | torch.backends.cudnn.benchmark = False
112 |
113 |
--------------------------------------------------------------------------------
/configs.py:
--------------------------------------------------------------------------------
1 | conf = {
2 | 'cambridge': [
3 | {
4 | # Set N.1
5 | 'beams': 2,
6 | 'steps': 40,
7 | 'N': 52,
8 | 'M': 2,
9 | 'feat_model': 'cosplace_r18_l3',
10 | 'protocol': '2_1',
11 | 'center_std': [1.2, 1.2, 0.1],
12 | 'teta': [10],
13 | 'gamma': 0.3,
14 | 'res': 320,
15 | 'colmap_res': 320,
16 | },
17 | {
18 | # Set N.2
19 | 'beams': 2,
20 | 'steps': 30,
21 | 'N': 40,
22 | 'M': 2,
23 | 'feat_model': 'cosplace_r18_l2',
24 | 'protocol': '2_1',
25 | 'center_std': [.4, .4, .04],
26 | 'teta': [4],
27 | 'gamma': 0.3,
28 | 'res': 320,
29 | 'colmap_res': 320,
30 | },
31 | {
32 | # Set N.3
33 | 'beams': 2,
34 | 'steps': 60,
35 | 'N': 32,
36 | 'M': 1,
37 | 'feat_model': 'cosplace_r18_l2',
38 | 'protocol': '2_0',
39 | 'center_std': [.15, .15, 0.02],
40 | 'teta': [1.5],
41 | 'gamma': 0.3,
42 | 'res': 480,
43 | 'colmap_res': 480,
44 | },
45 | ]
46 | }
47 |
48 |
49 | def get_config(ds_name):
50 | cambridge_scenes = [
51 | 'StMarysChurch', 'OldHospital', 'KingsCollege', 'ShopFacade'
52 | ]
53 |
54 | if ds_name in cambridge_scenes:
55 | return conf['cambridge']
56 | else:
57 | return NotImplementedError
58 |
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import gdown
2 | import os
3 | import shutil
4 | import urllib.request
5 |
6 |
7 | cambridge_scenes = {
8 | 'OldHospital': 'https://www.repository.cam.ac.uk/bitstreams/ae577bfb-bdce-488c-8ce6-3765eabe420e/download',
9 | 'KingsCollege': 'https://www.repository.cam.ac.uk/bitstreams/1cd2b04b-ada9-4841-8023-8207f1f3519b/download',
10 | 'StMarysChurch': 'https://www.repository.cam.ac.uk/bitstreams/2559ba20-c4d1-4295-b77f-183f580dbc56/download',
11 | 'ShopFacade': 'https://www.repository.cam.ac.uk/bitstreams/4e5c67dd-9497-4a1d-add4-fd0e00bcb8cb/download'
12 | }
13 | DATA_DIR = 'data'
14 | os.makedirs(DATA_DIR, exist_ok=True)
15 |
16 | ## Download splatting models for Cambridge
17 | print('Downloading Gaussian Splatting models for cambrdige...')
18 | url = 'https://drive.google.com/uc?id=1iCyizI0jZwZ7mdXG9wGGh2fuwdsbQVkL'
19 | output = 'g_down.zip'
20 | gdown.download(url, output, quiet=False)
21 |
22 | unzip_cmd = f'unzip {output}'
23 | os.system(unzip_cmd)
24 | os.rename('out_gp', 'cambridge_splats')
25 | shutil.move('cambridge_splats', DATA_DIR)
26 | os.remove(output) # remove zip file
27 |
28 | # Download undistorted colmap models for Cambridge and 7scenes
29 | print('Downloading colmap models for Cambridge and 7 scenes...')
30 | url = 'https://drive.google.com/uc?id=1BPUU7Z_Xc4SIQJwfIiIHwUA9ud9c9tCU'
31 | output = 'all_colmaps.zip'
32 | gdown.download(url, output, quiet=False)
33 | unzip_cmd = f'unzip {output}'
34 | os.system(unzip_cmd)
35 | shutil.move('all_colmaps', DATA_DIR)
36 | os.remove(output) # remove zip file
37 |
38 | ## Download cambridge scenes
39 | wget_cmd = 'wget {url} -O {out_name}'
40 | for cs, cs_url in cambridge_scenes.items():
41 | print(f'Downloading dataset for {cs}...')
42 | # urllib.request.urlretrieve(cs_url, f'{cs}.zip')
43 | os.system(wget_cmd.format(url=cs_url, out_name=f'{cs}.zip'))
44 | unzip_cmd = f'unzip {cs}.zip'
45 | os.system(unzip_cmd)
46 | shutil.move(cs, DATA_DIR)
47 | os.remove(f'{cs}.zip') # remove zip file
48 |
--------------------------------------------------------------------------------
/eval_pose_file.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join, dirname
3 | import argparse
4 |
5 | import commons
6 | from parse_args import parse_args
7 | from configs import get_path_conf
8 | from gloc import extraction
9 | from gloc import rendering
10 | from gloc.rendering import get_renderer
11 | from gloc.datasets import get_dataset, RenderedImagesDataset, get_transform
12 | from gloc.utils import utils, visualization
13 | from gloc.build_fine_model import get_fine_model
14 | from gloc.resamplers import get_protocol
15 |
16 |
17 | def main(args):
18 | DS = args.name
19 | print(f"Arguments: {args}")
20 |
21 | paths_conf = get_path_conf(320, None)
22 | pd = get_dataset(DS, paths_conf[DS], None)
23 |
24 | print(f'Loading pose prior from {args.pose_prior}')
25 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pd)
26 |
27 | all_true_t, all_true_R = pd.get_q_poses()
28 | errors_t, errors_R = utils.get_all_errors_first_estimate(all_true_t, all_true_R, all_pred_t, all_pred_R)
29 | out_str, _ = utils.eval_poses(errors_t, errors_R, descr='Retrieval first estimate')
30 | print(out_str)
31 |
32 |
33 | if __name__ == '__main__':
34 | parser = argparse.ArgumentParser(description='Argument parser')
35 | parser.add_argument('name', type=str, help='colmap')
36 | parser.add_argument('pose_prior', type=str, help='colmap')
37 |
38 | args = parser.parse_args()
39 | main(args)
40 |
41 |
42 | """
43 | Example:
44 | python eval_pose_file.py chess_dslam logs/reb_chess_2_bms2_pt2_0_N36_M2_V004_T2_gsplat_dinol4_320/2024-01-27_16-21-05/renderings/pt2_0_s19_sz320_theta2,0_t0,0_0,0_0,0/est_poses.txt
45 | """
46 |
--------------------------------------------------------------------------------
/gloc/__init__.py:
--------------------------------------------------------------------------------
1 | # submodules
2 | from . import datasets
3 | from . import models
4 | from . import extraction
5 | from . import utils
6 |
--------------------------------------------------------------------------------
/gloc/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['dataset', 'get_dataset']
2 |
3 | from gloc.datasets.dataset import PoseDataset, RenderedImagesDataset
4 | from gloc.datasets.dataset import get_query_id, get_render_id
5 | from gloc.datasets.get_dataset import get_dataset, get_transform
6 | from gloc.datasets.imlist_dataset import ImListDataset, find_candidates_paths
--------------------------------------------------------------------------------
/gloc/datasets/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | from os.path import join
5 | import torch.utils.data as data
6 | from PIL import Image
7 | import torchvision.transforms as T
8 |
9 | from gloc.utils import read_model_nopoints as read_model, parse_cam_model, qvec2rotmat
10 | from gloc.utils import Image as RImage
11 |
12 |
13 | def get_query_id(name):
14 | if isinstance(name, tuple):
15 | name = name[0]
16 | name = name.split('/')[-1].split('.')[0]
17 | return name
18 |
19 |
20 | def get_render_id(name):
21 | if isinstance(name, tuple):
22 | name = name[0]
23 | if ('night' in name) or ('day' in name):
24 | i = name.find('IMG')
25 | name = name[i:].split('_')[:-1]
26 | name = '_'.join(name)
27 | return name
28 |
29 | name = name.split('/')[-1].split('.')[0].split('_')[:-1]
30 | if len(name) > 2 and name[1] == 'nexus4':
31 | name = '_'.join(name[5:])
32 | elif len(name) > 2 and name[1] == 'gopro3':
33 | name = '_'.join(name[3:])
34 | else:
35 | name = name[1]
36 | return name
37 |
38 |
39 | class PoseDataset(data.Dataset):
40 | def __init__(self, name, paths_conf, transform=None, rendered_db=None):
41 | self.name = name
42 | self.root = paths_conf['root']
43 | self.colmap_model = paths_conf['colmap']
44 | self.transform = transform
45 | self.rendered_db = rendered_db
46 | self.use_render = False
47 |
48 | self.images, self.intrinsics = PoseDataset.load_colmap(self.colmap_model)
49 |
50 | queries = paths_conf['q_file']
51 | db = paths_conf['db_file']
52 | if (queries != '') and (db != ''):
53 | all_frames = np.array(list(map(lambda x: x.name, self.images)))
54 | self.db_frames_idxs, self.db_tvecs, self.db_qvecs = self.load_txt(db, all_frames)
55 | self.q_frames_idxs, self.q_tvecs, self.q_qvecs = self.load_txt(queries, all_frames)
56 |
57 | all_frames = np.array(list(map(lambda x: x.name, self.images)))
58 | self.db_frames_idxs, self.db_tvecs, self.db_qvecs = self.load_txt(db, all_frames)
59 | self.q_frames_idxs, self.q_tvecs, self.q_qvecs = self.load_txt(queries, all_frames)
60 | self.n_q = len(self.q_frames_idxs)
61 |
62 | def load_txt(self, fpath, all_frames):
63 | with open(fpath, 'r') as f:
64 | lines = f.readlines()
65 | if lines[0].startswith('Visual Landmark'):
66 | lines = lines[3:]
67 | frames = np.array(list(map(lambda x: x.split(' ')[0].strip(), lines)))
68 | frames_idxs = list(np.where(np.in1d(all_frames, frames))[0])
69 | tvecs = np.array(list(map(lambda x: x.tvec, self.images)))[frames_idxs]
70 | qvecs = np.array(list(map(lambda x: x.qvec, self.images)))[frames_idxs]
71 |
72 | return frames_idxs, tvecs, qvecs
73 |
74 | def get_basename(self, im_idx):
75 | q_name = self.images[im_idx].name.replace('/', '_').split('.')[0]
76 | return q_name
77 |
78 | def get_pose(self, idx):
79 | R = qvec2rotmat(self.images[idx].qvec)
80 | t = self.images[idx].tvec
81 |
82 | return t, R
83 |
84 | def get_pose_by_name(self, name):
85 | names =np.array(list(map(lambda x: x.name, self.images)))
86 | idx = np.argwhere(name == names)[0,0]
87 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec
88 |
89 | return qvec, tvec
90 |
91 | def num_queries(self):
92 | return len(self.q_frames_idxs)
93 |
94 | def get_pose_by_name(self, name):
95 | names =np.array(list(map(lambda x: x.name, self.images)))
96 | idx = np.argwhere(name == names)[0,0]
97 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec
98 |
99 | return qvec, tvec
100 |
101 | def get_intrinsics(self, q_key_name):
102 | assert q_key_name in self.intrinsics, f'{q_key_name} is not a valid image name'
103 |
104 | w = self.intrinsics[q_key_name]['w']
105 | h = self.intrinsics[q_key_name]['h']
106 | K = self.intrinsics[q_key_name]['K']
107 |
108 | return K, w, h
109 |
110 | def get_q_poses(self):
111 | Rs = []
112 | ts = []
113 |
114 | for q_idx in range(len(self.q_frames_idxs)):
115 | idx = self.q_frames_idxs[q_idx]
116 |
117 | R = qvec2rotmat(self.images[idx].qvec)
118 | t = self.images[idx].tvec
119 | Rs.append(R)
120 | ts.append(t)
121 |
122 | return np.array(ts), np.array(Rs)
123 |
124 | def get_all_poses(self):
125 | Rs = []
126 | ts = []
127 |
128 | for image in self.images:
129 | R = qvec2rotmat(image.qvec)
130 | t = image.tvec
131 | Rs.append(R)
132 | ts.append(t)
133 |
134 | return np.array(ts), np.array(Rs)
135 |
136 | def __getitem__(self, idx):
137 | """Return:
138 | dict:'im' is the image tensor
139 | 'xyz' is the absolute position of the image
140 | 'wpqr' is the absolute rotation quaternion of the image
141 | """
142 | data_dict = {}
143 | im_data = self.images[idx]
144 | data_dict['im_ref'] = im_data
145 | if not self.use_render:
146 | im = Image.open(join(self.root, im_data.name))
147 | else:
148 | im = Image.open(join(self.rendered_db, im_data.name.replace('/', '_')).replace('.jpg', '.png'))
149 | if self.transform:
150 | im = self.transform(im)
151 | data_dict['im'] = im
152 |
153 | return data_dict
154 |
155 | def __len__(self):
156 | return len(self.images)
157 |
158 | @staticmethod
159 | def load_colmap(colmap_model):
160 | # Load the images
161 | logging.info(f'Loading colmap from {colmap_model}')
162 | cam_list = {}
163 | cameras, images = read_model(colmap_model)
164 | for i in images:
165 | qvec = images[i].qvec
166 | tvec = images[i].tvec
167 | cam_data = cameras[images[i].camera_id]
168 | cam_dict = parse_cam_model(cam_data)
169 |
170 | K = np.array([
171 | [cam_dict["fx"], 0.0, cam_dict["cx"]],
172 | [0.0, cam_dict["fy"], cam_dict["cy"]],
173 | [0.0, 0.0, 1.0]])
174 |
175 | R = qvec2rotmat(qvec)
176 | T = np.eye(4)
177 | T[0:3, 0:3] = R
178 | T[0:3, 3] = tvec
179 | w, h = cam_dict["width"], cam_dict["height"]
180 |
181 | basename = os.path.splitext(images[i].name)[0]
182 |
183 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h, 'model': cam_data.model}
184 | return list(images.values()), cam_list
185 |
186 |
187 | class RenderedImagesDataset(data.Dataset):
188 | def __init__(self, im_root, transform=None, query_res=None, verbose=True):
189 | self.root = im_root
190 | self.descr_file = os.path.join(im_root, 'rendered_views.txt')
191 |
192 | self.images = RenderedImagesDataset.load_images(self.descr_file, verbose)
193 | self.final_resize = None
194 | if query_res is not None:
195 | # this to ensure that renderings end up having same resolution
196 | # as the query once resized
197 | self.final_resize = T.Resize(query_res, antialias=True)
198 |
199 | self.transform = transform
200 |
201 | def __getitem__(self, idx):
202 | """Return:
203 | dict:'im' is the image tensor
204 | 'xyz' is the absolute position of the image
205 | 'wpqr' is the absolute rotation quaternion of the image
206 | """
207 | data_dict = {}
208 | im_data = self.images[idx]
209 | data_dict['im_ref'] = im_data
210 | im = Image.open(join(self.root, im_data.name))
211 |
212 | if self.transform:
213 | im = self.transform(im)
214 | if self.final_resize:
215 | im = self.final_resize(im)
216 | data_dict['im'] = im
217 |
218 | return data_dict
219 |
220 | def __len__(self):
221 | return len(self.images)
222 |
223 | def get_full_paths(self):
224 | paths = list(map(lambda x: join(self.root, x.name), self.images))
225 | return paths
226 |
227 | def get_names(self):
228 | image_names = list(map(lambda x: x.name, self.images))
229 | image_names = np.array(image_names)
230 |
231 | return image_names
232 |
233 | def get_camera_centers(self):
234 | centers = []
235 | for im in self.images:
236 | R = qvec2rotmat(im.qvec)
237 | tvec = im.tvec
238 | im_center = - R.T @ tvec
239 | centers.append(im_center)
240 | centers = np.array(centers)
241 |
242 | return centers
243 |
244 | def get_poses(self):
245 | Rs = []
246 | ts = []
247 |
248 | for image in self.images:
249 | R = qvec2rotmat(image.qvec)
250 | t = image.tvec
251 | Rs.append(R)
252 | ts.append(t)
253 |
254 | return np.array(ts), np.array(Rs)
255 |
256 | @staticmethod
257 | def load_images(fpath, verbose):
258 | # Load the images
259 | if verbose:
260 | print('Loading the rendered and cameras')
261 | images = []
262 | with open(fpath, 'r') as rv:
263 | while True:
264 | line = rv.readline()
265 | if not line:
266 | break
267 | line = line.strip()
268 | fields = line.split(' ')
269 |
270 | name = fields[0]+'.png'
271 | tvec = np.array(tuple(map(float, fields[1:4])))
272 | qvec = np.array(tuple(map(float, fields[4:8])))
273 |
274 | im = RImage(id=-1, qvec=qvec, tvec=tvec, name=name,
275 | camera_id='', xys={}, point3D_ids={})
276 | images.append(im)
277 | return images
278 |
--------------------------------------------------------------------------------
/gloc/datasets/dataset_nolabels.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import numpy as np
4 | import torch.utils.data as data
5 | from PIL import Image
6 | import torchvision.transforms as T
7 |
8 | from gloc.utils import read_model_nopoints as read_model, parse_cam_model, qvec2rotmat
9 | from gloc.utils.camera_utils import read_cameras_intrinsics, Image as RImage
10 |
11 |
12 | class IntrinsicsDataset(data.Dataset):
13 | def __init__(self, name, paths_conf, transform=None):
14 | self.name = name
15 | self.root = paths_conf['root']
16 | self.colmap_model = paths_conf['colmap']
17 | self.q_files = paths_conf['q_intrinsics']
18 | self.transform = transform
19 |
20 | self.db_images, self.db_intrinsics = IntrinsicsDataset.load_colmap(self.colmap_model)
21 | self.q_list, self.q_intrinsics = IntrinsicsDataset.load_queries(self.q_files)
22 | self.intrinsics = self.db_intrinsics.copy()
23 | self.intrinsics.update(self.q_intrinsics)
24 |
25 | self.images = self.db_images.copy()
26 | for q in self.q_list:
27 | im = RImage(id=-1, qvec=-1, tvec=-1, name=q.id, camera_id='', xys={}, point3D_ids={})
28 | self.images.append(im)
29 |
30 | self.db_frames_idxs = list(range(len(self.db_images)))
31 | self.q_frames_idxs = list(np.arange(len(self.q_list)) + len(self.db_images))
32 | self.db_qvecs = np.array(list(map(lambda x:x.qvec, self.db_images)))
33 | self.db_tvecs = np.array(list(map(lambda x:x.tvec, self.db_images)))
34 |
35 | self.n_db = len(self.db_images)
36 | self.n_q = len(self.q_list)
37 | logging.info(f'Loaded dataset with {self.n_db} db images and {self.n_q} queries w/ intrinsics')
38 |
39 | def get_basename(self, im_idx):
40 | q_name = self.images[im_idx].name.replace('/', '_').split('.')[0]
41 | return q_name
42 |
43 | def get_pose(self, idx):
44 | assert idx < self.n_db, f'Only db images have extrinsics, {idx} is a query'
45 | R = qvec2rotmat(self.images[idx].qvec)
46 | t = self.images[idx].tvec
47 |
48 | return t, R
49 |
50 | def num_queries(self):
51 | return self.n_q
52 |
53 | def get_db_poses(self):
54 | Rs = []
55 | ts = []
56 |
57 | for q_idx in range(len(self.db_frames_idxs)):
58 | idx = self.db_frames_idxs[q_idx]
59 |
60 | R = qvec2rotmat(self.images[idx].qvec)
61 | t = self.images[idx].tvec
62 | Rs.append(R)
63 | ts.append(t)
64 |
65 | return np.array(ts), np.array(Rs)
66 |
67 | def __getitem__(self, idx):
68 | """Return:
69 | dict:'im' is the image tensor
70 | 'xyz' is the absolute position of the image
71 | 'wpqr' is the absolute rotation quaternion of the image
72 | """
73 | data_dict = {}
74 | im_data = self.images[idx]
75 | data_dict['im_ref'] = im_data
76 | im = Image.open(os.path.join(self.root, im_data.name))
77 | if self.transform:
78 | im = self.transform(im)
79 | data_dict['im'] = im
80 |
81 | return data_dict
82 |
83 | def __len__(self):
84 | return len(self.images)
85 |
86 | def get_pose_by_name(self, name):
87 | names =np.array(list(map(lambda x: x.name, self.images)))
88 | idx = np.argwhere(name == names)[0,0]
89 | qvec, tvec = self.images[idx].qvec, self.images[idx].tvec
90 |
91 | return qvec, tvec
92 |
93 | @staticmethod
94 | def load_queries(q_file_list):
95 | q_intrinsincs = {}
96 | q_list = []
97 | for q_file in q_file_list:
98 | data, intrinsics_dict = IntrinsicsDataset.load_camera_intrinsics(q_file)
99 | q_list += data
100 | q_intrinsincs.update(intrinsics_dict)
101 |
102 | return q_list, q_intrinsincs
103 |
104 | @staticmethod
105 | def load_camera_intrinsics(intrinsic_file):
106 | # Load the images
107 | logging.info(f'Loading intrinsics from {intrinsic_file}')
108 | cam_list = {}
109 | cameras = read_cameras_intrinsics(intrinsic_file)
110 | for cam_data in cameras:
111 | cam_dict = parse_cam_model(cam_data)
112 | K = np.array([
113 | [cam_dict["fx"], 0.0, cam_dict["cx"]],
114 | [0.0, cam_dict["fy"], cam_dict["cy"]],
115 | [0.0, 0.0, 1.0]])
116 | w, h = cam_dict["width"], cam_dict["height"]
117 |
118 | basename = os.path.splitext(cam_data.id)[0]
119 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h,
120 | 'model': cam_data.model, 'params':cam_data.params}
121 |
122 | return cameras, cam_list
123 |
124 | @staticmethod
125 | def load_colmap(colmap_model):
126 | # Load the images
127 | logging.info(f'Loading colmap from {colmap_model}')
128 | cam_list = {}
129 | cameras, images = read_model(colmap_model)
130 | for i in images:
131 | qvec = images[i].qvec
132 | tvec = images[i].tvec
133 | cam_data = cameras[images[i].camera_id]
134 | cam_dict = parse_cam_model(cam_data)
135 |
136 | K = np.array([
137 | [cam_dict["fx"], 0.0, cam_dict["cx"]],
138 | [0.0, cam_dict["fy"], cam_dict["cy"]],
139 | [0.0, 0.0, 1.0]])
140 |
141 | R = qvec2rotmat(qvec)
142 | T = np.eye(4)
143 | T[0:3, 0:3] = R
144 | T[0:3, 3] = tvec
145 | w, h = cam_dict["width"], cam_dict["height"]
146 |
147 | basename = os.path.splitext(images[i].name)[0]
148 |
149 | cam_list[basename] = {'K':K, 'T':T, 'w':w, 'h':h}
150 | return list(images.values()), cam_list
151 |
--------------------------------------------------------------------------------
/gloc/datasets/get_dataset.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import torchvision.transforms as T
3 |
4 | from gloc.datasets import PoseDataset
5 | from gloc.datasets.dataset_nolabels import IntrinsicsDataset
6 |
7 |
8 | def get_dataset(name, paths_conf, transform=None):
9 | # if 'Aachen' in name:
10 | if name in ['Aachen_night', 'Aachen_day', 'Aachen_real', 'Aachen_real_und']:
11 | dataset = IntrinsicsDataset(name, paths_conf, transform)
12 | else:
13 | dataset = PoseDataset(name, paths_conf, transform)
14 |
15 | return dataset
16 |
17 |
18 | def get_transform(args, colmap_dir=''):
19 | res = args.res
20 | if args.feat_model == 'Dinov2':
21 | cam_file = join(colmap_dir, 'cameras.txt')
22 | random_line = open(cam_file, 'r').readlines()[10].split(' ')
23 | w, h = int(random_line[2]), int(random_line[3])
24 | patch_size = 14
25 | new_h = patch_size * (h // patch_size)
26 | new_w = patch_size * (w // patch_size)
27 | transform = T.Compose([
28 | T.ToTensor(),
29 | T.Resize((new_h, new_w), antialias=True),
30 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
31 | ])
32 |
33 | elif ('Aachen' not in args.name) and (colmap_dir != ''):
34 | cam_file = join(colmap_dir, 'cameras.txt')
35 | random_line = open(cam_file, 'r').readlines()[10].split(' ')
36 | w, h = int(random_line[2]), int(random_line[3])
37 | ratio = min(h, w) / res
38 | new_h = int(h/ratio)
39 | new_w = int(w/ratio)
40 | transform = T.Compose([
41 | T.ToTensor(),
42 | T.Resize((new_h, new_w), antialias=True),
43 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
44 | ])
45 |
46 | else:
47 | transform = T.Compose([
48 | T.ToTensor(),
49 | T.Resize(res, antialias=True),
50 | T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
51 | ])
52 |
53 | return transform
54 |
--------------------------------------------------------------------------------
/gloc/datasets/imlist_dataset.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | from PIL import Image
3 | import os
4 | import torch.utils.data as data
5 | from tqdm import tqdm
6 |
7 | from gloc.datasets import RenderedImagesDataset
8 |
9 |
10 | class ImListDataset(data.Dataset):
11 | def __init__(self, path_list, transform=None):
12 | self.path_list = path_list
13 | self.transform = transform
14 |
15 | def __getitem__(self, idx):
16 | im = Image.open(self.path_list[idx])
17 |
18 | if self.transform:
19 | im = self.transform(im)
20 | return im
21 |
22 | def __len__(self):
23 | return len(self.path_list)
24 |
25 |
26 | def find_candidates_paths(pose_dataset, n_beams, render_dir):
27 | candidates_pathlist = []
28 | for q_idx in tqdm(range(len(pose_dataset.q_frames_idxs)), ncols=100):
29 | q_name = pose_dataset.get_basename(pose_dataset.q_frames_idxs[q_idx])
30 | query_dir = os.path.join(render_dir, q_name)
31 |
32 | for beam_i in range(n_beams):
33 | beam_dir = join(query_dir, f'beam_{beam_i}')
34 |
35 | rd = RenderedImagesDataset(beam_dir, verbose=False)
36 | paths = rd.get_full_paths()
37 | candidates_pathlist += paths
38 |
39 | # return last query res; assumes they are all the same
40 | query_tensor = pose_dataset[pose_dataset.q_frames_idxs[q_idx]]['im']
41 | query_res = tuple(query_tensor.shape[-2:])
42 |
43 | return candidates_pathlist, query_res
44 |
--------------------------------------------------------------------------------
/gloc/extraction/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['extract_real', 'extract_render', 'utils']
2 |
3 |
4 | from gloc.extraction.extract_feats import extract_features, get_query_features, get_candidates_features
5 | from gloc.extraction.utils import get_predictions_from_step_dir, get_retrieval_predictions, split_renders_into_chunks, get_feat_dim
6 |
--------------------------------------------------------------------------------
/gloc/extraction/extract_feats.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import torch
4 | from tqdm import tqdm
5 | from os.path import join
6 | import numpy as np
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data.dataset import Subset
9 |
10 |
11 | def extract_features(model, model_name, pose_dataset, res, bs=32, check_cache=True):
12 | pd = pose_dataset
13 | DS = pd.name
14 |
15 | q_cache_file = join('descr_cache',f'{DS}_{model_name}_{res}_q_descriptors.pth')
16 | db_cache_file = join('descr_cache', f'{DS}_{model_name}_{res}_db_descriptors.pth')
17 | if check_cache:
18 | if (os.path.isfile(db_cache_file) and os.path.isfile(q_cache_file)):
19 |
20 | logging.info(f"Loading {db_cache_file}")
21 | db_descriptors = torch.load(db_cache_file)
22 | q_descriptors = torch.load(q_cache_file)
23 |
24 | return db_descriptors, q_descriptors
25 |
26 | model = model.eval()
27 |
28 | queries_subset_ds = Subset(pd, pd.q_frames_idxs)
29 | database_subset_ds = Subset(pd, pd.db_frames_idxs)
30 |
31 | db_descriptors = get_query_features(model, database_subset_ds, bs)
32 | q_descriptors = get_query_features(model, queries_subset_ds, bs)
33 | db_descriptors = np.vstack(db_descriptors)
34 | q_descriptors = np.vstack(q_descriptors)
35 |
36 | if check_cache:
37 | os.makedirs('descr_cache', exist_ok=True)
38 | torch.save(db_descriptors, db_cache_file)
39 | torch.save(q_descriptors, q_cache_file)
40 |
41 | return db_descriptors, q_descriptors
42 |
43 |
44 | def get_query_features(model, dataset, bs=1):
45 | """
46 | Separate function for the queries as they might have different
47 | resolutions; thus it does not use a matrix to store descriptors
48 | but a list of arrays
49 | """
50 | model = model.eval()
51 | # bs = 1 as resolution might differ
52 | dataloader = DataLoader(dataset=dataset, num_workers=4, batch_size=bs)
53 |
54 | iterator = tqdm(dataloader, ncols=100)
55 | descriptors = []
56 | with torch.no_grad():
57 | for images in iterator:
58 | images = images['im'].cuda()
59 |
60 | descr = model(images)
61 | descr = descr.cpu().numpy()
62 | descriptors.append(descr)
63 |
64 | return descriptors
65 |
66 |
67 | def get_candidates_features(model, dataset, descr_dim, bs=32):
68 | dl = DataLoader(dataset=dataset, num_workers=8, batch_size=bs)
69 |
70 | len_ds = len(dataset)
71 | descriptors = np.empty((len_ds, *descr_dim), dtype=np.float32)
72 |
73 | with torch.no_grad():
74 | for i, images in enumerate(tqdm(dl, ncols=100)):
75 | descr = model(images.cuda())
76 |
77 | descr = descr.cpu().numpy()
78 | descriptors[i*bs:(i+1)*bs] = descr
79 |
80 | return descriptors
81 |
--------------------------------------------------------------------------------
/gloc/extraction/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import numpy as np
4 | from tqdm import tqdm
5 | import torch
6 | import einops
7 | from os.path import join
8 | from torch.utils.data.dataset import Subset
9 |
10 | from gloc.utils import utils, qvec2rotmat
11 | from gloc import extraction
12 | from gloc.datasets import RenderedImagesDataset
13 | from gloc import models
14 |
15 |
16 | def get_retrieval_predictions(model_name, res, pose_dataset, topk=5):
17 | model = models.get_retrieval_model(model_name)
18 |
19 | db_descriptors, q_descriptors = extraction.extract_features(model, model_name, pose_dataset, res, bs=1)
20 | logging.info(f'N. db descriptors: {db_descriptors.shape[0]}, N. Q descriptors: {q_descriptors.shape[0]}')
21 | logging.info('Computing first retrieval prediction...')
22 | all_pred_t, all_pred_R = utils.get_predictions(db_descriptors, q_descriptors, pose_dataset, top_k=topk)
23 | # all_true_t, all_true_R, all_pred_t, all_pred_R = utils.get_predictions_w_truths(db_descriptors, q_descriptors, pd, top_k=topk)
24 |
25 | return all_pred_t, all_pred_R
26 |
27 |
28 | def get_predictions_from_step_dir(fine_model, pd, render_dir, transform, N_per_beam, n_beams):
29 | queries_subset = Subset(pd, pd.q_frames_idxs)
30 | q_descriptors = extraction.get_features_from_dataset(fine_model, queries_subset)
31 |
32 | all_pred_t = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3))
33 | all_pred_R = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3, 3))
34 | names = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam), dtype=object)
35 | all_scores = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam))
36 |
37 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100):
38 | q_name = pd.get_basename(pd.q_frames_idxs[q_idx])
39 | query_dir = os.path.join(render_dir, q_name)
40 | query_tensor = pd[pd.q_frames_idxs[q_idx]]['im']
41 | query_res = tuple(query_tensor.shape[-2:])
42 |
43 | q_feats = extraction.get_query_descriptor_by_idx(q_descriptors, q_idx)
44 | for beam_i in range(n_beams):
45 | if n_beams == 1:
46 | beam_dir = query_dir
47 | else:
48 | beam_dir = join(query_dir, f'beam_{beam_i}')
49 | rd = RenderedImagesDataset(beam_dir, transform, query_res, verbose=False)
50 |
51 | scores_file = join(beam_dir, 'scores.pth')
52 | if not os.path.isfile(scores_file):
53 | r_db_descriptors = extraction.get_features_from_dataset(fine_model, rd, bs=fine_model.conf.bs, is_render=True)
54 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True)
55 | torch.save((predictions, scores), scores_file)
56 | else:
57 | predictions, scores = torch.load(scores_file)
58 | pred_t, pred_R = utils.get_pose_from_preds(q_idx, pd, rd, predictions, N_per_beam)
59 |
60 | all_pred_t[q_idx, beam_i] = pred_t
61 | all_pred_R[q_idx, beam_i] = pred_R
62 | all_scores[q_idx, beam_i] = scores[:N_per_beam]
63 |
64 | nn = []
65 | for pr in predictions[:N_per_beam]:
66 | nn.append(rd.images[pr].name)
67 | names[q_idx, beam_i] = np.array(nn)
68 |
69 | #del q_feats, r_db_descriptors
70 | if n_beams > 1:
71 | # flatten stuff to sort predictions based on similarity
72 | flatten_beams = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)')
73 | flatten_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3)
74 | flatten_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3)
75 |
76 | flat_preds = np.argsort(flatten_beams(all_scores))
77 | names = np.take_along_axis(flatten_beams(names), flat_preds, axis=1)
78 |
79 | flat_t = flatten_t(all_pred_t)
80 | flat_R = flatten_R(all_pred_R)
81 |
82 | stacked_t, stacked_R = [], []
83 | for i in range(len(all_pred_t)):
84 | sorted_t = flat_t[i, flat_preds[i]]
85 | sorted_R = flat_R[i, flat_preds[i]]
86 |
87 | stacked_t.append(sorted_t)
88 | stacked_R.append(sorted_R)
89 | all_pred_t = np.stack(stacked_t)
90 | all_pred_R = np.stack(stacked_R)
91 | else:
92 | all_pred_t = all_pred_t.squeeze()
93 | all_pred_R = all_pred_R.squeeze()
94 | names = names.squeeze()
95 |
96 | return all_pred_t, all_pred_R, names
97 |
98 |
99 | def split_renders_into_chunks(num_queries, num_candidates, n_beams, N_per_beam, im_per_chunk):
100 | chunk_limit = im_per_chunk
101 | if num_queries > chunk_limit:
102 | q_range = np.arange(num_queries)
103 | # this will be [0, 1100, 2200..]
104 | start_chunk_q_idx = q_range[::chunk_limit]
105 | # start from [1], so the first chunk is from 0 to 1100*beams*n_beam
106 | chunk_idx_ims = [start_q*n_beams*N_per_beam for start_q in start_chunk_q_idx[1:]]
107 | chunks = [np.arange(chunk_idx_ims[0])]
108 | for ic in range(1, len(chunk_idx_ims)):
109 | this_chunk = np.arange(chunk_idx_ims[ic-1], chunk_idx_ims[ic])
110 | chunks.append(this_chunk)
111 | last_chunk = np.arange(chunk_idx_ims[-1], num_candidates)
112 | chunks.append(last_chunk)
113 |
114 | chunk_start_q_idx = start_chunk_q_idx
115 | chunk_end_q_idx = list(start_chunk_q_idx[1:]) + [num_queries]
116 | else:
117 | chunks = [np.arange(num_candidates)]
118 | chunk_start_q_idx = [0]
119 | chunk_end_q_idx = [num_queries]
120 |
121 | return chunk_start_q_idx, chunk_end_q_idx, chunks
122 |
123 |
124 | def get_feat_dim(fine_model, query_res):
125 | x = torch.rand(1, 3, *query_res)
126 | dim = tuple(fine_model(x.cuda()).shape)[1:]
127 |
128 | return dim
129 |
--------------------------------------------------------------------------------
/gloc/initialization.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join, dirname
3 | import torch
4 | import logging
5 |
6 | from gloc import extraction
7 | from gloc.utils import utils
8 |
9 |
10 | def init_refinement(args, pose_dataset):
11 | first_step = 0
12 | scores = {}
13 |
14 | if (args.pose_prior is None) and (args.resume_step is None):
15 | # if no pose prior, init with retrieval
16 | all_pred_t, all_pred_R = extraction.get_retrieval_predictions(args.retr_model, args.res, pose_dataset, args.beams*args.M)
17 | elif args.pose_prior is not None:
18 | assert os.path.isfile(args.pose_prior), f'{args.pose_prior} does not exist as a file'
19 |
20 | logging.info(f'Loading pose prior from {args.pose_prior}')
21 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pose_dataset, args.beams*args.M)
22 |
23 | if args.resume_step is None:
24 | all_true_t, all_true_R = pose_dataset.get_q_poses()
25 | errors_t, errors_R = utils.get_all_errors_first_estimate(all_true_t, all_true_R, all_pred_t, all_pred_R)
26 |
27 | out_str, out_vals = utils.eval_poses(errors_t, errors_R, descr='Retrieval first estimate')
28 | scores['baseline'] = out_vals
29 | logging.info(out_str)
30 |
31 | # go from (NQ, M*beams, 3/3,3) to (NQ, beams, M, 3/3, 3)
32 | all_pred_t = utils.reshape_preds_per_beam(args.beams, args.M, all_pred_t)
33 | all_pred_R = utils.reshape_preds_per_beam(args.beams, args.M, all_pred_R)
34 | else:
35 | all_pred_t, all_pred_R = None, None
36 |
37 | scores['steps'] = []
38 | if args.first_step is not None:
39 | first_step = args.first_step
40 | if args.resume_step is not None:
41 | score_path = join(dirname(dirname(args.resume_step)), 'scores.pth')
42 | if os.path.isfile(score_path):
43 | scores = torch.load(score_path)
44 | if len(scores['steps']) > first_step:
45 | logging.info(f"Cutting score file to {first_step}, from {len(scores['steps'])}")
46 | scores['steps'] = scores['steps'][:first_step]
47 |
48 | return first_step, all_pred_t, all_pred_R, scores
49 |
--------------------------------------------------------------------------------
/gloc/models/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['layers', 'get_model', 'fine_models']
2 |
3 |
4 | from gloc.models.get_model import get_retrieval_model, get_ref_model
5 |
--------------------------------------------------------------------------------
/gloc/models/features.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 | from collections import OrderedDict
4 | import torch
5 | from torch import nn
6 | from torchvision.models import resnet18, resnet50
7 |
8 | from gloc.models.layers import L2Norm
9 |
10 |
11 | class BaseFeaturesClass(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 | self.model = nn.Identity()
15 |
16 | def forward(self, x):
17 | """
18 | To be used by subclasses, each specifying their own `self.model`
19 | Args:
20 | x (torch.tensor): batch of images shape Bx3xHxW
21 | Returns:
22 | torch.tensor: Features maps of shape BxDxHxW
23 | """
24 | return self.model(x)
25 |
26 |
27 | class CosplaceFeatures(BaseFeaturesClass):
28 | def __init__(self, model_name):
29 | super().__init__()
30 | if 'r18' in model_name:
31 | arch = 'ResNet18'
32 | else: # 'r50' in model_name
33 | arch = 'ResNet50'
34 | # FC dim set to 512 as a placeholder, it will be truncated anyway before the last FC
35 | model = torch.hub.load("gmberton/cosplace", "get_trained_model",
36 | backbone=arch, fc_output_dim=512)
37 |
38 | backbone = model.backbone
39 | if '_l1' in model_name:
40 | backbone = backbone[:-3]
41 | elif '_l2' in model_name:
42 | backbone = backbone[:-2]
43 | elif '_l3' in model_name:
44 | backbone = backbone[:-1]
45 |
46 | self.model = backbone.eval()
47 |
48 |
49 | class ResnetFeatures(BaseFeaturesClass):
50 | def __init__(self, model_name):
51 | super().__init__()
52 |
53 | if model_name.startswith('resnet18'):
54 | model = resnet18(weights='DEFAULT')
55 | elif model_name.startswith('resnet50'):
56 | model = resnet50(weights='DEFAULT')
57 | else:
58 | raise NotImplementedError
59 |
60 | layers = list(model.children())[:-2] # Remove avg pooling and FC layer
61 | backbone = torch.nn.Sequential(*layers)
62 |
63 | if '_l1' in model_name:
64 | backbone = backbone[:-3]
65 | elif '_l2' in model_name:
66 | backbone = backbone[:-2]
67 | elif '_l3' in model_name:
68 | backbone = backbone[:-1]
69 |
70 | self.model = backbone.eval()
71 |
72 |
73 | class AlexnetFeatures(BaseFeaturesClass):
74 | def __init__(self, model_name):
75 | super().__init__()
76 |
77 | model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
78 | backbone = model.features
79 |
80 | if '_l1' in model_name:
81 | backbone = backbone[:4]
82 | elif '_l2' in model_name:
83 | backbone = backbone[:7]
84 | elif '_l3' in model_name:
85 | backbone = backbone[:9]
86 |
87 | self.model = backbone.eval()
88 |
89 |
90 | class DinoFeatures(BaseFeaturesClass):
91 | def __init__(self, conf):
92 | super().__init__()
93 | self.conf = conf
94 | self.clamp = conf.clamp
95 | self.norm = L2Norm()
96 | self.conf.bs = 32
97 | self.feat_level = conf.level[0]
98 |
99 | dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
100 | self.ref_model = dinov2_vits14
101 |
102 | # override
103 | def forward(self, x):
104 | desc = self.ref_model.get_intermediate_layers(x, n=self.feat_level, reshape=True)[-1]
105 | desc = self.norm(desc)
106 |
107 | return desc
108 |
109 |
110 | class RomaFeatures(BaseFeaturesClass):
111 | weight_urls = {
112 | "roma": {
113 | "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
114 | "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
115 | },
116 | "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
117 | }
118 |
119 | def __init__(self, conf):
120 | super().__init__()
121 | sys.path.append(str(Path(__file__).parent.joinpath('third_party/RoMa')))
122 | from roma.models.encoders import CNNandDinov2
123 |
124 | self.conf = conf
125 | weights = torch.hub.load_state_dict_from_url(self.weight_urls["roma"]["outdoor"])
126 | dinov2_weights = torch.hub.load_state_dict_from_url(self.weight_urls["dinov2"])
127 |
128 | ww = OrderedDict({k.replace('encoder.', ''): v for (k, v) in weights.items() if k.startswith('encoder') })
129 | encoder = CNNandDinov2(
130 | cnn_kwargs = dict(
131 | pretrained=False,
132 | amp = True),
133 | amp = True,
134 | use_vgg = True,
135 | dinov2_weights = dinov2_weights
136 | )
137 | encoder.load_state_dict(ww)
138 |
139 | self.ref_model = encoder.cnn
140 | self.clamp = conf.clamp
141 | self.scale_n = conf.scale_n
142 | self.norm = L2Norm()
143 | self.conf.bs = 16
144 | self.feat_level = conf.level[0]
145 |
146 | # override
147 | def forward(self, x):
148 | f_pyramid = self.ref_model(x)
149 | fmaps = f_pyramid[self.feat_level]
150 |
151 | if self.scale_n != -1:
152 | # optionally scale down fmaps
153 | nh, nw = tuple(fmaps.shape[-2:])
154 | half = nn.AdaptiveAvgPool2d((nh // self.scale_n, nw // self.scale_n))
155 | fmaps = half(fmaps)
156 |
157 | desc = self.norm(fmaps)
158 | return desc
159 |
--------------------------------------------------------------------------------
/gloc/models/get_model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 | from dataclasses import dataclass
4 | import torch
5 | from torch import nn
6 | import os
7 | from torchvision.models import resnet18, resnet50
8 |
9 | from gloc.models.layers import L2Norm, FlattenFeatureMaps
10 | from gloc.models import features
11 |
12 |
13 | def get_retrieval_model(model_name, cuda=True, eval=True):
14 | if model_name.startswith('cosplace'):
15 | model = torch.hub.load("gmberton/cosplace", "get_trained_model",
16 | backbone="ResNet18", fc_output_dim=512)
17 | else:
18 | raise NotImplementedError()
19 |
20 | if cuda:
21 | model = model.cuda()
22 | if eval:
23 | model = model.eval()
24 |
25 | return model
26 |
27 |
28 | def get_feature_model(args, model_name, cuda=True):
29 | if model_name.startswith('cosplace'):
30 | feat_model = features.CosplaceFeatures(model_name)
31 |
32 | elif model_name.startswith('resnet'):
33 | feat_model = features.ResnetFeatures(model_name)
34 |
35 | elif model_name.startswith('alexnet'):
36 | feat_model = features.AlexnetFeatures(model_name)
37 |
38 | elif model_name == 'Dinov2':
39 | conf = DinoConf(clamp=args.clamp_score, level=args.feat_level)
40 | feat_model = features.DinoFeatures(conf)
41 |
42 | elif model_name == 'Roma':
43 | conf = RomaConf(clamp=args.clamp_score, level=args.feat_level, scale_n=args.scale_fmaps)
44 | feat_model = features.RomaFeatures(conf)
45 |
46 | else:
47 | raise NotImplementedError()
48 |
49 | if cuda:
50 | feat_model = feat_model.cuda()
51 |
52 | return feat_model
53 |
54 |
55 | def get_ref_model(args, cuda=True):
56 | model_name = args.ref_model
57 | feat_model = get_feature_model(args, args.feat_model, cuda)
58 |
59 | if model_name == 'DenseFeatures':
60 | from gloc.models.refinement_model import DenseFeaturesRefiner
61 | model_class = DenseFeaturesRefiner
62 | conf = DenseFeaturesConf(clamp=args.clamp_score)
63 |
64 | else:
65 | raise NotImplementedError()
66 |
67 | model = model_class(conf, feat_model)
68 | if cuda:
69 | model = model.cuda()
70 |
71 | return model
72 |
73 |
74 | @dataclass
75 | class DenseFeaturesConf:
76 | clamp: float = -1
77 | def get_str__conf(self):
78 | repr = f"_cl{self.clamp}"
79 | return repr
80 |
81 |
82 | @dataclass
83 | class DinoConf:
84 | clamp: float = -1
85 | level: int = 8
86 | def get_str__conf(self):
87 | repr = f"_l{self.level}_cl{self.clamp}"
88 | return repr
89 |
90 | @dataclass
91 | class RomaConf:
92 | clamp: float = -1
93 | level: int = 4 # 1 2 4 8
94 | # pool feature maps to 1/n
95 | scale_n: int = -1
96 | def get_str__conf(self):
97 | repr = f"_l{self.level}_sn{self.scale_n}_cl{self.clamp}"
98 | return repr
99 |
--------------------------------------------------------------------------------
/gloc/models/layers.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | from torch import nn
4 | from torch.nn import Parameter
5 | import torch.nn.functional as F
6 |
7 |
8 | def gem(x, p=torch.ones(1)*3, eps: float = 1e-6):
9 | return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
10 |
11 |
12 | class GeM(nn.Module):
13 | def __init__(self, p=3, eps=1e-6):
14 | super().__init__()
15 | self.p = Parameter(torch.ones(1)*p)
16 | self.eps = eps
17 |
18 | def forward(self, x):
19 | return gem(x, p=self.p, eps=self.eps)
20 |
21 | def __repr__(self):
22 | return f"{self.__class__.__name__}(p={self.p.data.tolist()[0]:.4f}, eps={self.eps})"
23 |
24 |
25 | class Flatten(nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 |
29 | def forward(self, x):
30 | assert x.shape[2] == x.shape[3] == 1, f"{x.shape[2]} != {x.shape[3]} != 1"
31 | return x[:, :, 0, 0]
32 |
33 |
34 | class FlattenFeatureMaps(nn.Module):
35 | def __init__(self):
36 | super().__init__()
37 | self.flatten = lambda x: einops.rearrange(x, 'b c h w -> b (c h w)')
38 | def forward(self, x):
39 | return self.flatten(x)
40 |
41 |
42 | class L2Norm(nn.Module):
43 | def __init__(self, dim=1):
44 | super().__init__()
45 | self.dim = dim
46 | def forward(self, x):
47 | return F.normalize(x, p=2, dim=self.dim)
48 |
--------------------------------------------------------------------------------
/gloc/models/refinement_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 |
5 | from gloc.models.layers import L2Norm
6 |
7 |
8 | class DenseFeaturesRefiner(nn.Module):
9 | def __init__(self, conf, ref_model):
10 | super().__init__()
11 | self.conf = conf
12 | self.ref_model = ref_model
13 | self.clamp = conf.clamp
14 | self.norm = L2Norm()
15 | self.conf.bs = 32
16 |
17 | def forward(self, x):
18 | """
19 | Args:
20 | x (torch.tensor): batch of images shape Bx3xHxW
21 | Returns:
22 | torch.tensor: Features of shape BxDxHxW
23 | """
24 | with torch.no_grad():
25 | desc = self.ref_model(x)
26 | desc = self.norm(desc)
27 |
28 | return desc
29 |
30 | def score_candidates(self, q_feats, r_db_descriptors):
31 | """_summary_
32 |
33 | Args:
34 | q_feats (np.array): shape 1 x C x H x W
35 | r_db_descriptors (np.array): shape N_cand x C x H x W
36 |
37 | Returns:
38 | torch.tensor : vector of shape (N_cand, ), score of each one
39 | """
40 | q_feats = torch.tensor(q_feats)
41 |
42 | # this version is faster than looped, but requires much more memory due to broadcasting
43 | # r_db = torch.tensor(r_db_descriptors).squeeze(1)
44 | # scores = torch.linalg.norm(q_feats - r_db, dim=1)
45 | scores = torch.zeros(len(r_db_descriptors), q_feats.shape[-2], q_feats.shape[-1])
46 | for i, desc in enumerate(r_db_descriptors):
47 | # q_feats : 1, D, H, W
48 | # desc : D, H, W
49 | # score : 1, H, W
50 | score = torch.linalg.norm(q_feats - torch.tensor(desc), dim=1)
51 | scores[i] = score[0]
52 |
53 | if self.clamp > 0:
54 | scores = scores.clamp(max=self.clamp)
55 | scores = scores.sum(dim=(1,2)) / np.prod(scores.shape[-2:])
56 |
57 | return scores
58 |
59 | def rank_candidates(self, q_feats, r_db_descriptors, get_scores=False):
60 | scores = self.score_candidates(q_feats, r_db_descriptors)
61 | preds = torch.argsort(scores)
62 | if get_scores:
63 | return preds, scores[preds]
64 | return preds
65 |
--------------------------------------------------------------------------------
/gloc/rendering/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['render_conf', 'rend_utils']
2 |
3 |
4 | from gloc.rendering.rend_conf import get_renderer
5 | from gloc.rendering.rend_utils import log_poses, split_to_beam_folder
6 |
--------------------------------------------------------------------------------
/gloc/rendering/base_renderer.py:
--------------------------------------------------------------------------------
1 | class BaseRenderer:
2 | def __init__(self, conf):
3 | pass
4 |
5 | def load_model(self):
6 | return None
7 |
8 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh):
9 | pass
10 |
11 | def end_epoch(self, step_dir):
12 | pass
13 |
14 | @staticmethod
15 | def clean_file_names(r_dir, r_names, verbose):
16 | pass
17 |
--------------------------------------------------------------------------------
/gloc/rendering/mesh_renderer.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import logging
3 | import numpy as np
4 | import open3d as o3d
5 | from os.path import join
6 |
7 | from gloc.rendering.base_renderer import BaseRenderer
8 |
9 |
10 | class MeshRenderer(BaseRenderer):
11 | def __init__(self, conf):
12 | super().__init__(conf)
13 | self.mesh_path = conf.mesh_path
14 | self.background = 'black'
15 | self.supports_deferred_rendering = False
16 | logging.info(f'Using mesh from {self.mesh_path}')
17 |
18 | # override
19 | def load_model(self):
20 | mesh = o3d.io.read_triangle_model(self.mesh_path, False)
21 |
22 | for iter in range(len(mesh.materials)):
23 | mesh.materials[iter].shader = "defaultLit"
24 | # mesh.materials[iter].shader = "defaultUnlit"
25 |
26 | # - the original colors make the textures too dark - set to white
27 | mesh.materials[iter].base_color = [1.0, 1.0, 1.0, 1.0]
28 |
29 | return mesh
30 |
31 | # override
32 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, **kwargs):
33 | w, h = wh
34 |
35 | renderer = o3d.visualization.rendering.OffscreenRenderer(w, h)
36 | renderer.scene.add_model("Scene mesh", model)
37 | # - setup lighting
38 | renderer.scene.scene.enable_sun_light(True)
39 | if self.background == 'black':
40 | renderer.scene.set_background([0., 0., 0.,0.])
41 |
42 | for i, fname in enumerate(r_names):
43 | output_path = join(out_dir, f"{fname}.png")
44 | T, K = pose_list[i]
45 |
46 | renderer.setup_camera(K, T, w, h)
47 | color = np.array(renderer.render_to_image())
48 |
49 | img_rendering = PIL.Image.fromarray(color)
50 | img_rendering.save(output_path)
51 |
52 | del renderer
53 |
--------------------------------------------------------------------------------
/gloc/rendering/nerf_renderer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import shutil
4 | import numpy as np
5 | from os.path import join
6 | from pathlib import Path
7 | import mediapy as media
8 | import torch
9 | import logging
10 | from typing import List
11 |
12 | from nerfstudio.cameras.cameras import Cameras
13 | from nerfstudio.pipelines.base_pipeline import Pipeline
14 | from nerfstudio.utils import colormaps
15 | from nerfstudio.utils.eval_utils import eval_setup
16 | from nerfstudio.cameras.camera_paths import get_path_from_json
17 | from nerfstudio.utils import colormaps
18 |
19 | from gloc.rendering.base_renderer import BaseRenderer
20 | from gloc.utils import get_c2w_nerfconv
21 |
22 |
23 | class NerfRenderer(BaseRenderer):
24 | def __init__(self, conf):
25 | super().__init__(conf)
26 | self.ns_config = conf.ns_config
27 | self.ns_transform = conf.ns_transform
28 | self.supports_deferred_rendering = False
29 | logging.info(f'Using nerf from {self.ns_config}')
30 |
31 | # override
32 | def load_model(self):
33 | _, pipeline, _, _ = eval_setup(
34 | Path(self.ns_config),
35 | eval_num_rays_per_chunk=None,
36 | test_mode="inference",
37 | )
38 |
39 | return pipeline
40 |
41 | # override
42 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, deferred=False):
43 | # for refinement experiments,
44 | # the K matrix is the same throughout the pose list, so take the 1st
45 | K = pose_list[0][1]
46 | traj_file = self._gen_trajectory_file(out_dir, wh, K, r_names, render_ts, render_qvecs)
47 |
48 | with open(traj_file, "r", encoding="utf-8") as f:
49 | camera_path = json.load(f)
50 | camera_path = get_path_from_json(camera_path)
51 |
52 | output_path = Path(out_dir)
53 | rendered_output_names = ['rgb']
54 | self._render_trajectory(
55 | model,
56 | camera_path,
57 | output_filename=output_path,
58 | rendered_output_names=rendered_output_names,
59 | )
60 |
61 | def _gen_trajectory_file(self, out_dir, wh, K, r_names, render_ts, render_qvecs):
62 | width, height = wh
63 | f_len = K[0,0]
64 | traj_file = join(out_dir, 'trajectory.json')
65 |
66 | out = {
67 | 'camera_path': []
68 | }
69 | out["render_height"] = height
70 | out["render_width"] = width
71 | out["seconds"] = len(r_names)
72 |
73 | # load the transformation from dataparser_transforms.json
74 | with open(self.ns_transform) as f:
75 | json_data_txt = f.read()
76 | json_data = json.loads(json_data_txt)
77 |
78 | T_dp = np.eye(4)
79 | T_dp[0:3,:] = np.array(json_data["transform"])
80 | s_dp = json_data["scale"]
81 |
82 | for i in range(len(r_names)):
83 | cam_dict = {}
84 | name = r_names[i]
85 | tvec, qvec = render_ts[i], render_qvecs[i]
86 |
87 | # get camera to world matrix, in nerf convention
88 | c2w = get_c2w_nerfconv(qvec, tvec)
89 | # apply dataparser transform
90 | c2w = T_dp @ c2w
91 | c2w[0:3, 3] = s_dp * c2w[0:3, 3]
92 |
93 | cam_dict['camera_to_world'] = c2w.tolist()
94 | cam_dict['fov'] = np.degrees(2*np.arctan2(height, 2*f_len))
95 | cam_dict['file_path'] = name
96 |
97 | out['camera_path'].append(cam_dict)
98 |
99 | with open(traj_file, 'w') as f:
100 | json.dump(out, f, indent=4)
101 |
102 | return traj_file
103 |
104 | @staticmethod
105 | def _render_trajectory(
106 | pipeline: Pipeline,
107 | cameras: Cameras,
108 | output_filename: Path,
109 | rendered_output_names: List[str],
110 | colormap_options: colormaps.ColormapOptions = colormaps.ColormapOptions(),
111 | ) -> None:
112 | """Helper function to create a video of the spiral trajectory.
113 |
114 | Args:
115 | pipeline: Pipeline to evaluate with.
116 | cameras: Cameras to render.
117 | output_filename: Name of the output file.
118 | rendered_output_names: List of outputs to visualise.
119 | colormap_options: Options for colormap.
120 | """
121 | cameras = cameras.to(pipeline.device)
122 | output_image_dir = output_filename.parent / output_filename.stem
123 | for camera_idx in range(cameras.size):
124 | aabb_box = None
125 | camera_ray_bundle = cameras.generate_rays(camera_indices=camera_idx, aabb_box=aabb_box)
126 |
127 | with torch.no_grad():
128 | outputs = pipeline.model.get_outputs_for_camera_ray_bundle(camera_ray_bundle)
129 |
130 | render_image = []
131 | for rendered_output_name in rendered_output_names:
132 | output_image = outputs[rendered_output_name]
133 |
134 | output_image = colormaps.apply_colormap(
135 | image=output_image,
136 | colormap_options=colormap_options,
137 | ).cpu().numpy()
138 |
139 | render_image.append(output_image)
140 | render_image = np.concatenate(render_image, axis=1)
141 |
142 | media.write_image(output_image_dir / f"{camera_idx:05d}.png", render_image, fmt="png")
143 |
144 | @staticmethod
145 | def clean_file_names(r_dir, r_names, verbose=False):
146 | if verbose:
147 | print('Changing filenames format...')
148 |
149 | f_names = sorted(filter(lambda x: x.endswith('.png'), os.listdir(r_dir)))
150 | for i, f_name in enumerate(f_names):
151 | src = join(r_dir, f_name)
152 | dst = join(r_dir, r_names[i]+'.png')
153 | shutil.move(src, dst)
154 |
--------------------------------------------------------------------------------
/gloc/rendering/rend_conf.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | from gloc.rendering.mesh_renderer import MeshRenderer
4 |
5 |
6 | @dataclass
7 | class O3DConf():
8 | mesh_path:str
9 |
10 | @dataclass
11 | class NeRFConf():
12 | ns_config:str
13 | ns_transform:str
14 |
15 | @dataclass
16 | class GSplattingConf():
17 | gaussians:str
18 |
19 |
20 | def get_renderer(args, paths_conf):
21 |
22 | if args.renderer == 'o3d':
23 | rend_class = MeshRenderer
24 | conf = O3DConf(paths_conf[args.name]['mesh_path'])
25 |
26 | elif args.renderer == 'nerf':
27 | from gloc.rendering.nerf_renderer import NerfRenderer
28 | conf = NeRFConf(ns_config=paths_conf[args.name]['ns_config'],
29 | ns_transform=paths_conf[args.name]['ns_transform'])
30 | rend_class = NerfRenderer
31 |
32 | elif args.renderer == 'g_splatting':
33 | from gloc.rendering.splatting_renderer import GaussianSplattingRenderer
34 | conf = GSplattingConf(paths_conf[args.name]['gaussians'])
35 | rend_class = GaussianSplattingRenderer
36 |
37 | else:
38 | raise NotImplementedError()
39 |
40 | renderer = rend_class(conf)
41 |
42 | return renderer
43 |
--------------------------------------------------------------------------------
/gloc/rendering/rend_utils.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 | import shutil
3 | import os
4 |
5 |
6 | def log_poses(r_dir, r_names, render_ts, render_qvecs, renderer):
7 | if renderer == 'nerf':
8 | width = 5
9 | r_names = [str(idx).zfill(width) for idx in range(len(r_names))]
10 |
11 | with open(join(r_dir, 'rendered_views.txt'), 'w') as rv:
12 | for i in range(len(r_names)):
13 | line_data = [r_names[i], *tuple(render_ts[i]), *tuple(render_qvecs[i])]
14 | line = " ".join(map(str, line_data))
15 | rv.write(line+'\n')
16 |
17 |
18 | def split_to_beam_folder(r_dir, n_beams, r_names_per_beam_q_idx, create_beams=False):
19 | for beam_i in range(n_beams):
20 | beam_dir = join(r_dir, f'beam_{beam_i}')
21 | if create_beams:
22 | os.makedirs(beam_dir, exist_ok=True)
23 | beam_names = r_names_per_beam_q_idx[beam_i]
24 | for b_name in beam_names:
25 | src = join(r_dir, b_name+'.png')
26 | dst = join(beam_dir, b_name+'.png')
27 | shutil.move(src, dst)
28 |
--------------------------------------------------------------------------------
/gloc/rendering/splatting_renderer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from os.path import join
3 | import os
4 | import sys
5 | from pathlib import Path
6 | from argparse import Namespace
7 | import torch
8 | import torchvision
9 | import numpy as np
10 | from tqdm import tqdm
11 | import shutil
12 | from glob import glob
13 | from typing import List
14 |
15 | from gloc.utils import camera_utils, Image as RImage
16 | from gloc.rendering.base_renderer import BaseRenderer
17 |
18 | sys.path.append(str(Path(__file__).parent.parent.parent.joinpath('third_party/gaussian-splatting')))
19 |
20 | from scene import Scene
21 | from gaussian_renderer import render
22 | from gaussian_renderer import GaussianModel
23 |
24 |
25 | class GaussianSplattingRenderer(BaseRenderer):
26 | def __init__(self, conf):
27 | super().__init__(conf)
28 | self.pipeline = Namespace(convert_SHs_python=False, compute_cov3D_python=False, debug=False)
29 | self.gaussians = conf.gaussians
30 | self.sh_degree = 3
31 | self.supports_deferred_rendering = True
32 | self.iteration = 7000
33 | self.buf_deferred = self.init_buffer()
34 | logging.info(f'Using Gaussians from from {self.gaussians}')
35 | self.dataset = Namespace(data_device= 'cuda', eval= False, images='dont',
36 | model_path=self.gaussians, resolution='dont', sh_degree= 3,
37 | source_path='', white_background=False)
38 |
39 | # override
40 | def load_model(self):
41 | gaussians = GaussianModel(self.sh_degree)
42 | return gaussians
43 |
44 | # override
45 | def render_poses(self, out_dir, model, r_names, render_ts, render_qvecs, pose_list, wh, deferred=True):
46 | # for refinement experiments,
47 | # the K matrix is the same throughout the pose list, so take the 1st
48 | K = pose_list[0][1]
49 | if deferred:
50 | self.update_buffer(wh, K, r_names, render_ts, render_qvecs)
51 | else:
52 | if not isinstance(wh, list):
53 | # wh, r_names, render_ts, render_qvecs = [wh], [r_names], [render_ts], [render_qvecs]
54 | wh = [wh]
55 |
56 | # BEFORE it was only THIS
57 | colmap_dir = self._gen_colmap(out_dir, wh, [K], [r_names], [render_ts], [render_qvecs])
58 | self.dataset.source_path = colmap_dir
59 | scene = Scene(self.dataset, model, load_iteration=self.iteration, shuffle=False)
60 | bg_color = [0, 0, 0]
61 |
62 | with torch.no_grad():
63 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
64 | views = scene.getTrainCameras()
65 | for view in tqdm(views, ncols=100):
66 | rendering = render(view, model, self.pipeline, background)["render"]
67 | torchvision.utils.save_image(rendering, join(out_dir, view.image_name+'.png'))
68 |
69 | # override
70 | def end_epoch(self, step_dir):
71 | if self.is_buffer_empty():
72 | return
73 |
74 | model = GaussianModel(self.sh_degree)
75 |
76 | wh_l, K_l, r_names_l, render_ts_l, render_qvecs_l = self.read_buffer()
77 | colmap_dir = self._gen_colmap(step_dir, wh_l, K_l, r_names_l, render_ts_l, render_qvecs_l)
78 | self.dataset.source_path = colmap_dir
79 | scene = Scene(self.dataset, model, load_iteration=self.iteration, shuffle=False)
80 | bg_color = [0, 0, 0]
81 |
82 | with torch.no_grad():
83 | background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
84 | views = scene.getTrainCameras()
85 | for view in tqdm(views, ncols=100):
86 | rendering = render(view, model, self.pipeline, background)["render"]
87 | torchvision.utils.save_image(rendering, join(step_dir, view.image_name+'.png'))
88 |
89 | # clear buffer
90 | self.buf_deferred = self.init_buffer()
91 | print(f'Sorting out the step dir....')
92 | self.sort_step_dir(step_dir)
93 |
94 | def sort_step_dir(self, step_dir):
95 | # put all the renders inside their query directory
96 | all_images = glob(join(step_dir, '*.png'))
97 | im_per_query = {}
98 | for im in all_images:
99 | r_name = im.split('.png')[0]
100 | end_name = r_name.rfind('_')
101 | q_name = r_name[:end_name]
102 | q_name = os.path.basename(q_name)
103 | if q_name not in im_per_query:
104 | im_per_query[q_name] = [r_name]
105 | else:
106 | im_per_query[q_name].append(r_name)
107 |
108 | for q_name in tqdm(im_per_query, ncols=100):
109 | q_renders = im_per_query[q_name]
110 | os.makedirs(join(step_dir, q_name), exist_ok=True)
111 | for q_r in q_renders:
112 | rbg_r = q_r + '.png'
113 | shutil.move(rbg_r, join(step_dir, q_name))
114 |
115 | def is_buffer_empty(self):
116 | return (len(self.buf_deferred['K']) == 0)
117 |
118 | def read_buffer(self):
119 | bf = (self.buf_deferred['wh'], self.buf_deferred['K'], self.buf_deferred['r_names'],
120 | self.buf_deferred['render_ts'], self.buf_deferred['render_qvecs'])
121 | return bf
122 |
123 | def update_buffer(self, wh, K, r_names, render_ts, render_qvecs):
124 | self.buf_deferred['wh'].append(wh)
125 | self.buf_deferred['K'].append(K)
126 | self.buf_deferred['r_names'].append(r_names)
127 | self.buf_deferred['render_ts'].append(render_ts)
128 | self.buf_deferred['render_qvecs'].append(render_qvecs)
129 |
130 | @staticmethod
131 | def init_buffer():
132 | buffer = {
133 | 'wh': [],
134 | 'K': [],
135 | 'r_names': [],
136 | 'render_ts': [],
137 | 'render_qvecs': []
138 | }
139 | return buffer
140 |
141 | @staticmethod
142 | def _gen_colmap(out_dir: str, wh_l: List[tuple], K_l: List[np.array],
143 | r_names_l: List[str], render_ts_l: List[np.array], render_qvecs_l: List[np.array]):
144 | out_cameras = {}
145 | out_images = {}
146 | # print(type(wh_l), type(wh_l[0]), len(wh_l))
147 | # print(len(r_names_l), len(render_ts_l), len(render_qvecs_l))
148 | n_cameras = len(K_l)
149 | im_per_camera = len(r_names_l[0])
150 | # print(n_cameras)
151 | # print(K_l)
152 | for c_id in range(n_cameras):
153 | wh, K, r_names, render_ts, render_qvecs = wh_l[c_id], K_l[c_id], r_names_l[c_id], render_ts_l[c_id], render_qvecs_l[c_id]
154 | width, height = wh
155 | # assumes PINHOLE model
156 | model = 'PINHOLE'
157 | fx, fy = K[0,0], K[1,1]
158 | cx, cy = K[0,2], K[1,2]
159 | params = np.array([fx, fy, cx, cy])
160 | c = camera_utils.Camera(id=c_id, model=model, width=width, height=height, params=params)
161 | out_cameras[c_id] = c
162 |
163 | for r_id in range(len(r_names)):
164 | name = r_names[r_id]
165 | im_id = c_id*im_per_camera + r_id
166 | im = RImage(id=im_id, qvec=render_qvecs[r_id],
167 | tvec=render_ts[r_id], camera_id=c_id,
168 | name=name, xys={}, point3D_ids={})
169 | out_images[im_id] = im
170 |
171 | colmap_subdir = join(out_dir, 'sparse', '0')
172 | os.makedirs(colmap_subdir)
173 | camera_utils.write_model_nopoints(out_cameras, out_images, colmap_subdir, ext='.txt')
174 |
175 | return out_dir
176 |
--------------------------------------------------------------------------------
/gloc/resamplers/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['get_protocol', 'strategies']
2 |
3 | from gloc.resamplers.get_protocol import get_protocol
4 | from gloc.resamplers.sampling_utils import gen_rotations, gen_translations, parse_pose_data
5 |
--------------------------------------------------------------------------------
/gloc/resamplers/get_protocol.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from os.path import join
3 | import torch
4 |
5 | from gloc.resamplers.scalers_conf import get_sampler
6 | from gloc.resamplers.strategies import Protocol1, Protocol2
7 |
8 |
9 | def get_protocol(args, n_views, protocol):
10 | scaler_name = protocol.split('_')[-1]
11 | sampler, scaler = get_sampler(args, args.sampler, scaler_name)
12 |
13 | if protocol.startswith('1'):
14 | protocol_conf = Protocol1Conf(N_steps=args.steps, n_views=n_views)
15 | protocol_class = Protocol1
16 |
17 | elif protocol.startswith('2'):
18 | protocol_conf = Protocol2Conf(N_steps=args.steps, n_views=n_views, M_candidates=args.M)
19 | protocol_class = Protocol2
20 |
21 | else:
22 | raise NotImplementedError
23 |
24 | protocol_obj = protocol_class(protocol_conf, sampler, scaler, protocol)
25 | return protocol_obj
26 |
27 |
28 | @dataclass
29 | class ProtocolConf:
30 | N_steps: int = 20
31 | n_views: int = 20
32 |
33 |
34 | @dataclass
35 | class Protocol1Conf(ProtocolConf):
36 | pass
37 |
38 |
39 | @dataclass
40 | class Protocol2Conf(ProtocolConf):
41 | M_candidates: int = 4
42 |
--------------------------------------------------------------------------------
/gloc/resamplers/samplers.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import numpy as np
4 | from pyquaternion import Quaternion
5 |
6 | from gloc.utils import rotmat2qvec, qvec2rotmat
7 |
8 |
9 | class RandomConstantSampler():
10 | def __init__(self, conf):
11 | pass
12 |
13 | def sample_batch(self, n_views, center_noise, angle_noise, old_t, old_R):
14 | qvecs = []
15 | tvecs = []
16 | poses = []
17 |
18 | for _ in range(n_views):
19 | new_tvec, new_qvec, new_T = self.sample(center_noise, angle_noise, old_t, old_R)
20 |
21 | qvecs.append(new_qvec)
22 | tvecs.append(new_tvec)
23 | poses.append(new_T)
24 |
25 | return tvecs, qvecs, poses
26 |
27 | @staticmethod
28 | def sample(center_noise, angle_noise, old_t, old_R, low_std_ax=1):
29 | old_qvec = rotmat2qvec(old_R) # transform to qvec
30 |
31 | r_axis = Quaternion.random().axis # sample random axis
32 | teta = angle_noise # sample random angle smaller than theta
33 | r_quat = Quaternion(axis=r_axis, degrees=teta)
34 | new_qvec = r_quat * old_qvec # perturb the original pose
35 |
36 | # convert from Quaternion to np.array
37 | new_qvec = new_qvec.elements
38 | new_R = qvec2rotmat(new_qvec)
39 |
40 | old_center = - old_R.T @ old_t # get image center using original pose
41 | perturb_c = np.random.rand(2)
42 | perturb_low_ax = np.random.rand(1)*0.1
43 | perturb_c = np.insert(perturb_c, low_std_ax, perturb_low_ax)
44 | perturb_c /= np.linalg.norm(perturb_c) # normalize noise vector
45 |
46 | # move along the noise direction for a fixed magnitude
47 | new_center = old_center + perturb_c*center_noise
48 | new_t = - new_R @ new_center # use the new pose to convert to translation vector
49 |
50 | new_T = np.eye(4)
51 | new_T[0:3, 0:3] = new_R
52 | new_T[0:3, 3] = new_t
53 |
54 | return new_t, new_qvec, new_T
55 |
56 |
57 | class RandomGaussianSampler():
58 | def __init__(self, conf):
59 | pass
60 |
61 | def sample_batch(self, n_views, center_std, max_angle, old_t, old_R):
62 | qvecs = []
63 | tvecs = []
64 | poses = []
65 |
66 | for _ in range(n_views):
67 | new_tvec, new_qvec, new_T = self.sample(center_std, max_angle, old_t, old_R)
68 |
69 | qvecs.append(new_qvec)
70 | tvecs.append(new_tvec)
71 | poses.append(new_T)
72 |
73 | return tvecs, qvecs, poses
74 |
75 | @staticmethod
76 | def sample(center_std, max_angle, old_t, old_R):
77 | old_qvec = rotmat2qvec(old_R) # transform to qvec
78 |
79 | r_axis = Quaternion.random().axis # sample random axis
80 | teta = np.random.rand()*max_angle # sample random angle smaller than theta
81 | r_quat = Quaternion(axis=r_axis, degrees=teta)
82 | new_qvec = r_quat * old_qvec # perturb the original pose
83 |
84 | # convert from Quaternion to np.array
85 | new_qvec = new_qvec.elements
86 | new_R = qvec2rotmat(new_qvec)
87 |
88 | old_center = - old_R.T @ old_t # get image center using original pose
89 | perturb_c = torch.normal(0., center_std)
90 | new_center = old_center + np.array(perturb_c) # perturb it
91 | new_t = - new_R @ new_center # use the new pose to convert to translation vector
92 |
93 | new_T = np.eye(4)
94 | new_T[0:3, 0:3] = new_R
95 | new_T[0:3, 3] = new_t
96 |
97 | return new_t, new_qvec, new_T
98 |
99 |
100 | class RandomDoubleAxisSampler():
101 | rotate_axis = {
102 | 'pitch': [1, 0, 0], # x, pitch
103 | 'yaw': [0, 1, 0] # y, yaw
104 | }
105 |
106 | def __init__(self, conf):
107 | pass
108 |
109 | def sample_batch(self, n_views: int, center_std: torch.tensor, max_angle: torch.tensor,
110 | old_t: np.array, old_R: np.array):
111 | qvecs = []
112 | tvecs = []
113 | poses = []
114 |
115 | for _ in range(n_views):
116 | # apply yaw first
117 | ax = self.rotate_axis['yaw']
118 | new_tvec, _, new_R, _ = self.sample(ax, center_std, float(max_angle[0]), old_t, old_R)
119 |
120 | # apply pitch then
121 | ax = self.rotate_axis['pitch']
122 | new_tvec, new_qvec, _, new_T = self.sample(ax, center_std, float(max_angle[1]), new_tvec, new_R)
123 |
124 | qvecs.append(new_qvec)
125 | tvecs.append(new_tvec)
126 | poses.append(new_T)
127 |
128 | return tvecs, qvecs, poses
129 |
130 | @staticmethod
131 | def sample(axis, center_std: torch.tensor, max_angle: float, old_t: np.array, old_R: np.array, rot_only: bool =False):
132 | old_qvec = rotmat2qvec(old_R) # transform to qvec
133 |
134 | teta = np.random.rand()*max_angle # sample random angle smaller than theta
135 | r_quat = Quaternion(axis=axis, degrees=teta)
136 | new_qvec = r_quat * old_qvec # perturb the original pose
137 |
138 | # convert from Quaternion to np.array
139 | new_qvec = new_qvec.elements
140 | new_R = qvec2rotmat(new_qvec)
141 |
142 | if not rot_only:
143 | old_center = - old_R.T @ old_t # get image center using original pose
144 | perturb_c = torch.normal(0., center_std)
145 | new_center = old_center + np.array(perturb_c) # perturb it
146 | new_t = - new_R @ new_center # use the new pose to convert to translation vector
147 | else:
148 | new_t = - new_R @ old_center # use the new pose to convert to translation vector
149 |
150 | new_T = np.eye(4)
151 | new_T[0:3, 0:3] = new_R
152 | new_T[0:3, 3] = new_t
153 |
154 | return new_t, new_qvec, new_R, new_T
155 |
156 |
157 | class RandomSamplerByAxis():
158 | rotate_axis = [
159 | [1, 0, 0], # x, pitch
160 | [0, 1, 0] # y, yaw
161 | ]
162 |
163 | def __init__(self, conf):
164 | pass
165 |
166 | def sample_batch(self, n_views, center_std, max_angle, old_t, old_R):
167 | qvecs = []
168 | tvecs = []
169 | poses = []
170 |
171 | for i in range(n_views):
172 | # use first axis half the time, the other for the rest
173 | ax_i = i // ((n_views+1) // 2)
174 | ax = self.rotate_axis[ax_i]
175 |
176 | new_tvec, new_qvec, new_T = self.sample(ax, center_std, max_angle, old_t, old_R)
177 |
178 | qvecs.append(new_qvec)
179 | tvecs.append(new_tvec)
180 | poses.append(new_T)
181 |
182 | return tvecs, qvecs, poses
183 |
184 | @staticmethod
185 | def sample(axis, center_std, max_angle, old_t, old_R):
186 | old_qvec = rotmat2qvec(old_R) # transform to qvec
187 |
188 | teta = np.random.rand()*max_angle # sample random angle smaller than theta
189 | r_quat = Quaternion(axis=axis, degrees=teta)
190 | new_qvec = r_quat * old_qvec # perturb the original pose
191 |
192 | # convert from Quaternion to np.array
193 | new_qvec = new_qvec.elements
194 | new_R = qvec2rotmat(new_qvec)
195 |
196 | old_center = - old_R.T @ old_t # get image center using original pose
197 | perturb_c = torch.normal(0., center_std)
198 | new_center = old_center + np.array(perturb_c) # perturb it
199 | new_t = - new_R @ new_center # use the new pose to convert to translation vector
200 |
201 | new_T = np.eye(4)
202 | new_T[0:3, 0:3] = new_R
203 | new_T[0:3, 3] = new_t
204 |
205 | return new_t, new_qvec, new_T
206 |
207 |
208 | class RandomAndDoubleAxisSampler():
209 | rotate_axis = {
210 | 'pitch': [1, 0, 0], # x, pitch
211 | 'yaw': [0, 1, 0] # y, yaw
212 | }
213 |
214 | def __init__(self, conf):
215 | pass
216 |
217 | def sample_batch(self, n_views: int, center_std: torch.tensor, max_angle: torch.tensor,
218 | old_t: np.array, old_R: np.array):
219 | qvecs = []
220 | tvecs = []
221 | poses = []
222 |
223 | for i in range(n_views):
224 | # use DoubleRotation half the time, Random the rest
225 | ax_i = i // ((n_views+1) // 2)
226 | if ax_i == 0:
227 | # double ax rotation
228 |
229 | # apply yaw first
230 | ax = self.rotate_axis['yaw']
231 | new_tvec, _, new_R, _ = self.sample(ax, center_std, float(max_angle[0]), old_t, old_R)
232 |
233 | # apply pitch then
234 | ax = self.rotate_axis['pitch']
235 | new_tvec, new_qvec, _, new_T = self.sample(ax, center_std, float(max_angle[1]), new_tvec, new_R)
236 |
237 | qvecs.append(new_qvec)
238 | tvecs.append(new_tvec)
239 | poses.append(new_T)
240 | else:
241 | # use random axis, with yaw magnitude
242 | new_tvec, new_qvec, _, new_T = self.sample(None, center_std, float(max_angle[0]), old_t, old_R)
243 |
244 | qvecs.append(new_qvec)
245 | tvecs.append(new_tvec)
246 | poses.append(new_T)
247 |
248 | return tvecs, qvecs, poses
249 |
250 | @staticmethod
251 | def sample(axis, center_std: torch.tensor, max_angle: float, old_t: np.array, old_R: np.array, rot_only: bool =False):
252 | old_qvec = rotmat2qvec(old_R) # transform to qvec
253 |
254 | if axis is None:
255 | # if no axis provided, use a random one
256 | axis = Quaternion.random().axis # sample random axis
257 |
258 | teta = np.random.rand()*max_angle # sample random angle smaller than theta
259 | r_quat = Quaternion(axis=axis, degrees=teta)
260 | new_qvec = r_quat * old_qvec # perturb the original pose
261 |
262 | # convert from Quaternion to np.array
263 | new_qvec = new_qvec.elements
264 | new_R = qvec2rotmat(new_qvec)
265 |
266 | if not rot_only:
267 | old_center = - old_R.T @ old_t # get image center using original pose
268 | perturb_c = torch.normal(0., center_std)
269 | new_center = old_center + np.array(perturb_c) # perturb it
270 | new_t = - new_R @ new_center # use the new pose to convert to translation vector
271 | else:
272 | new_t = - new_R @ old_center # use the new pose to convert to translation vector
273 |
274 | new_T = np.eye(4)
275 | new_T[0:3, 0:3] = new_R
276 | new_T[0:3, 3] = new_t
277 |
278 | return new_t, new_qvec, new_R, new_T
279 |
--------------------------------------------------------------------------------
/gloc/resamplers/sampling_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from pyquaternion import Quaternion
3 |
4 | from gloc.utils import qvec2rotmat
5 |
6 |
7 | def gen_translations(border_points, radius, points_per_meter, q_center, axis_up):
8 | theta = np.linspace(0, 2 * np.pi, border_points, endpoint=False)
9 |
10 | x = np.cos(theta)*radius
11 | if axis_up == 'y':
12 | y = np.zeros(x.shape)
13 | z = np.sin(theta)*radius
14 | elif axis_up == 'z':
15 | y = np.sin(theta)*radius
16 | z = np.zeros(x.shape)
17 | else:
18 | raise NotImplementedError()
19 |
20 | points = np.array([x, y, z]).transpose()
21 | n_points = int(radius * points_per_meter+2)
22 | for point in points:
23 | x0 = np.linspace(0, point[0], n_points)[1:-1]
24 | if axis_up == 'y':
25 | y0 = np.zeros(x0.shape)
26 | z0 = np.linspace(0, point[2], n_points)[1:-1]
27 | elif axis_up == 'z':
28 | y0 = np.linspace(0, point[1], n_points)[1:-1]
29 | z0 = np.zeros(x0.shape)
30 | else:
31 | raise NotImplementedError()
32 |
33 | new_p=np.stack((x0, y0, z0)).transpose()
34 | points = np.vstack((points, new_p))
35 |
36 | points += q_center
37 | points = np.insert(points, 0, q_center, axis=0)
38 | camera_centers = list(points)
39 | return camera_centers
40 |
41 |
42 | def gen_rotations(qvec, R, tvec, c_center, points_per_axis, max_angle, n_axis):
43 | axis_set = [
44 | [1, 0, 0],
45 | [0, 1, 0],
46 | [1, -1, 0],
47 | [1, 1, 0],
48 | ]
49 |
50 | gen_poses = [(qvec, R, tvec)]
51 | for axis in axis_set[:n_axis]:
52 | theta = np.linspace(-max_angle, max_angle, points_per_axis)
53 | theta = np.delete(theta, points_per_axis // 2)
54 |
55 | for th in theta:
56 | my_quaternion = Quaternion(axis=axis, angle=th)
57 | new_qvec = my_quaternion * qvec
58 | new_R = qvec2rotmat(new_qvec)
59 | new_t = - new_R @ c_center
60 | gen_poses.append((new_qvec, new_R, new_t))
61 | return gen_poses
62 |
63 |
64 | def parse_pose_data(q_basename, gen_poses, K, sub_index):
65 | render_qvecs = []
66 | render_ts = []
67 | calibr_pose = []
68 | r_names = []
69 | for i, pose in enumerate(gen_poses):
70 | new_qvec, new_R, new_t = pose
71 | T = np.eye(4)
72 | T[0:3, 0:3] = new_R
73 | T[0:3, 3] = new_t
74 |
75 | render_qvecs.append(new_qvec)
76 | render_ts.append(new_t)
77 | r_names.append(q_basename + f'_{sub_index}' + f'_{i}')
78 | calibr_pose.append((T, K))
79 |
80 | return render_qvecs, render_ts, r_names, calibr_pose
81 |
--------------------------------------------------------------------------------
/gloc/resamplers/scalers.py:
--------------------------------------------------------------------------------
1 | class ConstantScaler():
2 | def __init__(self, conf):
3 | self.max_angle = conf.max_angle
4 | self.center_std = conf.max_center_std
5 |
6 | def step(self, i):
7 | pass
8 |
9 | def get_noise(self):
10 | return self.center_std, self.max_angle
11 |
12 | def get_max_noise(self, multiplier=1):
13 | return self.center_std*multiplier, self.max_angle*multiplier
14 |
15 |
16 | class UniformScaler():
17 | def __init__(self, conf):
18 | # gamma is the minimum multiplier that will be applied
19 | self.max_angle = conf.max_angle
20 | self.max_center_std = conf.max_center_std
21 | self.current_angle = conf.max_angle
22 | self.current_center_std = conf.max_center_std
23 |
24 | self.n_steps = conf.N_steps
25 | self.gamma = conf.gamma
26 |
27 | def step(self, i):
28 | scale_noise = max(self.gamma, (self.n_steps - i)/self.n_steps)
29 |
30 | self.current_center_std = self.max_center_std*scale_noise
31 | self.current_angle = self.max_angle* scale_noise
32 |
33 | def get_noise(self):
34 | return self.current_center_std, self.current_angle
35 |
36 | def get_max_noise(self, multiplier=1):
37 | return self.max_center_std*multiplier, self.max_angle*multiplier
38 |
--------------------------------------------------------------------------------
/gloc/resamplers/scalers_conf.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 |
4 | from gloc.resamplers.samplers import (RandomGaussianSampler, RandomSamplerByAxis,
5 | RandomDoubleAxisSampler, RandomAndDoubleAxisSampler)
6 | from gloc.resamplers.scalers import ConstantScaler, UniformScaler
7 |
8 |
9 | def get_sampler(args, sampler_name, scaler_name):
10 | max_angle_delta = torch.tensor(args.teta)
11 | max_center_std = torch.tensor(args.center_std)
12 |
13 | sampler_conf = SamplerConf()
14 | if sampler_name == 'rand':
15 | sampler_class = RandomGaussianSampler
16 | elif sampler_name == 'rand_yaw_or_pitch':
17 | sampler_class = RandomSamplerByAxis
18 | elif sampler_name == 'rand_yaw_and_pitch':
19 | sampler_class = RandomDoubleAxisSampler
20 | elif sampler_name == 'rand_and_yaw_and_pitch':
21 | sampler_class = RandomAndDoubleAxisSampler
22 | else:
23 | raise NotImplementedError()
24 |
25 | if scaler_name == '0':
26 | scaler_conf = ConstantScalerConf(max_center_std=max_center_std, max_angle=max_angle_delta)
27 | scaler_class = ConstantScaler
28 | elif scaler_name == '1':
29 | scaler_conf = UniformScalerConf(max_center_std=max_center_std, max_angle=max_angle_delta,
30 | N_steps=args.steps, gamma=args.gamma)
31 | scaler_class = UniformScaler
32 | else:
33 | raise NotImplementedError()
34 |
35 | sampler = sampler_class(sampler_conf)
36 | scaler = scaler_class(scaler_conf)
37 |
38 | return sampler, scaler
39 |
40 |
41 | @dataclass
42 | class ConstantScalerConf:
43 | max_center_std: torch.tensor = torch.tensor([1, 1, 1])
44 | max_angle: torch.tensor = torch.tensor([8])
45 |
46 |
47 | @dataclass
48 | class UniformScalerConf(ConstantScalerConf):
49 | N_steps: int = 20
50 | gamma: float = 0.1
51 |
52 |
53 | @dataclass
54 | class SamplerConf:
55 | pass
--------------------------------------------------------------------------------
/gloc/resamplers/strategies.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from gloc.utils import rotmat2qvec
4 |
5 |
6 | class BaseProtocol:
7 | """This base dummy class serves as template for subclasses. it always returns
8 | the same poses without perturbing them"""
9 | def __init__(self, conf, sampler, scaler, protocol_name):
10 | self.sampler = sampler
11 | self.scaler = scaler
12 | self.n_steps = conf.N_steps
13 | self.n_views = conf.n_views
14 | self.protocol = protocol_name
15 | # init for later
16 | self.center_std = None
17 | self.max_angle = None
18 |
19 | def init_step(self, i):
20 | self.scaler.step(i)
21 | self.center_std, self.max_angle = self.scaler.get_noise()
22 |
23 | def get_pertubr_str(self, step, res):
24 | c_str = "_".join(list(map(lambda x: f'{x:.1f}'.replace('.', ','), map(float, self.center_std))))
25 | angle_str = "_".join(list(map(lambda x: f'{x:.1f}'.replace('.', ','), map(float, self.max_angle))))
26 |
27 | perturb_str = f'pt{self.protocol}_s{step}_sz{res}_theta{angle_str}_t{c_str}'
28 | return perturb_str
29 |
30 | @staticmethod
31 | def get_r_name(q_name, r_i, beam_i):
32 | r_name = q_name+f'_{r_i}beam{beam_i}'
33 | return r_name
34 |
35 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs):
36 | # this base class returns the same pose all over again
37 | render_qvecs = []
38 | render_ts = []
39 | calibr_pose = []
40 | r_names = []
41 |
42 | for i in range(self.n_views):
43 | t = pred_t[i]
44 | R = pred_R[i]
45 | qvec = rotmat2qvec(R)
46 |
47 | render_qvecs.append(qvec)
48 | render_ts.append(t)
49 | r_names.append(BaseProtocol.get_r_name(q_name, i, beam_i))
50 | T = np.eye(4)
51 | T[0:3, 0:3] = R
52 | T[0:3, 3] = t
53 | calibr_pose.append((T, K))
54 |
55 | return r_names, render_ts, render_qvecs, calibr_pose
56 |
57 |
58 | class Protocol1(BaseProtocol):
59 | """
60 | This protocol keeps only the first prediction, to perturb N times
61 | """
62 | def __init__(self, conf, scaler, sampler, protocol_name):
63 | super().__init__(conf, scaler, sampler, protocol_name)
64 |
65 | # override
66 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs):
67 | render_qvecs = []
68 | render_ts = []
69 | calibr_pose = []
70 | r_names = []
71 |
72 | t = pred_t[0] # take first prediction
73 | R = pred_R[0] # take first prediction
74 | qvec = rotmat2qvec(R) # transform to qvec
75 |
76 | #### keep previous estimate #####
77 | render_qvecs.append(qvec)
78 | render_ts.append(t)
79 | r_names.append(BaseProtocol.get_r_name(q_name, 0, beam_i))
80 | T = np.eye(4)
81 | T[0:3, 0:3] = R
82 | T[0:3, 3] = t
83 | calibr_pose.append((T, K))
84 | ####################
85 | views_per_candidate = self.n_views - 1
86 | new_ts, new_qvecs, new_poses = self.sampler.sample_batch(views_per_candidate,
87 | self.center_std, self.max_angle,
88 | t, R)
89 | render_ts += new_ts
90 | render_qvecs += new_qvecs
91 | for j in range(views_per_candidate):
92 | r_name = BaseProtocol.get_r_name(q_name, j + 1, beam_i)
93 | r_names.append(r_name)
94 | calibr_pose.append((new_poses[j], K))
95 |
96 | return r_names, render_ts, render_qvecs, calibr_pose
97 |
98 |
99 | class Protocol2(BaseProtocol):
100 | """
101 | This protocol keeps the first M predictions, perturbing them N // M times
102 | """
103 | def __init__(self, conf, scaler, sampler, protocol_name):
104 | super().__init__(conf, scaler, sampler, protocol_name)
105 | self.M = conf.M_candidates
106 |
107 | # override
108 | def resample(self, K, q_name, pred_t, pred_R, beam_i=0, *args, **kwargs):
109 | render_qvecs = []
110 | render_ts = []
111 | calibr_pose = []
112 | r_names = []
113 |
114 | #### keep previous first M estimates #####
115 | for i in range(self.M):
116 | t = pred_t[i]
117 | R = pred_R[i]
118 | qvec = rotmat2qvec(R)
119 |
120 | render_qvecs.append(qvec)
121 | render_ts.append(t)
122 | r_names.append(BaseProtocol.get_r_name(q_name, i, beam_i))
123 | T = np.eye(4)
124 | T[0:3, 0:3] = R
125 | T[0:3, 3] = t
126 | calibr_pose.append((T, K))
127 | ####################
128 |
129 | views_per_candidate = self.n_views // self.M - 1
130 | for i in range(self.M):
131 | t = pred_t[i]
132 | R = pred_R[i]
133 |
134 | new_ts, new_qvecs, new_poses = self.sampler.sample_batch(views_per_candidate, self.center_std, self.max_angle,
135 | t, R)
136 | render_ts += new_ts
137 | render_qvecs += new_qvecs
138 | for j in range(views_per_candidate):
139 | r_name = BaseProtocol.get_r_name(q_name, self.M+i*views_per_candidate+j, beam_i)
140 | r_names.append(r_name)
141 | calibr_pose.append((new_poses[j], K))
142 | return r_names, render_ts, render_qvecs, calibr_pose
143 |
--------------------------------------------------------------------------------
/gloc/utils/__init__.py:
--------------------------------------------------------------------------------
1 | __all__ = ['camera_utils', 'utils', 'visualization']
2 |
3 |
4 | from gloc.utils.camera_utils import (qvec2rotmat, rotmat2qvec, get_c2w_nerfconv,
5 | read_model_nopoints, parse_cam_model, Image)
6 | from gloc.utils import utils
7 | from gloc.utils import visualization
8 |
--------------------------------------------------------------------------------
/gloc/utils/camera_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import collections
4 | import struct
5 |
6 |
7 | #### Code taken from Colmap:
8 | # from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
9 | CameraModel = collections.namedtuple(
10 | "CameraModel", ["model_id", "model_name", "num_params"])
11 | Camera = collections.namedtuple(
12 | "Camera", ["id", "model", "width", "height", "params"])
13 | BaseImage = collections.namedtuple(
14 | "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
15 | Point3D = collections.namedtuple(
16 | "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
17 |
18 |
19 | class Image(BaseImage):
20 | def qvec2rotmat(self):
21 | return qvec2rotmat(self.qvec)
22 |
23 |
24 | CAMERA_MODELS = {
25 | CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
26 | CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
27 | CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
28 | CameraModel(model_id=3, model_name="RADIAL", num_params=5),
29 | CameraModel(model_id=4, model_name="OPENCV", num_params=8),
30 | CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
31 | CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
32 | CameraModel(model_id=7, model_name="FOV", num_params=5),
33 | CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
34 | CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
35 | CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
36 | }
37 | CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
38 | for camera_model in CAMERA_MODELS])
39 | CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
40 | for camera_model in CAMERA_MODELS])
41 |
42 |
43 |
44 | def parse_cam_model(cam_data):
45 | model = cam_data.model
46 | width = cam_data.width
47 | height = cam_data.height
48 |
49 | if model == "SIMPLE_PINHOLE" or model == "SIMPLE_RADIAL" or model == "RADIAL" or model == "SIMPLE_RADIAL_FISHEYE" or model == "RADIAL_FISHEYE":
50 | fx = cam_data.params[0]
51 | fy = fx
52 | cx = cam_data.params[1]
53 | cy = cam_data.params[2]
54 | elif model == "PINHOLE" or model == "OPENCV" or model == "OPENCV_FISHEYE" or model == "FULL_OPENCV" or model == "FOV" or model == "THIN_PRISM_FISHEYE":
55 | fx = cam_data.params[0]
56 | fy = cam_data.params[1]
57 | cx = cam_data.params[2]
58 | cy = cam_data.params[3]
59 |
60 | return {"width":width, "height":height, "fx":fx, "fy":fy, "cx":cx, "cy":cy}
61 |
62 |
63 | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
64 | """Read and unpack the next bytes from a binary file.
65 | :param fid:
66 | :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
67 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
68 | :param endian_character: Any of {@, =, <, >, !}
69 | :return: Tuple of read and unpacked values.
70 | """
71 | data = fid.read(num_bytes)
72 | return struct.unpack(endian_character + format_char_sequence, data)
73 |
74 |
75 | def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
76 | """pack and write to a binary file.
77 | :param fid:
78 | :param data: data to send, if multiple elements are sent at the same time,
79 | they should be encapsuled either in a list or a tuple
80 | :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
81 | should be the same length as the data list or tuple
82 | :param endian_character: Any of {@, =, <, >, !}
83 | """
84 | if isinstance(data, (list, tuple)):
85 | bytes = struct.pack(endian_character + format_char_sequence, *data)
86 | else:
87 | bytes = struct.pack(endian_character + format_char_sequence, data)
88 | fid.write(bytes)
89 |
90 |
91 | def get_c2w_nerfconv(qvec, tvec):
92 | R = qvec2rotmat(qvec)
93 |
94 | c_T_w = np.eye(4)
95 | c_T_w[0:3, 0:3] = R
96 | c_T_w[0:3, 3] = tvec
97 |
98 | c2w = np.linalg.inv(c_T_w)
99 | c2w[0:3,2] *= -1 # flip the y and z axis
100 | c2w[0:3,1] *= -1
101 | # for Aachen, axis in the 3D model are different,
102 | # so comment these 2 lines below
103 | c2w=c2w[[1,0,2,3],:] # swap y and z
104 | c2w[2,:] *= -1 # flip whole world upside down
105 |
106 | return c2w
107 |
108 |
109 | def qvec2rotmat(qvec):
110 | return np.array([
111 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
112 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
113 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
114 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
115 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
116 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
117 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
118 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
119 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
120 |
121 |
122 | def rotmat2qvec(R):
123 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
124 | K = np.array([
125 | [Rxx - Ryy - Rzz, 0, 0, 0],
126 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
127 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
128 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
129 | eigvals, eigvecs = np.linalg.eigh(K)
130 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
131 | if qvec[0] < 0:
132 | qvec *= -1
133 | return qvec
134 |
135 |
136 | def read_cameras_intrinsics(path):
137 | """
138 | see: src/base/reconstruction.cc
139 | void Reconstruction::WriteCamerasText(const std::string& path)
140 | void Reconstruction::ReadCamerasText(const std::string& path)
141 | """
142 | cameras = []
143 | with open(path, "r") as fid:
144 | while True:
145 | line = fid.readline()
146 | if not line:
147 | break
148 | line = line.strip()
149 | if len(line) > 0 and line[0] != "#":
150 | elems = line.split()
151 | #camera_id = int(elems[0])
152 | path = elems[0]
153 | model = elems[1]
154 | width = int(elems[2])
155 | height = int(elems[3])
156 | params = np.array(tuple(map(float, elems[4:])))
157 | cameras.append(Camera(id=path, model=model,
158 | width=width, height=height,
159 | params=params))
160 | return cameras
161 |
162 |
163 | def read_cameras_text(path):
164 | """
165 | see: src/base/reconstruction.cc
166 | void Reconstruction::WriteCamerasText(const std::string& path)
167 | void Reconstruction::ReadCamerasText(const std::string& path)
168 | """
169 | cameras = {}
170 | with open(path, "r") as fid:
171 | while True:
172 | line = fid.readline()
173 | if not line:
174 | break
175 | line = line.strip()
176 | if len(line) > 0 and line[0] != "#":
177 | elems = line.split()
178 | camera_id = int(elems[0])
179 | model = elems[1]
180 | width = int(elems[2])
181 | height = int(elems[3])
182 | params = np.array(tuple(map(float, elems[4:])))
183 | cameras[camera_id] = Camera(id=camera_id, model=model,
184 | width=width, height=height,
185 | params=params)
186 | return cameras
187 |
188 |
189 | def read_cameras_binary(path_to_model_file):
190 | """
191 | see: src/base/reconstruction.cc
192 | void Reconstruction::WriteCamerasBinary(const std::string& path)
193 | void Reconstruction::ReadCamerasBinary(const std::string& path)
194 | """
195 | cameras = {}
196 | with open(path_to_model_file, "rb") as fid:
197 | num_cameras = read_next_bytes(fid, 8, "Q")[0]
198 | for _ in range(num_cameras):
199 | camera_properties = read_next_bytes(
200 | fid, num_bytes=24, format_char_sequence="iiQQ")
201 | camera_id = camera_properties[0]
202 | model_id = camera_properties[1]
203 | model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
204 | width = camera_properties[2]
205 | height = camera_properties[3]
206 | num_params = CAMERA_MODEL_IDS[model_id].num_params
207 | params = read_next_bytes(fid, num_bytes=8*num_params,
208 | format_char_sequence="d"*num_params)
209 | cameras[camera_id] = Camera(id=camera_id,
210 | model=model_name,
211 | width=width,
212 | height=height,
213 | params=np.array(params))
214 | assert len(cameras) == num_cameras
215 | return cameras
216 |
217 |
218 |
219 | def read_images_text(path):
220 | """
221 | see: src/base/reconstruction.cc
222 | void Reconstruction::ReadImagesText(const std::string& path)
223 | void Reconstruction::WriteImagesText(const std::string& path)
224 | """
225 | images = {}
226 | with open(path, "r") as fid:
227 | while True:
228 | line = fid.readline()
229 | if not line:
230 | break
231 | line = line.strip()
232 | if len(line) > 0 and line[0] != "#":
233 | elems = line.split()
234 | image_id = int(elems[0])
235 | qvec = np.array(tuple(map(float, elems[1:5])))
236 | tvec = np.array(tuple(map(float, elems[5:8])))
237 | camera_id = int(elems[8])
238 | image_name = elems[9]
239 | elems = fid.readline().split()
240 | # xys = np.column_stack([tuple(map(float, elems[0::3])),
241 | # tuple(map(float, elems[1::3]))])
242 | # point3D_ids = np.array(tuple(map(int, elems[2::3])))
243 | # images[image_id] = Image(
244 | # id=image_id, qvec=qvec, tvec=tvec,
245 | # camera_id=camera_id, name=image_name,
246 | # xys=xys, point3D_ids=point3D_ids)
247 | images[image_id] = Image(
248 | id=image_id, qvec=qvec, tvec=tvec,
249 | camera_id=camera_id, name=image_name,
250 | xys={}, point3D_ids={})
251 | return images
252 |
253 |
254 | def read_images_binary(path_to_model_file):
255 | """
256 | see: src/base/reconstruction.cc
257 | void Reconstruction::ReadImagesBinary(const std::string& path)
258 | void Reconstruction::WriteImagesBinary(const std::string& path)
259 | """
260 | images = {}
261 | with open(path_to_model_file, "rb") as fid:
262 | num_reg_images = read_next_bytes(fid, 8, "Q")[0]
263 | for _ in range(num_reg_images):
264 | binary_image_properties = read_next_bytes(
265 | fid, num_bytes=64, format_char_sequence="idddddddi")
266 | image_id = binary_image_properties[0]
267 | qvec = np.array(binary_image_properties[1:5])
268 | tvec = np.array(binary_image_properties[5:8])
269 | camera_id = binary_image_properties[8]
270 | image_name = ""
271 | current_char = read_next_bytes(fid, 1, "c")[0]
272 | while current_char != b"\x00": # look for the ASCII 0 entry
273 | image_name += current_char.decode("utf-8")
274 | current_char = read_next_bytes(fid, 1, "c")[0]
275 | num_points2D = read_next_bytes(fid, num_bytes=8,
276 | format_char_sequence="Q")[0]
277 | x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
278 | format_char_sequence="ddq"*num_points2D)
279 | xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
280 | tuple(map(float, x_y_id_s[1::3]))])
281 | point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
282 | images[image_id] = Image(
283 | id=image_id, qvec=qvec, tvec=tvec,
284 | camera_id=camera_id, name=image_name,
285 | xys={}, point3D_ids={})
286 | return images
287 |
288 |
289 | def read_points3D_text(path):
290 | """
291 | see: src/colmap/scene/reconstruction.cc
292 | void Reconstruction::ReadPoints3DText(const std::string& path)
293 | void Reconstruction::WritePoints3DText(const std::string& path)
294 | """
295 | points3D = {}
296 | with open(path, "r") as fid:
297 | while True:
298 | line = fid.readline()
299 | if not line:
300 | break
301 | line = line.strip()
302 | if len(line) > 0 and line[0] != "#":
303 | elems = line.split()
304 | point3D_id = int(elems[0])
305 | xyz = np.array(tuple(map(float, elems[1:4])))
306 | rgb = np.array(tuple(map(int, elems[4:7])))
307 | error = float(elems[7])
308 | image_ids = np.array(tuple(map(int, elems[8::2])))
309 | point2D_idxs = np.array(tuple(map(int, elems[9::2])))
310 | points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
311 | error=error, image_ids=image_ids,
312 | point2D_idxs=point2D_idxs)
313 | return points3D
314 |
315 |
316 | def read_points3D_binary(path_to_model_file):
317 | """
318 | see: src/colmap/scene/reconstruction.cc
319 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
320 | void Reconstruction::WritePoints3DBinary(const std::string& path)
321 | """
322 | points3D = {}
323 | with open(path_to_model_file, "rb") as fid:
324 | num_points = read_next_bytes(fid, 8, "Q")[0]
325 | for _ in range(num_points):
326 | binary_point_line_properties = read_next_bytes(
327 | fid, num_bytes=43, format_char_sequence="QdddBBBd")
328 | point3D_id = binary_point_line_properties[0]
329 | xyz = np.array(binary_point_line_properties[1:4])
330 | rgb = np.array(binary_point_line_properties[4:7])
331 | error = np.array(binary_point_line_properties[7])
332 | track_length = read_next_bytes(
333 | fid, num_bytes=8, format_char_sequence="Q")[0]
334 | track_elems = read_next_bytes(
335 | fid, num_bytes=8*track_length,
336 | format_char_sequence="ii"*track_length)
337 | image_ids = np.array(tuple(map(int, track_elems[0::2])))
338 | point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
339 | points3D[point3D_id] = Point3D(
340 | id=point3D_id, xyz=xyz, rgb=rgb,
341 | error=error, image_ids=image_ids,
342 | point2D_idxs=point2D_idxs)
343 | return points3D
344 |
345 |
346 | def detect_model_format(path, ext):
347 | if os.path.isfile(os.path.join(path, "cameras" + ext)) and \
348 | os.path.isfile(os.path.join(path, "images" + ext)):
349 | print("Detected model format: '" + ext + "'")
350 | return True
351 |
352 | return False
353 |
354 |
355 | def read_model_nopoints(path, ext=""):
356 | # try to detect the extension automatically
357 | if ext == "":
358 | if detect_model_format(path, ".bin"):
359 | ext = ".bin"
360 | elif detect_model_format(path, ".txt"):
361 | ext = ".txt"
362 | else:
363 | print("Provide model format: '.bin' or '.txt'")
364 | return
365 |
366 | if ext == ".txt":
367 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
368 | images = read_images_text(os.path.join(path, "images" + ext))
369 | else:
370 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
371 | images = read_images_binary(os.path.join(path, "images" + ext))
372 | return cameras, images
373 |
374 |
375 | def read_model(path, ext=""):
376 | # try to detect the extension automatically
377 | if ext == "":
378 | if detect_model_format(path, ".bin"):
379 | ext = ".bin"
380 | elif detect_model_format(path, ".txt"):
381 | ext = ".txt"
382 | else:
383 | print("Provide model format: '.bin' or '.txt'")
384 | return
385 |
386 | if ext == ".txt":
387 | cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
388 | images = read_images_text(os.path.join(path, "images" + ext))
389 | points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
390 | else:
391 | cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
392 | images = read_images_binary(os.path.join(path, "images" + ext))
393 | points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
394 | return cameras, images, points3D
395 |
396 |
397 | def write_cameras_text(cameras, path, header=True):
398 | """
399 | see: src/base/reconstruction.cc
400 | void Reconstruction::WriteCamerasText(const std::string& path)
401 | void Reconstruction::ReadCamerasText(const std::string& path)
402 | """
403 | HEADER = "# Camera list with one line of data per camera:\n" + \
404 | "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n" + \
405 | "# Number of cameras: {}\n".format(len(cameras))
406 | with open(path, "w") as fid:
407 | if header:
408 | fid.write(HEADER)
409 | for _, cam in cameras.items():
410 | to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
411 | line = " ".join([str(elem) for elem in to_write])
412 | fid.write(line + "\n")
413 |
414 |
415 | def write_cameras_binary(cameras, path_to_model_file):
416 | """
417 | see: src/colmap/scene/reconstruction.cc
418 | void Reconstruction::WriteCamerasBinary(const std::string& path)
419 | void Reconstruction::ReadCamerasBinary(const std::string& path)
420 | """
421 | with open(path_to_model_file, "wb") as fid:
422 | write_next_bytes(fid, len(cameras), "Q")
423 | for _, cam in cameras.items():
424 | model_id = CAMERA_MODEL_NAMES[cam.model].model_id
425 | camera_properties = [cam.id,
426 | model_id,
427 | cam.width,
428 | cam.height]
429 | write_next_bytes(fid, camera_properties, "iiQQ")
430 | for p in cam.params:
431 | write_next_bytes(fid, float(p), "d")
432 | return cameras
433 |
434 |
435 | def write_images_text(images, path):
436 | """
437 | see: src/colmap/scene/reconstruction.cc
438 | void Reconstruction::ReadImagesText(const std::string& path)
439 | void Reconstruction::WriteImagesText(const std::string& path)
440 | """
441 | if len(images) == 0:
442 | mean_observations = 0
443 | else:
444 | mean_observations = sum((len(img.point3D_ids) for _, img in images.items()))/len(images)
445 | HEADER = "# Image list with two lines of data per image:\n" + \
446 | "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n" + \
447 | "# POINTS2D[] as (X, Y, POINT3D_ID)\n" + \
448 | "# Number of images: {}, mean observations per image: {}\n".format(len(images), mean_observations)
449 |
450 | with open(path, "w") as fid:
451 | fid.write(HEADER)
452 | for _, img in images.items():
453 | image_header = [img.id, *img.qvec, *img.tvec, img.camera_id, img.name]
454 | first_line = " ".join(map(str, image_header))
455 | fid.write(first_line + "\n")
456 |
457 | points_strings = []
458 | for xy, point3D_id in zip(img.xys, img.point3D_ids):
459 | points_strings.append(" ".join(map(str, [*xy, point3D_id])))
460 | fid.write(" ".join(points_strings) + "\n")
461 |
462 |
463 | def write_images_binary(images, path_to_model_file):
464 | """
465 | see: src/colmap/scene/reconstruction.cc
466 | void Reconstruction::ReadImagesBinary(const std::string& path)
467 | void Reconstruction::WriteImagesBinary(const std::string& path)
468 | """
469 | with open(path_to_model_file, "wb") as fid:
470 | write_next_bytes(fid, len(images), "Q")
471 | for _, img in images.items():
472 | write_next_bytes(fid, img.id, "i")
473 | write_next_bytes(fid, img.qvec.tolist(), "dddd")
474 | write_next_bytes(fid, img.tvec.tolist(), "ddd")
475 | write_next_bytes(fid, img.camera_id, "i")
476 | for char in img.name:
477 | write_next_bytes(fid, char.encode("utf-8"), "c")
478 | write_next_bytes(fid, b"\x00", "c")
479 | write_next_bytes(fid, len(img.point3D_ids), "Q")
480 | for xy, p3d_id in zip(img.xys, img.point3D_ids):
481 | write_next_bytes(fid, [*xy, p3d_id], "ddq")
482 |
483 |
484 | def write_points3D_text(points3D, path):
485 | """
486 | see: src/colmap/scene/reconstruction.cc
487 | void Reconstruction::ReadPoints3DText(const std::string& path)
488 | void Reconstruction::WritePoints3DText(const std::string& path)
489 | """
490 | if len(points3D) == 0:
491 | mean_track_length = 0
492 | else:
493 | mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
494 | HEADER = "# 3D point list with one line of data per point:\n" + \
495 | "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
496 | "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)
497 |
498 | with open(path, "w") as fid:
499 | fid.write(HEADER)
500 | for _, pt in points3D.items():
501 | point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
502 | fid.write(" ".join(map(str, point_header)) + " ")
503 | track_strings = []
504 | for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
505 | track_strings.append(" ".join(map(str, [image_id, point2D])))
506 | fid.write(" ".join(track_strings) + "\n")
507 |
508 |
509 | def write_points3D_binary(points3D, path_to_model_file):
510 | """
511 | see: src/colmap/scene/reconstruction.cc
512 | void Reconstruction::ReadPoints3DBinary(const std::string& path)
513 | void Reconstruction::WritePoints3DBinary(const std::string& path)
514 | """
515 | with open(path_to_model_file, "wb") as fid:
516 | write_next_bytes(fid, len(points3D), "Q")
517 | for _, pt in points3D.items():
518 | write_next_bytes(fid, pt.id, "Q")
519 | write_next_bytes(fid, pt.xyz.tolist(), "ddd")
520 | write_next_bytes(fid, pt.rgb.tolist(), "BBB")
521 | write_next_bytes(fid, pt.error, "d")
522 | track_length = pt.image_ids.shape[0]
523 | write_next_bytes(fid, track_length, "Q")
524 | for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
525 | write_next_bytes(fid, [image_id, point2D_id], "ii")
526 |
527 |
528 | def write_model(cameras, images, points3D, path, ext=".bin"):
529 | if ext == ".txt":
530 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
531 | write_images_text(images, os.path.join(path, "images" + ext))
532 | write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
533 | else:
534 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
535 | write_images_binary(images, os.path.join(path, "images" + ext))
536 | write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
537 | return cameras, images, points3D
538 |
539 |
540 | def write_model_nopoints(cameras, images, path, ext=".txt"):
541 | if ext == ".txt":
542 | write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
543 | write_images_text(images, os.path.join(path, "images" + ext))
544 | else:
545 | write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
546 | write_images_binary(images, os.path.join(path, "images" + ext))
547 |
--------------------------------------------------------------------------------
/gloc/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from math import ceil
3 | import numpy as np
4 | import einops
5 | import faiss
6 | from os.path import join
7 |
8 | from gloc.utils import qvec2rotmat, rotmat2qvec
9 |
10 | threshs_t = [ 0.25, 0.5, 5.0, 10.0]
11 | threshs_R = [2.0, 5.0, 10.0, 15.0]
12 |
13 |
14 | def log_pose_estimate(render_dir, pd, pred_R, pred_t, flat_preds=None, top_ns=[3, 6]):
15 | f_results = join(render_dir, 'est_poses.txt')
16 | is_aachen = 'Aachen' in pd.name
17 | print(f'WRITING TO {f_results}')
18 | with open(f_results, 'w') as f:
19 | for q_idx in range(len(pd.q_frames_idxs)):
20 | idx = pd.q_frames_idxs[q_idx]
21 | name = pd.images[idx].name
22 | if flat_preds is not None:
23 | qvec = rotmat2qvec(pred_R[q_idx][flat_preds[q_idx,0]])
24 | tvec = pred_t[q_idx][flat_preds[q_idx,0]]
25 | else:
26 | qvec = rotmat2qvec(pred_R[q_idx][0])
27 | tvec = pred_t[q_idx][0]
28 |
29 | if is_aachen:
30 | name = os.path.basename(name)
31 | qvec = ' '.join(map(str, qvec))
32 | tvec = ' '.join(map(str, tvec))
33 | f.write(f'{name} {qvec} {tvec}\n')
34 |
35 | for topn in top_ns:
36 | tn_results = join(render_dir, f'top{topn}_est_poses.txt')
37 | print(f'WRITING TO {tn_results}')
38 |
39 | with open(tn_results, 'w') as f:
40 | for q_idx in range(len(pd.q_frames_idxs)):
41 | idx = pd.q_frames_idxs[q_idx]
42 | name = pd.images[idx].name
43 | for i in range(topn):
44 | if flat_preds is not None:
45 | qvec = rotmat2qvec(pred_R[q_idx][flat_preds[q_idx, i]])
46 | tvec = pred_t[q_idx][flat_preds[q_idx, i]]
47 | else:
48 | qvec = rotmat2qvec(pred_R[q_idx, i])
49 | tvec = pred_t[q_idx, i]
50 |
51 | name = os.path.basename(name)
52 | qvec = ' '.join(map(str, qvec))
53 | tvec = ' '.join(map(str, tvec))
54 | f.write(f'{name} {qvec} {tvec}\n')
55 |
56 | return f_results
57 |
58 |
59 | def load_pose_prior(pose_file, pd, M=1):
60 | with open(pose_file, 'r') as fp:
61 | est_poses = fp.readlines()
62 | # the format is: 'basename_img.ext qw qx qy qz tx ty tz\n'
63 | est_poses = list(map(lambda x: x.strip().split(' '), est_poses))
64 | poses_dict = {}
65 | for pose in est_poses:
66 | qvec_float = list(map(float, pose[1:5]))
67 | tvec_float = list(map(float, pose[5:8]))
68 | if pose[0] not in poses_dict:
69 | poses_dict[pose[0]] = []
70 |
71 | poses_dict[pose[0]].append( (np.array(qvec_float), np.array(tvec_float)) )
72 |
73 | all_pred_t, all_pred_R = np.empty((pd.n_q, M, 3)), np.empty((pd.n_q, M, 3, 3))
74 |
75 | if pd.name in ['Aachen_night', 'Aachen_day', 'Aachen_real', 'Aachen_real_und']:
76 | get_q_key = lambda x: os.path.basename(x)
77 | else:
78 | get_q_key = lambda x: x
79 | for q_idx in range(len(pd.q_frames_idxs)):
80 | idx = pd.q_frames_idxs[q_idx]
81 | #q_key = os.path.basename(pd.images[idx].name)
82 | q_key = get_q_key(pd.images[idx].name)
83 | poses_q = poses_dict[q_key]
84 |
85 | if len(poses_q) == 1:
86 | qvec, tvec = poses_q[0]
87 | R = qvec2rotmat(qvec).reshape(-1, 3, 3)
88 | R_rp = np.repeat(R, M, axis=0)
89 | tvec_rp = np.repeat(tvec.reshape(-1, 3), M, axis=0)
90 |
91 | all_pred_t[q_idx] = tvec_rp
92 | all_pred_R[q_idx] = R_rp
93 |
94 | else:
95 | assert len(poses_q) >= M, f'This query has {len(poses_q)} poses and you asked for {M}'
96 | for i in range(M):
97 | qvec, tvec = poses_q[i]
98 | R = qvec2rotmat(qvec)
99 | tvec = tvec
100 | all_pred_t[q_idx, i] = tvec
101 | all_pred_R[q_idx, i] = R
102 |
103 | return all_pred_t, all_pred_R
104 |
105 |
106 | def reshape_preds_per_beam(n_beams, M, preds):
107 | n_dims = len(preds.shape)
108 | to_stack = []
109 | for i in range(n_beams):
110 | # with n-beams, beam i gets the i-th cand., and the i+n, i+2, so on
111 | # so take one every n
112 | ts = preds[:,i::n_beams][:,:M, :]
113 | to_stack.append(ts)
114 |
115 | if n_dims == 3:
116 | stacked = np.hstack(to_stack).reshape(-1, n_beams, M, 3)
117 | elif n_dims == 4:
118 | stacked = np.hstack(to_stack).reshape(-1, n_beams, M, 3, 3)
119 |
120 | return stacked
121 |
122 |
123 | def repeat_first_preds_per_beam(n_beams, M, preds):
124 | first_preds = preds[:, :n_beams, :]
125 | # go from (Q, n_beams, 3/3,3), to (Q, n_beams, 1, 3/3,3)
126 | # so that then the first pred. in each beam can be repeated N times
127 | first_preds = np.expand_dims(first_preds, axis=2)
128 | result = np.repeat(first_preds, M, axis=2)
129 |
130 | return result
131 |
132 |
133 | def get_n_steps(num_queries, render_per_step, max_steps, renderer, hard_stop):
134 | if hard_stop > 0:
135 | return hard_stop
136 | if renderer != 'o3d':
137 | return max_steps
138 | # due to open3d bug, black images after 1e5 renders
139 | max_renders = 1e5
140 | n_steps = ceil(max_renders / (num_queries*render_per_step))
141 | return n_steps
142 |
143 |
144 | def eval_poses(errors_t, errors_R, descr=''):
145 | med_t = np.median(errors_t)
146 | med_R = np.median(errors_R)
147 | out = f'Results {descr}:'
148 | out += f'\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg'
149 | out_vals = []
150 |
151 | out += '\nPercentage of test images localized within:'
152 | for th_t, th_R in zip(threshs_t, threshs_R):
153 | ratio = np.mean((errors_t < th_t) & (errors_R < th_R))
154 | out += f'\n\t{th_t:.2f}m, {th_R:.0f}deg : {ratio*100:.2f}%'
155 | out_vals.append(ratio)
156 |
157 | return out, np.array(out_vals)
158 |
159 |
160 | def eval_poses_top_n(all_errors_t, all_errors_R, descr=''):
161 | best_candidates = [1, 5, max(20, all_errors_R.shape[1])]
162 |
163 | med_t = np.median(all_errors_t[:, 0])
164 | med_best_t = np.median(all_errors_t.min(axis=1))
165 |
166 | med_R = np.median(all_errors_R[:, 0])
167 | med_best_R = np.median(all_errors_R.min(axis=1))
168 |
169 | out = f'Results {descr}:'
170 | out += f'\nMedian errors on first/best: {med_t:.2f}m, {med_R:.2f}deg // {med_best_t:.2f}m, {med_best_R:.2f}deg'
171 | out_vals = np.zeros((len(best_candidates), len(threshs_t)))
172 |
173 | out += f"\nPercentage of test images localized within (TOP {'|TOP '.join(map(str, best_candidates))}):"
174 | for i, (th_t, th_R) in enumerate(zip(threshs_t, threshs_R)):
175 | out += f'\n\t{th_t:.2f}m, {int(th_R):2d}deg :'
176 | for j, best in enumerate(best_candidates):
177 | ratio = np.mean( ((all_errors_t[:, :best] < th_t) & (all_errors_R[:, :best] < th_R)).any(axis=1) )
178 | out += f' {ratio*100:4.1f}% |'
179 | out_vals[j, i] = ratio
180 |
181 | return out, out_vals
182 |
183 |
184 | def get_predictions(db_descriptors, q_descriptors, pose_dataset, fc_output_dim=512, top_k=20):
185 | pd = pose_dataset
186 |
187 | # Use a kNN to find predictions
188 | faiss_index = faiss.IndexFlatL2(fc_output_dim)
189 | faiss_index.add(db_descriptors)
190 |
191 | _, predictions = faiss_index.search(q_descriptors, top_k)
192 | pred_Rs = []
193 | pred_ts = []
194 | for q_idx in range(len(pd.q_frames_idxs)):
195 | pred_Rs_per_query = []
196 | pred_ts_per_query = []
197 | for k in range(top_k):
198 | pred_idx = pd.db_frames_idxs[predictions[q_idx, k]]
199 | R = qvec2rotmat(pd.images[pred_idx].qvec)
200 | t = pd.images[pred_idx].tvec
201 | pred_Rs_per_query.append(R)
202 | pred_ts_per_query.append(t)
203 |
204 | pred_Rs.append(np.array(pred_Rs_per_query))
205 | pred_ts.append(np.array(pred_ts_per_query))
206 |
207 | return np.array(pred_ts), np.array(pred_Rs)
208 |
209 |
210 | def get_error(R, t, R_gt, t_gt):
211 | e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0)
212 | cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1., 1.)
213 | e_R = np.rad2deg(np.abs(np.arccos(cos)))
214 |
215 | return e_t, e_R
216 |
217 |
218 | def get_errors_from_preds(true_t, true_R, pred_t, pred_R, top_k=20):
219 | errors_t = []
220 | errors_R = []
221 |
222 | top_k = min(top_k, pred_t.shape[1])
223 | for k in range(top_k):
224 | e_t, e_R = get_error(pred_R[0, k], pred_t[0, k], true_R[0], true_t[0])
225 |
226 | errors_t.append(e_t)
227 | errors_R.append(e_R)
228 | errors_t, errors_R = np.array(errors_t), np.array(errors_R)
229 |
230 | return errors_t, errors_R
231 |
232 |
233 | def get_all_errors_first_estimate(true_t, true_R, pred_t, pred_R):
234 | errors_t = []
235 | errors_R = []
236 | n_queries = len(pred_R)
237 | for q_idx in range(n_queries):
238 | e_t, e_R = get_error(pred_R[q_idx, 0], pred_t[q_idx, 0], true_R[q_idx], true_t[q_idx])
239 |
240 | errors_t.append(e_t)
241 | errors_R.append(e_R)
242 | errors_t = np.array(errors_t)
243 | errors_R = np.array(errors_R)
244 |
245 | return errors_t, errors_R
246 |
247 |
248 | def get_pose_from_preds_w_truth(q_idx, pd, rd, predictions, top_k=20):
249 | true_Rs = []
250 | true_ts = []
251 | pred_Rs = []
252 | pred_ts = []
253 |
254 | idx = pd.q_frames_idxs[q_idx]
255 | R_gt = qvec2rotmat(pd.images[idx].qvec)
256 | t_gt = pd.images[idx].tvec
257 | true_Rs.append(R_gt)
258 | true_ts.append(t_gt)
259 |
260 | pred_Rs_per_query = []
261 | pred_ts_per_query = []
262 |
263 | top_k = min(top_k, len(predictions))
264 | for k in range(top_k):
265 | pred_idx = predictions[k]
266 |
267 | R = qvec2rotmat(rd.images[pred_idx].qvec)
268 | t = rd.images[pred_idx].tvec
269 | pred_Rs_per_query.append(R)
270 | pred_ts_per_query.append(t)
271 |
272 | pred_Rs.append(np.array(pred_Rs_per_query))
273 | pred_ts.append(np.array(pred_ts_per_query))
274 |
275 | true_t, true_R, pred_t, pred_R= np.array(true_ts), np.array(true_Rs), np.array(pred_ts), np.array(pred_Rs)
276 | return true_t, true_R, pred_t, pred_R
277 |
278 |
279 | def get_pose_from_preds(q_idx, pd, rd, predictions, top_k=20):
280 | pred_Rs = []
281 | pred_ts = []
282 |
283 | pred_Rs_per_query = []
284 | pred_ts_per_query = []
285 |
286 | top_k = min(top_k, len(predictions))
287 | for k in range(top_k):
288 | pred_idx = predictions[k]
289 |
290 | R = qvec2rotmat(rd.images[pred_idx].qvec)
291 | t = rd.images[pred_idx].tvec
292 | pred_Rs_per_query.append(R)
293 | pred_ts_per_query.append(t)
294 |
295 | pred_Rs.append(np.array(pred_Rs_per_query))
296 | pred_ts.append(np.array(pred_ts_per_query))
297 |
298 | pred_t, pred_R= np.array(pred_ts), np.array(pred_Rs)
299 | return pred_t, pred_R
300 |
301 |
302 | def sort_preds_across_beams(all_scores, all_pred_t, all_pred_R, all_errors_t, all_errors_R):
303 | # flatten stuff to sort predictions based on similarity
304 | flat_err = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)')
305 | flat_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3)
306 | flat_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3)
307 | flat_preds = np.argsort(flat_err(all_scores))
308 | all_errors_t = np.take_along_axis(flat_err(all_errors_t), flat_preds, axis=1)
309 | all_errors_R = np.take_along_axis(flat_err(all_errors_R), flat_preds, axis=1)
310 | flat_pred_t = flat_t(all_pred_t)
311 | flat_pred_R = flat_R(all_pred_R)
312 |
313 | return flat_pred_R, flat_pred_t, flat_preds, all_errors_t, all_errors_R
314 |
315 |
316 | def update_scores(scores, scores_temp):
317 | if scores is None:
318 | scores = scores_temp
319 | else:
320 | scores['steps'] += scores_temp['steps']
321 |
322 | return scores
323 |
--------------------------------------------------------------------------------
/gloc/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from os.path import join
4 | import matplotlib as mpl
5 | import matplotlib.pyplot as plt
6 | import cv2
7 |
8 | from gloc.utils.utils import threshs_R, threshs_t
9 |
10 |
11 | def plot_error_distr(err_R, err_t, step_num, save_dir, f_name):
12 | fig, axi = plt.subplots(1, 2, figsize=(12,6), dpi=80)
13 |
14 | meds = [np.median(err_R[:,0]), np.median(err_t[:,0])]
15 | for i, errors in enumerate([err_R, err_t]):
16 | ax = axi[i]
17 | for label in (ax.get_xticklabels() + ax.get_yticklabels()):
18 | label.set_fontsize(16)
19 | med = meds[i]
20 | ax.hist(errors[:, 0], bins=50, range=(0, 20), alpha=0.8)
21 | ax.vlines(med, 0, 50, colors='red', label='50th perc.')
22 | ax.legend(prop={"size":15})
23 |
24 | fig.suptitle(f'Step {step_num}\nMedian errors: ( {meds[1]:.1f} m, {meds[0]:.1f}° )', fontsize=18)
25 | plt.tight_layout()
26 | plt.savefig(join(save_dir, f_name), bbox_inches='tight', pad_inches=0.0)
27 | plt.close()
28 |
29 |
30 | def plot_scores(scores, out_dir):
31 | threshs = list(zip(threshs_t, threshs_R))
32 |
33 | steps = np.array(scores['steps'])
34 | fig, axis = plt.subplots(1, 2, figsize=(14, 7), dpi=100)
35 | x=list(range(len(steps)))
36 |
37 | for i, idx in enumerate([1, 2]):
38 | ax = axis[i]
39 |
40 | ax.plot(x, steps[:, 0, idx], label=f'Recall')
41 | ax.plot(x, steps[:, 2, idx], label=f'Upper bound')
42 | ax.hlines(scores['baseline'][idx], x[0], x[-1], colors='red', linestyles='dashed', label='baseline')
43 | ax.set_title(f'Threshold= {threshs[idx]}')
44 | ax.legend()
45 | plt.tight_layout()
46 | plt.savefig(f'{out_dir}/scores_plot.png')
47 |
--------------------------------------------------------------------------------
/parse_args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os.path import join
3 | from datetime import datetime
4 |
5 |
6 | def parse_args():
7 | parser = argparse.ArgumentParser(description='Argument parser')
8 | # exp args
9 | parser.add_argument('name', type=str, help='DS name',
10 | choices=['Aachen', 'Aachen_real', 'Aachen_day', 'Aachen_night', 'Aachen_real_und', 'Aachen_small',
11 | 'KingsCollege', 'KingsCollege_und', 'StMarysChurch_und',
12 | 'ShopFacade', 'ShopFacade_und',
13 | 'OldHospital', 'OldHospital_und',
14 | 'chess', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin', 'heads'])
15 | parser.add_argument('--exp_name', type=str, help='log folder', default='default')
16 | parser.add_argument('--res', type=int, help='resolution', default=320)
17 | parser.add_argument('--seed', type=int, help='seed', default=0)
18 | parser.add_argument('--first_step', type=int, help='start from', default=None)
19 | parser.add_argument('--hard_stop', type=int, help='interrupt at step N, but dont consider it for scaling noise', default=-1)
20 | parser.add_argument('--resume_step', type=str, help='resume folder', default=None)
21 | parser.add_argument('--save_feats', action='store_true', help='seed', default=False)
22 | parser.add_argument('--pose_prior', type=str, help='start from a pose prior in this file', default=None)
23 | parser.add_argument('--clean_logs', action='store_true', help='remove renderings in the end', default=False)
24 | parser.add_argument('--chunk_size', type=int, help='n feats at a time', default=1100)
25 |
26 | # model args
27 | parser.add_argument('--retr_model', type=str, help='retrieval model', default='cosplace', choices=['cosplace'])
28 | parser.add_argument('--ref_model', type=str, help='fine model', default='DenseFeatures',
29 | choices=['DenseFeatures'])
30 | parser.add_argument('--feat_model', type=str, help='refinement model arch', default='cosplace',
31 | choices=['',
32 | 'cosplace_r18_l1', 'cosplace_r18_l2', 'cosplace_r18_l3',
33 | 'cosplace_r50_l1', 'cosplace_r50_l2', 'cosplace_r50_l3',
34 | 'resnet18_l1', 'resnet18_l2', 'resnet18_l3',
35 | 'resnet50_l1', 'resnet50_l2', 'resnet50_l3',
36 | 'alexnet_l1', 'alexnet_l2', 'alexnet_l3',
37 | 'Dinov2', 'Roma',
38 | ])
39 |
40 | # fine models args
41 | parser.add_argument('--clamp_score', type=float, help='thresholded scoring function', default=-1)
42 | parser.add_argument('--feat_level', nargs='+', type=int, help='Level of features for ALIKED', default=[-1])
43 | parser.add_argument('--scale_fmaps', type=int, help='Scale F.maps to 1/n', default=6)
44 |
45 | # path args
46 | parser.add_argument("--storage_dir", type=str, default='/storage/gtrivigno/vloc/renderings', help='model path')
47 | parser.add_argument("--fix_storage", action='store_true', default=False, help='model path')
48 |
49 | # render args
50 | parser.add_argument('--colmap_res', type=int, help='res', default=320)
51 | parser.add_argument('--mesh', type=str, help='mesh type', choices=['colored', 'colored_14', 'colored_15', 'textured'], default='colored')
52 | parser.add_argument('--renderer', type=str, help='renderer type', choices=['o3d', 'nerf', 'g_splatting'], default='o3d')
53 |
54 | # perturb args
55 | parser.add_argument('-pt', '--protocol', type=str, help='protocol',
56 | choices=['1_0', '1_1', '2_0', '2_1'], default='2_1')
57 | parser.add_argument('--sampler', type=str, help='sampler', default='rand',
58 | choices=['rand', 'rand_yaw_or_pitch', 'rand_yaw_and_pitch', 'rand_and_yaw_and_pitch'])
59 | parser.add_argument('--beams', type=int, help='N. beams to optimize independetly', default=1)
60 | parser.add_argument('--steps', type=int, help='iterations', default=20)
61 | parser.add_argument('--teta', nargs='+', type=float, help='max angle', default=[8])
62 | parser.add_argument('--center_std', nargs='+', type=float, default=[1., 1., 1.])
63 | parser.add_argument('--N', type=int, help='N views to render in total, per query', default=20)
64 | parser.add_argument('--M', type=int, help='In each beam, perturb the first M rather than only first cand.', default=4)
65 | parser.add_argument('--gamma', type=float, help='min scale', default=0.1)
66 |
67 | # eval scripts
68 | parser.add_argument("--eval_renders_dir", type=str, default='', help='eval render dir')
69 |
70 | # random args
71 | parser.add_argument('--only_step', type=int, help='iterations', default=-1)
72 |
73 | args = parser.parse_args()
74 | args.save_dir = join("logs", args.exp_name, datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
75 |
76 | ## some consistency checks
77 | assert args.N % args.beams == 0, 'N (total views to rend) has to be a multiple of N. beams'
78 | if args.mesh == 'textured' and not args.name.startswith('Aachen'):
79 | raise ValueError('Textured mesh is only available for Aachen')
80 | if args.protocol[0] not in ['0', '1']:
81 | assert args.N % args.M == 0, f'In protocol 2, N ({args.N}) has to be a multiple of M ({args.M})'
82 | assert (args.N // args.beams) % args.M == 0, f'In protocols with M!=1, N/beams ({args.N//args.beams}) has to be a multiple of M ({args.M})'
83 | if 'yaw_and_pitch' in args.sampler:
84 | assert len(args.teta) == 2, f'Sampler {args.sampler} requires 2 angles, 1 for yaw, 1 for pitch'
85 | else:
86 | assert len(args.teta) == 1, f'Sampler {args.sampler} requires only 1 angle'
87 |
88 | return args
--------------------------------------------------------------------------------
/path_configs.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 |
4 | def get_paths(base_path, colmap_res, mesh_type):
5 | aachen_meshes = {
6 | 'colored': 'AC13_colored.ply',
7 | 'colored_14': 'AC14_colored.ply',
8 | 'colored_15': 'AC15_colored.ply',
9 | 'textured': 'AC13-C_textured/AC13-C_textured.obj',
10 | }
11 |
12 | paths_conf = {
13 | # Aachen
14 | 'Aachen': {
15 | 'root': f'{base_path}/Aachen-Day-Night/images/images_upright',
16 | 'colmap': f'{base_path}/all_colmaps/{colmap_res}_undist',
17 | 'q_file': '',
18 | 'db_file': '',
19 | 'q_intrinsics': [f'{base_path}/Aachen-Day-Night/queries/undist_rs{colmap_res}_allquery_intrinsics.txt'],
20 | 'mesh_path': f'{base_path}/meshes',
21 | },
22 | }
23 |
24 | # Cambridge scenes
25 | cambridge_scenes = ['StMarysChurch', 'OldHospital', 'KingsCollege', 'ShopFacade']
26 | gs_models = {
27 | 'StMarysChurch': 'gauss_church',
28 | 'OldHospital': 'gauss_hosp',
29 | 'KingsCollege': 'gauss_kings',
30 | 'ShopFacade': 'gauss_shop',
31 | }
32 | for cs in cambridge_scenes:
33 | paths_conf[cs] = {
34 | 'root': f'{base_path}/{cs}',
35 | 'colmap': f'{base_path}/all_colmaps/{cs}/{colmap_res}_undist/sparse/0',
36 | 'q_file': f'{base_path}/{cs}/dataset_test.txt',
37 | 'db_file': f'{base_path}/{cs}/dataset_train.txt',
38 | 'mesh_path': NotImplemented,
39 | 'ns_config': NotImplemented,
40 | 'ns_transform': NotImplemented,
41 | 'gaussians': f'{base_path}/cambridge_splats/{gs_models[cs]}',
42 | }
43 |
44 | # 7scenes
45 | scenes7 = ['chess', 'office', 'fire', 'stairs', 'redkitchen', 'pumpkin', 'heads']
46 | for sc in scenes7:
47 | paths_conf[sc] = {
48 | 'root': f'{base_path}/7scenes/{sc}',
49 | 'colmap': f'{base_path}/all_colmaps/{sc}/colmap_{colmap_res}',
50 | 'q_file': f'{base_path}/7scenes/{sc}/test.txt',
51 | 'db_file': f'{base_path}/7scenes/{sc}/train.txt',
52 | 'mesh_path': NotImplemented,
53 | 'ns_config': NotImplemented,
54 | 'ns_transform': NotImplemented,
55 | 'gaussians': f'{base_path}/7scenes_splats/{sc}',
56 | }
57 | ######################################
58 |
59 | paths_conf['Aachen']['mesh_path'] = join(paths_conf['Aachen']['mesh_path'], aachen_meshes.get(mesh_type, ''))
60 |
61 | return paths_conf
62 |
63 |
64 | def get_path_conf(colmap_res, mesh_type):
65 | base_path = 'data'
66 | temp_path = 'data/temp'
67 |
68 | conf = get_paths(base_path, colmap_res, mesh_type)
69 | conf['temp'] = temp_path
70 | return conf
71 |
--------------------------------------------------------------------------------
/refine_pose.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import shutil
5 | import torch
6 | from os.path import join
7 | from tqdm import tqdm
8 | import numpy as np
9 | from torch.utils.data.dataset import Subset
10 | import torchvision.transforms as T
11 |
12 | import commons
13 | from parse_args import parse_args
14 | from path_configs import get_path_conf
15 | from gloc import extraction
16 | from gloc import initialization
17 | from gloc import rendering
18 | from gloc.models import get_ref_model
19 | from gloc.rendering import get_renderer
20 | from gloc.datasets import get_dataset, find_candidates_paths, get_transform, RenderedImagesDataset, ImListDataset
21 | from gloc.utils import utils, visualization
22 | from gloc.resamplers import get_protocol
23 | from configs import get_config
24 |
25 |
26 | def main(args):
27 | commons.make_deterministic(args.seed)
28 | commons.setup_logging(args.save_dir, console="info")
29 | logging.info(" ".join(sys.argv))
30 | logging.info(f"Arguments: {args}")
31 | logging.info(f"The outputs are being saved in {args.save_dir}")
32 |
33 | paths_conf = get_path_conf(args.colmap_res, args.mesh)
34 | temp_dir = join(paths_conf['temp'], args.exp_name)
35 | os.makedirs(temp_dir)
36 |
37 | exp_config = get_config(args.name)
38 | scores = None
39 | for i in range(len(exp_config)):
40 | ref_args = exp_config[i]
41 | args.__dict__.update(ref_args)
42 | scores_temp, render_dir = refinement_loop(args)
43 |
44 | scores = utils.update_scores(scores, scores_temp)
45 | args.pose_prior = join(render_dir, 'est_poses.txt')
46 |
47 | visualization.plot_scores(scores, args.save_dir)
48 |
49 | ### cleaning up...
50 | logging.info(f'Moving rendering from temp dir {temp_dir} to {args.save_dir}')
51 | shutil.move(join(temp_dir, 'renderings'), args.save_dir, copy_function=shutil.move)
52 | shutil.rmtree(temp_dir)
53 | logging.info('Terminating without errors!')
54 |
55 |
56 | def refinement_loop(args):
57 | DS = args.name
58 | res = args.res
59 |
60 | paths_conf = get_path_conf(args.colmap_res, args.mesh)
61 | transform = get_transform(args, paths_conf[DS]['colmap'])
62 | pose_dataset = get_dataset(DS, paths_conf[DS], transform)
63 | temp_dir = join(paths_conf['temp'], args.exp_name)
64 |
65 | first_step, all_pred_t, all_pred_R, scores = initialization.init_refinement(args, pose_dataset)
66 | ######### START REFINEMENT LOOP
67 | N_steps = args.steps
68 | N_per_beam = args.N // args.beams
69 | n_beams = args.beams
70 | N_views = args.N
71 | fine_model = get_ref_model(args)
72 |
73 | logging.info('Recomputing query features with refinement model...')
74 | queries_subset = Subset(pose_dataset, pose_dataset.q_frames_idxs)
75 | q_descriptors = extraction.get_query_features(fine_model, queries_subset)
76 |
77 | resampler = get_protocol(args, N_per_beam, args.protocol)
78 | renderer = get_renderer(args, paths_conf)
79 |
80 | max_step = utils.get_n_steps(pose_dataset.num_queries(), N_views, N_steps, args.renderer, args.hard_stop)
81 | for step in range(first_step, N_steps):
82 | if (step - first_step) == max_step:
83 | logging.info('Stopping due to Open3D bug')
84 | break
85 |
86 | resampler.init_step(step)
87 | center_std, angle_delta = resampler.scaler.get_noise()
88 |
89 | logging.info(f'[||] Starting iteration n.{step+1}/{N_steps} [||]')
90 | logging.info(f'Perturbing poses with Theta {angle_delta} and center STD {center_std}. Resolution {res}')
91 |
92 | if (first_step == step) and (args.resume_step is not None):
93 | render_dir = args.resume_step
94 | logging.info(f'Resuming from step dir {render_dir}...')
95 | else:
96 | perturb_str = resampler.get_pertubr_str(step, res)
97 | render_dir = perturb_step(perturb_str, pose_dataset, renderer, resampler, all_pred_t, all_pred_R, temp_dir, n_beams)
98 | renderer.end_epoch(render_dir)
99 |
100 | (all_pred_t, all_pred_R,
101 | all_errors_t, all_errors_R) = rank_candidates(fine_model, pose_dataset, render_dir, pose_dataset.transform,
102 | q_descriptors, N_per_beam, n_beams, chunk_limit=args.chunk_size)
103 | result_str, results = utils.eval_poses_top_n(all_errors_t, all_errors_R, descr=f'step {step}')
104 | logging.info(result_str)
105 |
106 | scores['steps'].append(results)
107 | torch.save(scores, join(args.save_dir, 'scores.pth'))
108 |
109 | if args.clean_logs:
110 | from clean_logs import main as cl_logs
111 | logging.info('Removing rendering files...')
112 | cl_logs(temp_dir, only_step=step)
113 |
114 | return scores, render_dir
115 |
116 |
117 | def perturb_step(perturb_str, pose_dataset, renderer, resampler, pred_t, pred_R, basepath, n_beams=1):
118 | out_dir = os.path.join(basepath, 'renderings', perturb_str)
119 | os.makedirs(out_dir)
120 | logging.info(f'Generating renders in {out_dir}')
121 |
122 | rend_model = renderer.load_model()
123 | r_names_per_beam = {}
124 | for q_idx in tqdm(range(len(pose_dataset.q_frames_idxs)), ncols=100):
125 | idx = pose_dataset.q_frames_idxs[q_idx]
126 | q_name = pose_dataset.get_basename(idx)
127 | q_key_name = os.path.splitext(pose_dataset.images[idx].name)[0]
128 | r_names_per_beam[q_idx] = {}
129 |
130 | K, w, h = pose_dataset.get_intrinsics(q_key_name)
131 |
132 | r_dir = os.path.join(out_dir, q_name)
133 | os.makedirs(r_dir)
134 | for beam_i in range(n_beams):
135 | beam_dir = join(r_dir, f'beam_{beam_i}')
136 | os.makedirs(beam_dir)
137 | pred_t_beam = pred_t[q_idx, beam_i]
138 | pred_R_beam = pred_R[q_idx, beam_i]
139 |
140 | r_names, render_ts, render_qvecs, calibr_pose = resampler.resample(K, q_name, pred_t_beam, pred_R_beam, q_idx=q_idx, beam_i=beam_i)
141 | r_names_per_beam[q_idx][beam_i] = r_names
142 | # poses have to be logged in 'beam_dir', but rendered in 'r_dir', so that
143 | # they can be rendered all together in 'deferred' mode, thus being more efficient
144 | rendering.log_poses(beam_dir, r_names, render_ts, render_qvecs, args.renderer)
145 | if renderer.supports_deferred_rendering:
146 | to_render_dir = r_dir
147 | else:
148 | to_render_dir = beam_dir
149 | renderer.render_poses(to_render_dir, rend_model, r_names, render_ts, render_qvecs, calibr_pose, (w, h),
150 | deferred=renderer.supports_deferred_rendering)
151 | del rend_model
152 |
153 | renderer.end_epoch(out_dir)
154 | logging.info('Moving each renders into their beams folder')
155 | if renderer.supports_deferred_rendering:
156 | for q_idx in range(len(pose_dataset.q_frames_idxs)):
157 | idx = pose_dataset.q_frames_idxs[q_idx]
158 | q_name = pose_dataset.get_basename(idx)
159 |
160 | r_dir = join(out_dir, q_name)
161 | rendering.split_to_beam_folder(r_dir, n_beams, r_names_per_beam[q_idx])
162 |
163 | return out_dir
164 |
165 |
166 | def rank_candidates(fine_model, pose_dataset, render_dir, transform, q_descriptors, N_per_beam, n_beams, chunk_limit=1100):
167 | all_pred_t = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam, 3))
168 | all_pred_R = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam, 3, 3))
169 | all_errors_t = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam))
170 | all_errors_R = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam))
171 | all_scores = np.empty((len(pose_dataset.q_frames_idxs), n_beams, N_per_beam))
172 |
173 | logging.info(f'Extracting candidates paths')
174 | candidates_pathlist, query_res = find_candidates_paths(pose_dataset, n_beams, render_dir)
175 |
176 | logging.info(f'Found {len(candidates_pathlist)} images for {pose_dataset.n_q} queries, now extracting features altogether')
177 | same_res_transform = T.Compose(transform.transforms.copy())
178 | same_res_transform.transforms[1] = T.Resize(query_res, antialias=True)
179 | imlist_ds = ImListDataset(candidates_pathlist, same_res_transform)
180 |
181 | chunk_start_q_idx, chunk_end_q_idx, chunks = extraction.split_renders_into_chunks(
182 | pose_dataset.n_q, len(imlist_ds), n_beams, N_per_beam, chunk_limit
183 | )
184 | dim = extraction.get_feat_dim(fine_model, query_res)
185 |
186 | logging.info(f'Query splits: {chunk_start_q_idx}, {chunk_end_q_idx}')
187 | logging.info(f'Chunk splits: {[c[-1] for c in chunks]}')
188 | for ic, chunk in enumerate(chunks):
189 | q_idx_start = chunk_start_q_idx[ic]
190 | q_idx_end = chunk_end_q_idx[ic]
191 |
192 | logging.info(f'Chunk n.{ic}')
193 | logging.info(f'Query from {q_idx_start} to {q_idx_end}')
194 | logging.info(f'Images from {chunk[0]} to {chunk[-1]}')
195 |
196 | chunk_ds = Subset(imlist_ds, chunk)
197 | descriptors = extraction.get_candidates_features(fine_model, chunk_ds, dim)
198 |
199 | logging.info(f'Extracted shape {descriptors.shape}, now computing predictions')
200 | for q_idx in tqdm(range(q_idx_start, q_idx_end), ncols=100):
201 | q_name = pose_dataset.get_basename(pose_dataset.q_frames_idxs[q_idx])
202 | query_dir = os.path.join(render_dir, q_name)
203 | q_feats = q_descriptors[q_idx]
204 |
205 | for beam_i in range(n_beams):
206 | beam_dir = join(query_dir, f'beam_{beam_i}')
207 | rd = RenderedImagesDataset(beam_dir, verbose=False)
208 |
209 | start_idx = (q_idx-q_idx_start)*n_beams*N_per_beam + beam_i*N_per_beam
210 | end_idx = start_idx + N_per_beam
211 | r_db_descriptors = descriptors[start_idx:end_idx]
212 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True)
213 | true_t, true_R, pred_t, pred_R = utils.get_pose_from_preds_w_truth(q_idx, pose_dataset, rd, predictions, N_per_beam)
214 | errors_t, errors_R = utils.get_errors_from_preds(true_t, true_R, pred_t, pred_R, N_per_beam)
215 |
216 | all_pred_t[q_idx, beam_i] = pred_t
217 | all_pred_R[q_idx, beam_i] = pred_R
218 | all_errors_t[q_idx, beam_i] = errors_t
219 | all_errors_R[q_idx, beam_i] = errors_R
220 | all_scores[q_idx, beam_i] = scores[:N_per_beam]
221 |
222 | # save scores within each beam so renders can be deleted afterwards
223 | torch.save((predictions, scores), join(beam_dir, 'scores.pth'))
224 |
225 | del q_feats, r_db_descriptors, descriptors
226 |
227 | # sort predictions/errors according the score across beams
228 | # only needed to log and eval poses, the optimization is beam-independent
229 | flat_pred_R, flat_pred_t, flat_preds, all_errors_t, all_errors_R = utils.sort_preds_across_beams(all_scores, all_pred_t, all_pred_R, all_errors_t, all_errors_R)
230 | # log pose estimate
231 | if flat_preds.shape[-1] > 6:
232 | # if there are at least 6 preds per query
233 | logging.info(f'Generating pose file...')
234 | utils.log_pose_estimate(render_dir, pose_dataset, flat_pred_R, flat_pred_t, flat_preds=flat_preds)
235 |
236 | return all_pred_t, all_pred_R, all_errors_t, all_errors_R
237 |
238 |
239 | if __name__ == '__main__':
240 | args = parse_args()
241 | main(args)
242 |
--------------------------------------------------------------------------------
/refine_pose_aachen.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import shutil
5 | import torch
6 | from os.path import join
7 | from tqdm import tqdm
8 | import numpy as np
9 | import einops
10 | from torch.utils.data.dataset import Subset
11 | import torchvision.transforms as T
12 |
13 | import commons
14 | from parse_args import parse_args
15 | from path_configs import get_path_conf
16 | from gloc.models import get_ref_model
17 | from gloc import extraction
18 | from gloc import rendering
19 | from gloc.rendering import get_renderer
20 | from gloc.utils import utils, rotmat2qvec
21 | from gloc.resamplers import get_protocol
22 | from gloc.datasets import RenderedImagesDataset, get_dataset, get_transform
23 |
24 |
25 | def main(args):
26 | DS = args.name
27 | res = args.res
28 |
29 | commons.make_deterministic(args.seed)
30 | commons.setup_logging(args.save_dir, console="info")
31 | logging.info(" ".join(sys.argv))
32 | logging.info(f"Arguments: {args}")
33 | logging.info(f"The outputs are being saved in {args.save_dir}")
34 |
35 | paths_conf = get_path_conf(args.colmap_res, args.mesh)
36 | temp_dir = join(paths_conf['temp'], args.exp_name)
37 | os.makedirs(temp_dir)
38 |
39 | transform = get_transform(args)
40 | pd = get_dataset(DS, paths_conf[DS], transform)
41 |
42 | if args.pose_prior == '':
43 | all_pred_t, all_pred_R = extraction.get_retrieval_predictions(args.model, args.outdim, args.res, pd, topk=args.beams*args.M)
44 | else:
45 | logging.info(f'Loading pose prior from {args.pose_prior}')
46 | all_pred_t, all_pred_R = utils.load_pose_prior(args.pose_prior, pd, args.beams*args.M)
47 |
48 | ######### START REFINEMENT LOOP
49 | N_steps = args.steps
50 | n_beams = args.beams
51 | N_per_beam = args.N // args.beams
52 | M = args.M
53 | fine_model = get_ref_model(args)
54 |
55 | logging.info('Recomputing query features with refinement model...')
56 | queries_subset = Subset(pd, pd.q_frames_idxs)
57 | q_descriptors = extraction.get_features_from_dataset(fine_model, queries_subset, use_tqdm=True)
58 |
59 | resampler = get_protocol(args, N_per_beam, args.protocol)
60 | renderer = get_renderer(args, paths_conf)
61 |
62 | first_step = 0
63 | if args.first_step is not None:
64 | first_step = args.first_step
65 |
66 | max_step = utils.get_n_steps(pd.num_queries(), args.N, N_steps, args.renderer, args.hard_stop)
67 | # go from (NQ, M*beams, 3/3,3) to (NQ, beams, M, 3/3, 3)
68 | all_pred_t = utils.reshape_preds_per_beam(n_beams, M, all_pred_t)
69 | all_pred_R = utils.reshape_preds_per_beam(n_beams, M, all_pred_R)
70 | for step in range(first_step, N_steps):
71 | if (step - first_step) == max_step:
72 | logging.info('Stopping due to Open3D bug')
73 | break
74 |
75 | resampler.init_step(step)
76 | center_std, angle_delta = resampler.scaler.get_noise()
77 |
78 | logging.info(f'[||] Starting iteration n.{step+1}/{N_steps} [||]')
79 | logging.info(f'Perturbing poses with Theta {angle_delta} and center STD {center_std}. Resolution {res}')
80 |
81 | if (first_step == step) and (args.resume_step is not None):
82 | render_dir = args.resume_step
83 | else:
84 | perturb_str = resampler.get_pertubr_str(step, res)
85 | render_dir = perturb_step(perturb_str, pd, renderer, resampler, all_pred_t, all_pred_R, temp_dir, n_beams)
86 |
87 | all_pred_t, all_pred_R = rank_candidates(fine_model, pd, render_dir,
88 | pd.transform, q_descriptors, N_per_beam, n_beams)
89 |
90 | logging.info(f'[!] Concluded iteration n.{step+1}/{N_steps} [!]')
91 |
92 | ### cleaning up...
93 | if args.clean_logs:
94 | from clean_logs import main as cl_logs
95 | logging.info('Removing rendering files...')
96 | cl_logs(temp_dir)
97 |
98 | logging.info(f'Moving rendering from temp dir {temp_dir} to {args.save_dir}')
99 | shutil.move(join(temp_dir, 'renderings'), args.save_dir, copy_function=shutil.move)
100 | shutil.rmtree(temp_dir)
101 | logging.info('Terminating without errors!')
102 |
103 |
104 | def perturb_step(perturb_str, pd, renderer, resampler, pred_t, pred_R, basepath, n_beams=1):
105 | out_dir = os.path.join(basepath, 'renderings', perturb_str)
106 | os.makedirs(out_dir)
107 | logging.info(f'Generating renders in {out_dir}')
108 |
109 | rend_model = renderer.load_model()
110 | r_names_per_beam = {}
111 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100):
112 | idx = pd.q_frames_idxs[q_idx]
113 | q_name = pd.get_basename(idx)
114 | q_key_name = pd.images[idx].name.split('.')[0]
115 | r_names_per_beam[q_idx] = {}
116 |
117 | w = pd.q_intrinsics[q_key_name]['w']
118 | h = pd.q_intrinsics[q_key_name]['h']
119 | K = pd.q_intrinsics[q_key_name]['K']
120 | r_dir = os.path.join(out_dir, q_name)
121 | os.makedirs(r_dir)
122 | for beam_i in range(n_beams):
123 | beam_dir = join(r_dir, f'beam_{beam_i}')
124 | os.makedirs(beam_dir)
125 | pred_t_beam = pred_t[q_idx, beam_i]
126 | pred_R_beam = pred_R[q_idx, beam_i]
127 |
128 | r_names, render_ts, render_qvecs, calibr_pose = resampler.resample(K, q_name, pred_t_beam, pred_R_beam,
129 | q_idx=q_idx, beam_i=beam_i)
130 |
131 | r_names_per_beam[q_idx][beam_i] = r_names
132 | # poses have to be logged in 'beam_dir', but rendered in 'r_dir', so that
133 | # they can be rendered all together by deferred renderers such as ibmr
134 | rendering.log_poses(beam_dir, r_names, render_ts, render_qvecs, args.renderer)
135 | renderer.render_poses(r_dir, rend_model, r_names, render_ts, render_qvecs, calibr_pose, (w, h))
136 | del rend_model
137 |
138 | renderer.end_epoch(out_dir)
139 | logging.info('Moving each renders into their beams folder')
140 | for q_idx in range(len(pd.q_frames_idxs)):
141 | idx = pd.q_frames_idxs[q_idx]
142 | q_name = pd.get_basename(idx)
143 |
144 | r_dir = join(out_dir, q_name)
145 | for beam_i in range(n_beams):
146 | beam_dir = join(r_dir, f'beam_{beam_i}')
147 | beam_names = r_names_per_beam[q_idx][beam_i]
148 | for b_name in beam_names:
149 | src = join(r_dir, b_name+'.png')
150 | dst = join(beam_dir, b_name+'.png')
151 | shutil.move(src, dst)
152 |
153 | return out_dir
154 |
155 |
156 | def rank_candidates(fine_model, pd, render_dir, transform, q_descriptors, N_per_beam, n_beams):
157 | all_pred_t = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3))
158 | all_pred_R = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam, 3, 3))
159 | all_scores = np.empty((len(pd.q_frames_idxs), n_beams, N_per_beam))
160 |
161 | logging.info('Extracting features from rendered images...')
162 | for q_idx in tqdm(range(len(pd.q_frames_idxs)), ncols=100):
163 | q_name = pd.get_basename(pd.q_frames_idxs[q_idx])
164 | query_dir = os.path.join(render_dir, q_name)
165 | query_tensor = pd[pd.q_frames_idxs[q_idx]]['im']
166 | query_res = tuple(query_tensor.shape[-2:])
167 | q_feats = extraction.get_query_descriptor_by_idx(q_descriptors, q_idx)
168 |
169 | for beam_i in range(n_beams):
170 | beam_dir = join(query_dir, f'beam_{beam_i}')
171 | rd = RenderedImagesDataset(beam_dir, transform, query_res, verbose=False)
172 | r_db_descriptors = extraction.get_features_from_dataset(fine_model, rd, bs=fine_model.conf.bs, is_render=True)
173 | predictions, scores = fine_model.rank_candidates(q_feats, r_db_descriptors, get_scores=True)
174 | pred_t, pred_R = utils.get_pose_from_preds(q_idx, pd, rd, predictions, N_per_beam)
175 |
176 | all_pred_t[q_idx][beam_i] = pred_t
177 | all_pred_R[q_idx][beam_i] = pred_R
178 | all_scores[q_idx][beam_i] = scores[:N_per_beam]
179 |
180 | scores_file = join(beam_dir, 'scores.pth')
181 | torch.save((predictions, scores), scores_file)
182 |
183 | del q_feats, r_db_descriptors
184 | # flatten stuff for eval
185 | flat_score = lambda x: einops.rearrange(x, 'q nb N -> q (nb N)')
186 | flat_R = lambda x: einops.rearrange(x, 'q nb N d1 d2 -> q (nb N) d1 d2', d1=3, d2=3)
187 | flat_t = lambda x: einops.rearrange(x, 'q nb N d -> q (nb N) d', d=3)
188 | flat_preds = np.argsort(flat_score(all_scores))
189 | flat_pred_t = flat_t(all_pred_t)
190 | flat_pred_R = flat_R(all_pred_R)
191 |
192 | # log pose estimate
193 | logging.info(f'Generating pose file...')
194 | f_results = utils.log_pose_estimate(render_dir, pd, flat_pred_R, flat_pred_t, flat_preds=flat_preds)
195 |
196 | method_name = os.path.dirname(render_dir).split('/')[-2]
197 | method_name = method_name.replace('_aachenreal_ibmr', '')
198 | dir_to_step = lambda x: int(x.split('/')[-1].split('_')[2].split('s')[-1])
199 | step_n = dir_to_step(render_dir)
200 | if (step_n+1) % 5 == 0:
201 | method_name += f'_s{step_n}'
202 | try:
203 | commons.submit_poses(method=method_name, path=f_results)
204 | except:
205 | logging.info('Submit script failed')
206 |
207 | return all_pred_t, all_pred_R
208 |
209 |
210 | if __name__ == '__main__':
211 | args = parse_args()
212 | main(args)
213 |
--------------------------------------------------------------------------------
/render_dataset_from_script.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os.path import join
4 | import numpy as np
5 |
6 | from configs import get_path_conf
7 | from gloc import rendering
8 | from gloc.rendering import get_renderer
9 | from gloc.datasets import PoseDataset
10 | from gloc.utils import rotmat2qvec
11 |
12 | """
13 | Example:
14 | python render_dataset_from_script.py stairs --out_dir nerf_renders --colmap_res 320 --renderer g_splatting
15 | """
16 |
17 | parser = argparse.ArgumentParser(description='Argument parser')
18 | # general args
19 | parser.add_argument('name', type=str, help='')
20 | parser.add_argument('--out_dir', type=str, help='', default='')
21 | parser.add_argument('--renderer', type=str, help='', default='nerf')
22 | parser.add_argument('--colmap_res', type=int, default=320, help='')
23 | args = parser.parse_args()
24 |
25 | print(args)
26 |
27 | DS = args.name
28 | out_dir = join(args.out_dir, args.name, args.renderer, str(args.colmap_res))
29 | os.makedirs(out_dir, exist_ok=True)
30 | print(f'Renders will in in {out_dir}')
31 |
32 | ##### parse path info and instantiate dataset
33 | paths_conf = get_path_conf(args.colmap_res, None)
34 | pd = PoseDataset(DS, paths_conf[DS])
35 | ####################
36 |
37 | #### parse pose and intrinsics metadata
38 | all_tvecs, all_Rs = pd.get_q_poses()
39 | names = [pd.images[pd.q_frames_idxs[q_idx]].name.replace('/', '_') for q_idx in range(len(pd.q_frames_idxs)) ]
40 | key = os.path.splitext(pd.images[0].name)[0]
41 | chosen_camera = pd.intrinsics[key]
42 | height = chosen_camera['h']
43 | width = chosen_camera['w']
44 | K = chosen_camera['K']
45 |
46 | calibr_pose = []
47 | all_qvecs = []
48 | for tvec, R in zip(all_tvecs, all_Rs):
49 | all_qvecs.append(rotmat2qvec(R))
50 |
51 | T = np.eye(4)
52 | T[0:3, 0:3] = R
53 | T[0:3, 3] = tvec
54 | calibr_pose.append((T, K))
55 | ####################
56 | renderer = get_renderer(args, paths_conf)
57 | mod = renderer.load_model()
58 | renderer.render_poses(out_dir, mod, names, all_tvecs, all_qvecs, calibr_pose, (width, height), deferred=False)
59 |
60 | renderer.clean_file_names(out_dir, names, verbose=True)
61 | # specify 'mesh' as renderer because we converted filenames
62 | rendering.log_poses(out_dir, names, all_tvecs, all_qvecs, 'mesh')
63 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.1
2 | faiss-cpu==1.7.4
3 | gdown==4.7.1
4 | googledrivedownloader==0.4
5 | h5py==3.8.0
6 | huggingface-hub==0.14.1
7 | imageio==2.28.1
8 | open3d==0.17.0
9 | opencv-python==4.7.0.72
10 | pandas==2.0.1
11 | plotly==5.14.1
12 | requests==2.28.1
13 | scikit-image==0.20.0
14 | scikit-learn==1.2.2
15 | scipy==1.9.1
16 | tqdm==4.65.0
17 | utm==0.7.0
18 |
--------------------------------------------------------------------------------
/submit_poses.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import commons
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument("result_path", type=str)
6 | parser.add_argument('-m', "--method_name", required=True, type=str)
7 | args = parser.parse_args()
8 |
9 | commons.submit_poses(method=args.method_name, path=args.result_path)
10 |
--------------------------------------------------------------------------------