├── system.png ├── code ├── models │ ├── .SetOfSet.py.swp │ ├── baseNet.py │ ├── SetOfSet.py │ └── layers.py ├── utils │ ├── __pycache__ │ │ └── general_utils.cpython-39.pyc │ ├── Phases.py │ ├── pos_enc_utils.py │ ├── ba_io.py │ ├── path_utils.py │ ├── sparse_utils.py │ ├── dataset_utils.py │ ├── general_utils.py │ ├── ba_functions.py │ ├── plot_utils.py │ └── ceres_utils.py ├── .vscode │ └── launch.json ├── confs │ ├── inference.conf │ └── training.conf ├── datasets │ ├── Projective.py │ ├── ScenesDataSet.py │ ├── Euclidean.py │ └── SceneData.py ├── inference.py ├── multiple_scenes_learning.py ├── loss_functions.py ├── run_ba.py ├── re_ba.py ├── joint_optimization.py ├── train.py └── evaluation.py ├── .vscode └── settings.json ├── LICENSE ├── bundle_adjustment ├── README.md └── custom_cpp_cost_functions.cpp ├── README.md └── environment.yml /system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/system.png -------------------------------------------------------------------------------- /code/models/.SetOfSet.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/code/models/.SetOfSet.py.swp -------------------------------------------------------------------------------- /code/utils/__pycache__/general_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/code/utils/__pycache__/general_utils.cpython-39.pyc -------------------------------------------------------------------------------- /code/utils/Phases.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class Phases(Enum): 5 | OPTIMIZATION = 1 6 | TRAINING = 2 7 | VALIDATION = 3 8 | TEST = 4 9 | FINE_TUNE = 5 10 | SHORT_OPTIMIZATION = 6 11 | INFERENCE = 7 -------------------------------------------------------------------------------- /code/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: 当前文件", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal" 13 | } 14 | ] 15 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "MicroPython.executeButton": [ 3 | { 4 | "text": "▶", 5 | "tooltip": "Run", 6 | "alignment": "left", 7 | "command": "extension.executeFile", 8 | "priority": 3.5 9 | } 10 | ], 11 | "MicroPython.syncButton": [ 12 | { 13 | "text": "$(sync)", 14 | "tooltip": "sync", 15 | "alignment": "left", 16 | "command": "extension.execute", 17 | "priority": 4 18 | } 19 | ] 20 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/confs/inference.conf: -------------------------------------------------------------------------------- 1 | exp_name = Inference 2 | num_iter = 1 3 | dataset 4 | { 5 | use_gt = False 6 | calibrated = True 7 | flag = 2 # 1 for val, 2 for test 8 | valset_path = /home/zeeq/data/testset/npzs100-130 9 | testset_path = /home/zeeq/data/testset/npzs100-130 10 | results_path = /home/zeeq/results/test10 11 | use_spatial_encoder = True 12 | gps_embed_width = 128 13 | egps_embed_rank = 4 14 | x_embed_rank = 4 15 | dsc_egps_embed_width = 32 # 0 means not embed 16 | addNoise = True 17 | noise_mean = 0 18 | noise_std = 0.01 19 | noise_radio = 1.0 20 | alpha = 0.9 # rot 21 | beta = 0.1 # trans 22 | } 23 | model 24 | { 25 | model_path = /path/to/trained/model.pt # //deepaat/models/Model_Ep44000_embed44.pt 26 | type = SetOfSet.SetOfSetNet 27 | num_features = 256 28 | num_blocks = 3 29 | block_size = 2 30 | use_skip = True 31 | } 32 | train 33 | { 34 | lr = 1e-3 35 | num_of_epochs = 50 36 | eval_intervals = 2000 37 | train_trans = True 38 | save_predictions = False 39 | } 40 | loss 41 | { 42 | func = GTLoss 43 | mask_thred = 0.5 44 | } 45 | ba 46 | { 47 | run_ba = False 48 | repeat=True # If repeat, the first time is from our points and the second from triangulation 49 | triangulation=True 50 | only_last_eval = True 51 | refined = False 52 | max_iter = 50 53 | ba_times = 2 54 | repro_thre = 2 55 | } 56 | -------------------------------------------------------------------------------- /code/utils/pos_enc_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Embedder: 5 | def __init__(self, **kwargs): 6 | self.kwargs = kwargs 7 | self.create_embedding_fn() 8 | 9 | def create_embedding_fn(self): 10 | embed_fns = [] 11 | d = self.kwargs['input_dims'] 12 | out_dim = 0 13 | if self.kwargs['include_input']: 14 | embed_fns.append(lambda x: x) 15 | out_dim += d 16 | 17 | max_freq = self.kwargs['max_freq_log2'] 18 | N_freqs = self.kwargs['num_freqs'] 19 | 20 | if self.kwargs['log_sampling']: 21 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs) 22 | else: 23 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs) 24 | 25 | for freq in freq_bands: 26 | for p_fn in self.kwargs['periodic_fns']: 27 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 28 | out_dim += d 29 | 30 | self.embed_fns = embed_fns 31 | self.out_dim = out_dim 32 | 33 | def embed(self, inputs): 34 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 35 | 36 | 37 | def get_embedder(multires, in_dim): 38 | embed_kwargs = { 39 | 'include_input': True, 40 | 'input_dims': in_dim, 41 | 'max_freq_log2': multires - 1, 42 | 'num_freqs': multires, 43 | 'log_sampling': True, 44 | 'periodic_fns': [torch.sin, torch.cos], 45 | } 46 | 47 | embedder_obj = Embedder(**embed_kwargs) 48 | embed = lambda x, eo=embedder_obj: eo.embed(x) 49 | return embed, embedder_obj.out_dim -------------------------------------------------------------------------------- /code/utils/ba_io.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import numpy as np 3 | import os 4 | 5 | 6 | def read_mat_files(path): 7 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True) 8 | Xs = raw_data['Points3D'].T 9 | M = raw_data['M'] 10 | m, n = M.shape 11 | m = m // 2 12 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1]) 13 | Ps = np.stack(raw_data['Ps']) 14 | data = {'Ps': Ps, 'Xs': Xs, 'xs': xs} 15 | return data 16 | 17 | def read_euc_gt_mat_files(path): 18 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True) 19 | M = raw_data['M'] 20 | if not isinstance(M, (np.ndarray, np.generic) ): 21 | M = np.asarray(M.todense()) 22 | Rs = np.stack(raw_data['R_gt']) 23 | ts = np.stack(raw_data['T_gt']) 24 | Ks = np.stack(raw_data['K_gt']) 25 | m, n = M.shape 26 | m = m // 2 27 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1]) 28 | data = {'Rs': Rs, 'ts': ts, 'Ks':Ks, 'xs': xs} 29 | return data 30 | 31 | def read_proj_gt_mat_files(path): 32 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True) 33 | M = np.asarray(raw_data['M']) 34 | m, n = M.shape 35 | m = m // 2 36 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1]) 37 | data = {'xs': xs} 38 | return data 39 | 40 | def read_euc_our_mat_files(path, name='Final_Cameras'): 41 | raw_data = sio.loadmat(os.path.join(path, 'cameras', name) + '.mat', squeeze_me=True) 42 | Xs = raw_data['pts3D'][:3].T.astype(np.double) 43 | Rs = raw_data['Rs'] 44 | ts = raw_data['ts'] 45 | Ks = raw_data['Ks'] 46 | data = {'Xs': Xs, 'Rs': Rs, 'ts': ts, 'Ks':Ks} 47 | return data 48 | 49 | 50 | def read_proj_our_mat_files(path, name='Final_Cameras'): 51 | raw_data = sio.loadmat(os.path.join(path, 'cameras', name) + '.mat', squeeze_me=True) 52 | Xs = raw_data['pts3D'][:3].T.astype(np.double) 53 | Ps = raw_data['Ps'] 54 | data = {'Ps': Ps, 'Xs':Xs} 55 | return data 56 | 57 | -------------------------------------------------------------------------------- /code/confs/training.conf: -------------------------------------------------------------------------------- 1 | exp_name = Learning_Euc 2 | # random_seed=0 3 | dataset 4 | { 5 | batch_size = 1 6 | shuffle_data = True # if true, then shuffle rows for each data 7 | trainset_path = /home/zeeq/data/trainset/npzs100-130 # Change to your own path 8 | valset_path = /home/zeeq/data/valset/npzs100-130 9 | testset_path = /home/zeeq/data/valset/npzs100-130 10 | results_path = /home/zeeq/results/val10 11 | use_spatial_encoder = True 12 | gps_embed_width = 128 13 | egps_embed_rank = 4 14 | x_embed_rank = 4 15 | dsc_egps_embed_width = 32 # 0 means not embed 16 | addNoise = True 17 | noise_mean = 0 18 | noise_std = 0.01 19 | noise_radio = 1.0 20 | alpha = 0.9 # rot 21 | beta = 0.1 # trans 22 | } 23 | model 24 | { 25 | type = SetOfSet.SetOfSetNet 26 | num_features = 256 27 | num_blocks = 3 28 | block_size = 2 29 | use_skip = True 30 | 31 | multires = 10 # discard 32 | } 33 | train 34 | { 35 | lr = 1e-3 36 | num_of_epochs = 50 37 | eval_intervals = 1000 38 | train_trans = True 39 | save_predictions = False 40 | 41 | scheduler_milestone = [12000,16000,18000] #[20000, 30000, 40000, 45000] # 20000 42 | gamma = 0.2 43 | optimization_num_of_epochs = 500 44 | optimization_eval_intervals = 250 45 | optimization_lr = 1e-3 46 | min_valid_pts = 100 # if any one camera in data batch watch points less than this, it will be skipped 47 | } 48 | loss 49 | { 50 | func = GTLoss 51 | mask_thred = 0.5 52 | 53 | infinity_pts_margin = 1e-4 54 | normalize_grad = True 55 | hinge_loss = True 56 | hinge_loss_weight = 1 57 | } 58 | ba 59 | { 60 | run_ba = False 61 | repeat=True # If repeat, the first time is from our points and the second from triangulation 62 | triangulation=True 63 | only_last_eval = True 64 | refined = False # multi-ba 65 | max_iter = 50 66 | ba_times = 2 67 | repro_thre = 2 68 | } -------------------------------------------------------------------------------- /code/datasets/Projective.py: -------------------------------------------------------------------------------- 1 | import cv2 # Do not remove 2 | import torch 3 | from utils import geo_utils, general_utils, dataset_utils, path_utils 4 | import scipy.io as sio 5 | import numpy as np 6 | import os.path 7 | 8 | 9 | def get_raw_data(conf, scan): 10 | """ 11 | :param conf: 12 | :return: 13 | M - Points Matrix (2mxn) 14 | Ns - Normalization matrices (mx3x3) 15 | Ps_gt - Olsson's estimated camera matrices (mx3x4) 16 | NBs - Normzlize Bifocal Tensor (Normalized Fn) (3mx3m) 17 | triplets 18 | """ 19 | # Init 20 | dataset_path_format = os.path.join(path_utils.path_to_datasets(), 'Projective', '{}.npz') 21 | 22 | # Get conf parameters 23 | if scan is None: 24 | scan = conf.get_string('dataset.scan') 25 | use_gt = conf.get_bool('dataset.use_gt') 26 | 27 | # Get raw data 28 | dataset = np.load(dataset_path_format.format(scan)) 29 | 30 | # Get bifocal tensors and 2D points 31 | M = dataset['M'] 32 | Ps_gt = dataset['Ps_gt'] 33 | Ns = dataset['Ns'] 34 | mask = dataset['mask'] 35 | 36 | if use_gt: 37 | M = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)) 38 | 39 | M = torch.from_numpy(M).float() 40 | Ps_gt = torch.from_numpy(Ps_gt).float() 41 | Ns = torch.from_numpy(Ns).float() 42 | mask = torch.from_numpy(mask).int() 43 | 44 | return M, Ns, Ps_gt, mask 45 | 46 | 47 | def test_Ps_M(Ps, M, Ns): 48 | global_rep_err = geo_utils.calc_global_reprojection_error(Ps.numpy(), M.numpy(), Ns.numpy()) 49 | print("Reprojection Error: Mean = {}, Max = {}".format(np.nanmean(global_rep_err), np.nanmax(global_rep_err))) 50 | 51 | 52 | def test_projective_dataset(scan): 53 | dataset_path_format = os.path.join(path_utils.path_to_datasets(), 'Projective', '{}.npz') 54 | 55 | # Get raw data 56 | dataset = np.load(dataset_path_format.format(scan)) 57 | 58 | # Get bifocal tensors and 2D points 59 | M = dataset['M'] 60 | Ps_gt = dataset['Ps_gt'] 61 | Ns = dataset['Ns'] 62 | 63 | M_gt = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float() 64 | 65 | M = torch.from_numpy(M).float() 66 | Ps_gt = torch.from_numpy(Ps_gt).float() 67 | Ns = torch.from_numpy(Ns).float() 68 | 69 | print("Test Ps and M") 70 | test_Ps_M(Ps_gt, M, Ns) 71 | 72 | print("Test Ps and M_gt") 73 | test_Ps_M(Ps_gt, M_gt, Ns) 74 | 75 | 76 | if __name__ == "__main__": 77 | scan = "Alcatraz Courtyard" 78 | test_projective_dataset(scan) 79 | 80 | 81 | -------------------------------------------------------------------------------- /code/inference.py: -------------------------------------------------------------------------------- 1 | import cv2 # DO NOT REMOVE 2 | from time import time 3 | from utils import general_utils, dataset_utils, plot_utils 4 | from utils.Phases import Phases 5 | from datasets import SceneData, ScenesDataSet 6 | from datasets.ScenesDataSet import DataLoader, myDataSet 7 | from single_scene_optimization import train_single_model 8 | import evaluation 9 | import torch 10 | import pandas as pd 11 | 12 | def inference(conf, device, phase): 13 | 14 | # Get conf 15 | flag = conf.get_int("dataset.flag") # 1 means val, 2 means test 16 | scans_list = SceneData.get_data_list(conf, flag) 17 | bundle_adjustment = conf.get_bool("ba.run_ba") 18 | refined = conf.get_bool("ba.refined") 19 | 20 | # Create model 21 | model_path = conf.get_string('model.model_path') 22 | model = general_utils.get_class("models." + conf.get_string("model.type"))(conf).to(device) 23 | model.load_state_dict(torch.load(model_path)) 24 | model.eval() 25 | 26 | errors_list = [] 27 | with torch.no_grad(): 28 | data_loader = myDataSet(conf, flag, scans_list).to(device) 29 | for batch_data in data_loader: 30 | for scene_data in batch_data: 31 | print(f"processing {scene_data.scan_name} ...") 32 | 33 | # Optimize Scene 34 | begin_time = time() 35 | pred_mask, pred_cam = model(scene_data) 36 | pred_time = time() - begin_time 37 | outputs = evaluation.prepare_predictions_2(scene_data, pred_mask, pred_cam, conf, 0, bundle_adjustment, refined=refined) 38 | errors = evaluation.compute_errors(outputs, conf, bundle_adjustment, refined=refined) 39 | 40 | errors['Inference time'] = pred_time 41 | errors['Scene'] = scene_data.scan_name 42 | errors['all_pts'] = scene_data.M.shape[-1] 43 | errors['pred_pts'] = outputs['pts3D_pred'].shape[1] 44 | # errors['after_ba_pts'] = outputs['Xs_ba'].shape[1] 45 | errors['gt_pts'] = scene_data.mask[:,scene_data.mask.sum(axis=0)!=0].shape[1] 46 | 47 | errors_list.append(errors) 48 | dataset_utils.save_cameras(outputs, conf, curr_epoch=None, phase=phase) 49 | plot_utils.plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan=scene_data.scan_name, epoch=None, bundle_adjustment=bundle_adjustment) 50 | 51 | # Write results 52 | df_errors = pd.DataFrame(errors_list) 53 | mean_errors = df_errors.mean(numeric_only=True) 54 | # df_errors = pd.concat([df_errors,mean_errors], axis=0, ignore_index=True) 55 | df_errors = df_errors.append(mean_errors, ignore_index=True) 56 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean" 57 | df_errors.set_index("Scene", inplace=True) 58 | df_errors = df_errors.round(3) 59 | print(df_errors.to_string(), flush=True) 60 | general_utils.write_results(conf, df_errors, file_name="Inference") 61 | 62 | 63 | if __name__ == "__main__": 64 | conf, device, phase = general_utils.init_exp(Phases.INFERENCE.name) 65 | inference(conf, device, phase) -------------------------------------------------------------------------------- /code/utils/path_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.Phases import Phases 3 | 4 | 5 | def join_and_create(path, folder): 6 | full_path = os.path.join(path, folder) 7 | if not os.path.exists(full_path): 8 | os.mkdir(full_path) 9 | 10 | return full_path 11 | 12 | 13 | def path_to_datasets(): 14 | return os.path.join('..', 'datasets') 15 | 16 | 17 | def path_to_condition(conf): 18 | experiments_folder = conf.get_string('dataset.results_path') 19 | if not os.path.exists(experiments_folder): 20 | os.mkdir(experiments_folder) 21 | exp_name = conf.get_string('exp_name') 22 | return join_and_create(experiments_folder, exp_name) 23 | 24 | 25 | def path_to_exp(conf): 26 | exp_ver = conf.get_string('exp_version') 27 | exp_ver_path = join_and_create(path_to_condition(conf), exp_ver) 28 | 29 | return exp_ver_path 30 | 31 | 32 | def path_to_phase(conf, phase): 33 | exp_path = path_to_exp(conf) 34 | return join_and_create(exp_path, phase.name) 35 | 36 | 37 | def path_to_scan(conf, phase, scan=None): 38 | exp_path = path_to_phase(conf, phase) 39 | scan = conf.get_string("dataset.scan") if scan is None else scan 40 | return join_and_create(exp_path, scan) 41 | 42 | 43 | def path_to_model(conf, phase, epoch=None, scan=None): 44 | if phase in [Phases.TRAINING, Phases.VALIDATION, Phases.TEST]: 45 | parent_folder = path_to_exp(conf) 46 | else: 47 | parent_folder = path_to_scan(conf, phase, scan=scan) 48 | 49 | models_path = join_and_create(parent_folder, 'models') 50 | 51 | if epoch is None: 52 | model_file_name = "Final_Model.pt" 53 | else: 54 | model_file_name = "Model_Ep{}.pt".format(epoch) 55 | 56 | return os.path.join(models_path, model_file_name) 57 | 58 | 59 | def path_to_learning_data(conf, phase): 60 | return join_and_create(path_to_condition(conf), phase) 61 | 62 | 63 | def path_to_cameras(conf, phase, epoch=None, scan=None): 64 | scan_path = path_to_scan(conf, phase, scan=scan) 65 | cameras_path = join_and_create(scan_path, 'cameras') 66 | 67 | if epoch is None: 68 | cameras_file_name = "Final_Cameras" 69 | else: 70 | cameras_file_name = "Cameras_Ep{}".format(epoch) 71 | 72 | return os.path.join(cameras_path, cameras_file_name) 73 | 74 | 75 | def path_to_plots(conf, phase, epoch=None, scan=None): 76 | scan_path = path_to_scan(conf, phase, scan=scan) 77 | plots_path = join_and_create(scan_path, 'plots') 78 | 79 | if epoch is None: 80 | plots_file_name = "Final_plots.html" 81 | else: 82 | plots_file_name = "Plot_Ep{}.html".format(epoch) 83 | 84 | return os.path.join(plots_path, plots_file_name) 85 | 86 | 87 | def path_to_logs(conf, phase): 88 | phase_path = path_to_phase(conf, phase) 89 | logs_path = join_and_create(phase_path, "logs") 90 | return logs_path 91 | 92 | 93 | def path_to_code_logs(conf): 94 | exp_path = path_to_exp(conf) 95 | code_path = join_and_create(exp_path, "code") 96 | return code_path 97 | 98 | 99 | def path_to_conf(conf_file): 100 | return os.path.join( 'confs', conf_file) -------------------------------------------------------------------------------- /code/multiple_scenes_learning.py: -------------------------------------------------------------------------------- 1 | import cv2 # DO NOT REMOVE 2 | from utils import general_utils, dataset_utils 3 | from utils.Phases import Phases 4 | from datasets.ScenesDataSet import ScenesDataSet, DataLoader, myDataSet 5 | from datasets import SceneData 6 | from single_scene_optimization import train_single_model 7 | import train 8 | import copy 9 | import time 10 | 11 | 12 | def main(): 13 | # Init Experiment 14 | conf, device, phase = general_utils.init_exp(Phases.TRAINING.name) 15 | general_utils.log_code(conf) 16 | 17 | # Get configuration 18 | sample = conf.get_bool('dataset.sample') 19 | batch_size = conf.get_int('dataset.batch_size') 20 | 21 | train_list = SceneData.get_data_list(conf, 0) 22 | val_list = SceneData.get_data_list(conf, 1) 23 | test_list = SceneData.get_data_list(conf, 2) 24 | 25 | if sample: 26 | min_sample_size = conf.get_int('dataset.min_sample_size') 27 | max_sample_size = conf.get_int('dataset.max_sample_size') 28 | # optimization_num_of_epochs = conf.get_int("train.optimization_num_of_epochs") 29 | # optimization_eval_intervals = conf.get_int('train.optimization_eval_intervals') 30 | # optimization_lr = conf.get_int('train.optimization_lr') 31 | 32 | # Create train, test and validation sets 33 | train_scenes = SceneData.create_scene_data_from_list(train_list, conf, 0) 34 | validation_scenes = SceneData.create_scene_data_from_list(val_list, conf, 1) 35 | test_scenes = SceneData.create_scene_data_from_list(test_list, conf, 2) 36 | 37 | train_set = ScenesDataSet(train_scenes, return_all=False, min_sample_size=min_sample_size, max_sample_size=max_sample_size) 38 | validation_set = ScenesDataSet(validation_scenes, return_all=True) 39 | test_set = ScenesDataSet(test_scenes, return_all=True) 40 | 41 | # Create dataloaders 42 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True).to(device) 43 | validation_loader = DataLoader(validation_set, batch_size=1, shuffle=False).to(device) 44 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False).to(device) 45 | else: 46 | train_loader = myDataSet(conf, 0, train_list, batch_size, True).to(device) 47 | validation_loader = myDataSet(conf, 1, val_list).to(device) 48 | test_loader = myDataSet(conf, 2, test_list).to(device) 49 | 50 | # Train model 51 | model = general_utils.get_class("models." + conf.get_string("model.type"))(conf).to(device) 52 | train_stat, train_errors, validation_errors, test_errors = train.train(conf, train_loader, model, phase, validation_loader, test_loader) 53 | # Write results 54 | general_utils.write_results(conf, train_stat, file_name="Train_Stats") 55 | general_utils.write_results(conf, train_errors, file_name="Train") 56 | general_utils.write_results(conf, validation_errors, file_name="Validation") 57 | general_utils.write_results(conf, test_errors, file_name="Test") 58 | 59 | 60 | def optimization_all_sets(conf, device, phase): 61 | # Get logs directories 62 | scans_list = conf.get_list('dataset.scans_list') 63 | for i, scan in enumerate(scans_list): 64 | conf["dataset"]["scan"] = scan 65 | train_single_model(conf, device, phase) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /code/models/baseNet.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | from utils import geo_utils 5 | from pytorch3d import transforms as py3d_trans 6 | import numpy as np 7 | 8 | 9 | class BaseNet(torch.nn.Module): 10 | def __init__(self, conf): 11 | super(BaseNet, self).__init__() 12 | 13 | self.calibrated = conf.get_bool('dataset.calibrated') 14 | self.normalize_output = conf.get_string('model.normalize_output', default=None) 15 | self.rot_representation = conf.get_string('model.rot_representation', default='quat') 16 | self.soft_sign = torch.nn.Softsign() 17 | 18 | if self.calibrated and self.rot_representation == '6d': 19 | print('rot representation: ' + self.rot_representation) 20 | self.out_channels = 9 21 | elif self.calibrated and self.rot_representation == 'quat': 22 | self.out_channels = 7 23 | elif self.calibrated and self.rot_representation == 'svd': 24 | self.out_channels = 12 25 | elif not self.calibrated: 26 | self.out_channels = 12 27 | else: 28 | print("Illegal output format") 29 | exit() 30 | 31 | @abc.abstractmethod 32 | def forward(self, data): 33 | pass 34 | 35 | def extract_model_outputs(self, x, pts_3D, data): 36 | # Get points 37 | pts_3D = geo_utils.ones_padding(pts_3D) 38 | 39 | # Get calibrated predictions 40 | if self.calibrated: 41 | # Get rotation 42 | if self.rot_representation == '6d': 43 | RTs = py3d_trans.rotation_6d_to_matrix(x[:, :6]) 44 | elif self.rot_representation == 'svd': 45 | m = x[:, :9].reshape(-1, 3, 3) 46 | RTs = geo_utils.project_to_rot(m) 47 | elif self.rot_representation == 'quat': 48 | RTs = py3d_trans.quaternion_to_matrix(x[:, :4]) 49 | else: 50 | print("Illegal output format") 51 | exit() 52 | 53 | # Get translation 54 | minRTts = x[:, -3:] 55 | 56 | # n = len(RTs) 57 | # Ps = torch.zeros([n, 3, 4]).cuda(0) 58 | 59 | # for i,r,t in zip(np.arange(n),RTs,minRTts): 60 | # Ps[i] = r.T @ torch.cat((torch.eye(3).cuda(0), -t.view(3, 1)), dim=1) 61 | 62 | # Get camera matrix 63 | Ps = torch.cat((RTs, minRTts.unsqueeze(dim=-1)), dim=-1) 64 | 65 | else: # Projective 66 | Ps = x.reshape(-1, 3, 4) 67 | 68 | # Normalize predictions 69 | if self.normalize_output == "Chirality": 70 | scale = torch.sign(Ps[:, 0:3, 0:3].det()) / Ps[:, 2, 0:3].norm(dim=1) 71 | Ps = Ps * scale.reshape(-1, 1, 1) 72 | elif self.normalize_output == "Differentiable Chirality": 73 | scale = self.soft_sign(Ps[:, 0:3, 0:3].det() * 10e3) / Ps[:, 2, 0:3].norm(dim=1) 74 | Ps = Ps * scale.reshape(-1, 1, 1) 75 | elif self.normalize_output == "Frobenius": 76 | Ps = Ps / Ps.norm(dim=(1, 2), p='fro', keepdim=True) 77 | 78 | # The model outputs a normalized camera! Meaning from world coordinates to camera coordinates, not to pixels in the image. 79 | pred_cams = {"Ps_norm": Ps, "pts3D": pts_3D} 80 | return pred_cams 81 | 82 | def extract_camera_outputs(self, x): 83 | 84 | # quats = self.norm_quats(x[:, :4]) 85 | # pred_cams = {"quats": geo_utils.norm_quats(x[:, :4]), "ts": geo_utils.scale_ts_norm(x[:, -3:])} 86 | pred_cams = {"quats": x[:, :4], "ts": x[:, -3:]} 87 | return pred_cams -------------------------------------------------------------------------------- /code/utils/sparse_utils.py: -------------------------------------------------------------------------------- 1 | from hashlib import new 2 | import torch 3 | 4 | 5 | class SparseMat: 6 | def __init__(self, values, indices, cam_per_pts, pts_per_cam, shape): 7 | assert len(shape) == 3 8 | self.values = values 9 | self.indices = indices 10 | self.shape = shape 11 | self.cam_per_pts = cam_per_pts 12 | self.pts_per_cam = pts_per_cam 13 | self.device = self.values.device 14 | 15 | @property 16 | def size(self): 17 | return self.shape 18 | 19 | 20 | def sum(self, dim): 21 | assert dim == 1 or dim == 0 22 | n_features = self.shape[2] 23 | out_size = self.shape[0] if dim == 1 else self.shape[1] 24 | indices_index = 0 if dim == 1 else 1 25 | mat_sum = torch.zeros(out_size, n_features, device=self.device) 26 | return mat_sum.index_add(0, self.indices[indices_index], self.values) 27 | 28 | 29 | def mean(self, dim): 30 | assert dim == 1 or dim == 0 31 | if dim == 0: 32 | return self.sum(dim=0) / self.cam_per_pts 33 | else: 34 | return self.sum(dim=1) / self.pts_per_cam 35 | 36 | 37 | def std(self, dim): 38 | assert dim == 1 or dim == 0 39 | #mean_mat = torch.zeros(self.shape, device=self.device) 40 | if dim==0: 41 | meanmat = self.mean(dim=0) # n,d 42 | mean_mat = meanmat.unsqueeze(dim=0).repeat(self.shape[0], 1, 1) # m,n,d 43 | mat_vals = mean_mat[self.indices[0], self.indices[1], :] # nnz 44 | sparse_meanmat = SparseMat(mat_vals, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) # m,n,d 45 | sparse_std = (((self-sparse_meanmat)**2).sum(dim=0) / self.cam_per_pts).pow(0.5) # n,d 46 | else: 47 | meanmat = self.mean(dim=1) 48 | mean_mat = meanmat.unsqueeze(dim=1).repeat(1, self.shape[1], 1) 49 | mat_vals = mean_mat[self.indices[0], self.indices[1], :] 50 | sparse_meanmat = SparseMat(mat_vals, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) # m,n,d 51 | sparse_std = (((self-sparse_meanmat)**2).sum(dim=1) / self.pts_per_cam).pow(0.5) # m,d 52 | 53 | return sparse_std 54 | 55 | # concat in last dim 56 | def last_dim_cat(self, other): 57 | new_values = torch.cat([self.values, other.values], -1) 58 | new_shape = (self.shape[0], self.shape[1], new_values.shape[-1]) 59 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, new_shape) 60 | 61 | 62 | def to(self, device, **kwargs): 63 | self.device = device 64 | self.values = self.values.to(device, **kwargs) 65 | self.indices = self.indices.to(device, **kwargs) 66 | self.pts_per_cam = self.pts_per_cam.to(device, **kwargs) 67 | self.cam_per_pts = self.cam_per_pts.to(device, **kwargs) 68 | return self 69 | 70 | 71 | 72 | def __add__(self, other): 73 | assert self.shape == other.shape 74 | # assert (self.indices == other.indices).all() # removed due to runtime 75 | new_values = self.values + other.values 76 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) 77 | 78 | 79 | def __sub__(self, other): 80 | assert self.shape == other.shape 81 | new_values = self.values - other.values 82 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) 83 | 84 | 85 | def __mul__(self, other): 86 | assert self.shape[:-1] == other.shape[:-1] 87 | new_values = self.values * other.values 88 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) 89 | 90 | 91 | def __truediv__(self, other): 92 | assert self.shape[:-1] == other.shape[:-1] 93 | new_values = self.values / other.values 94 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) 95 | 96 | 97 | def __pow__(self, other): 98 | return SparseMat(self.values ** other, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) -------------------------------------------------------------------------------- /bundle_adjustment/README.md: -------------------------------------------------------------------------------- 1 | # Python Bundle Adjustment 2 | These instructions are based on and almost identical to the ones at [https://github.com/drormoran/gasfm/tree/main/bundle_adjustment/README.md](https://github.com/drormoran/gasfm/tree/main/bundle_adjustment/README.md). 3 | 4 | 5 | ## Conda envorinment 6 | Use the gasfm environment. 7 | ``` 8 | conda activate deepaat 9 | export PYBIND11_PYTHON_VERSION="3.8" 10 | export PYTHON_VERSION="3.8" 11 | ``` 12 | 13 | ## Directory structure 14 | After this set up, the directory structure be: 15 | ``` 16 | DeepAAT 17 | ├── bundle_adjustment 18 | │   ├── ceres-solver 19 | │   │   ├── ceres-bin 20 | │   │   | └── lib 21 | │   │   | └── PyCeres.cpython-39-x86_64-linux-gnu.so 22 | │   │   ├── ceres_python_bindings 23 | │   │   | └── python_bindings 24 | │   │   | └── custom_cpp_cost_functions.cpp 25 | │   │   └── CMakeLists.txt 26 | │   └── custom_cpp_cost_functions.cpp 27 | ├── code 28 | ├── datasets 29 | ``` 30 | ## Set up 31 | 1. Clone the Ceres-Solver repository to the bundle_adjustment folder and check out version 2.1.0: 32 | 33 | ``` 34 | cd bundle_adjustment 35 | git clone https://ceres-solver.googlesource.com/ceres-solver -b 2.1.0 36 | ``` 37 | 38 | 39 | 2. Clone the ceres_python_bindings package inside the ceres-solver folder: 40 | 41 | ``` 42 | cd ceres-solver 43 | git clone https://github.com/Edwinem/ceres_python_bindings 44 | ``` 45 | 46 | 47 | 3. Copy the file "custom_cpp_cost_functions.cpp" and replace the file "ceres-solver/ceres_python_bindings/python_bindings/custom_cpp_cost_functions.cpp". 48 | This file contains projective and euclidean custom bundle adjustment functions. 49 | 50 | ``` 51 | cp ../custom_cpp_cost_functions.cpp ceres_python_bindings/python_bindings/custom_cpp_cost_functions.cpp 52 | ``` 53 | 54 | Next, you need to build ceres_python_bindings and ceres-solver and create a shared object file that python can call. 55 | You can either continue with the instructions here or follow the instructions at the ceres_python_bindings repository. 56 | 57 | 1. run: 58 | 59 | ``` 60 | cd ceres_python_bindings 61 | git submodule init 62 | git submodule update 63 | ``` 64 | 65 | 66 | 1. Make sure that the C++ standard library version used during the build is recent enough, and not hard-coded to C++11 by pybind11. Please check your c++ compiler version and modify it bellow (for example here is c++17): 67 | 68 | ``` 69 | sed -i 's/set(PYBIND11_CPP_STANDARD -std=c++11)/set(PYBIND11_CPP_STANDARD -std=c++17)/g' AddToCeres.cmake 70 | ``` 71 | 72 | 73 | 5. Add to the end of the file ceres-solver/CMakeLists.txt the line: "include(ceres_python_bindings/AddToCeres.cmake)": 74 | 75 | ``` 76 | cd .. 77 | echo "include(ceres_python_bindings/AddToCeres.cmake)" >> CMakeLists.txt 78 | ``` 79 | 80 | 81 | 6. Inside ceres-solver folder run: 82 | 83 | 84 | ``` 85 | mkdir ceres-bin 86 | cd ceres-bin 87 | cmake .. 88 | make -j8 89 | make test 90 | ``` 91 | 92 | 7. If everything worked you should see the following file: 93 | 94 | ``` 95 | bundle_adjustment/ceres-solver/ceres-bin/lib/PyCeres.cpython-39-x86_64-linux-gnu.so 96 | ``` 97 | 98 | 8. If you want to use this bundle adjustment implementation for a different project make sure to add the path of the shared object to linux PATH (in the code this is done for you). In the python project this would be for example: 99 | 100 | ``` 101 | import sys 102 | sys.path.append('../bundle_adjustment/ceres-solver/ceres-bin/lib/') 103 | import PyCeres 104 | ``` 105 | 106 | To see the usage of the PyCeres functions go to code/utils/ceres_utils and code/utils/ba_functions. 107 | 108 | ## Note 109 | If you encounter problems while compiling PyCeres, you can also refer to [GASFM](https://github.com/lucasbrynte/gasfm)'s [bundle adjustment instruction](https://github.com/lucasbrynte/gasfm/blob/main/bundle_adjustment/README.md) or create a new environment for compilation. 110 | If the environment for compiling PyCeres is not the same as the network environment, then run_ba in conf file can be set to False during training and inference, and after obtaining the predicted results, run run_ba.py separately for BA. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based mapping
2 | 3 | ### [Paper](https://www.sciencedirect.com/science/article/pii/S1569843224005466) | [arXiv](https://arxiv.org/abs/2402.01134) 4 |

5 | 6 |

7 | 8 | This is the implementation of the DeepAAT architecture, presented in our JAG paper DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based Mapping. The codebase is forked from the implementation of the ICCV 2021 paper [Deep Permutation Equivariant Structure from Motion](https://openaccess.thecvf.com/content/ICCV2021/html/Moran_Deep_Permutation_Equivariant_Structure_From_Motion_ICCV_2021_paper.html), available at [https://github.com/drormoran/Equivariant-SFM](https://github.com/drormoran/Equivariant-SFM). That architecture is also used as a baseline and referred to as ESFM in our paper. 9 | 10 | DeepAAT considers both spatial and spectral characteristics of imagery, enhancing its capability to resolve erroneous matching pairs and accurately predict image poses. DeepAAT marks a significant leap in AAT's efficiency, ensuring thorough scene coverage and precision. Its processing speed outpaces incremental AAT methods by hundreds of times and global AAT methods by tens of times while maintaining a comparable level of reconstruction accuracy. Additionally, DeepAAT's scene clustering and merging strategy facilitate rapid localization and pose determination for large-scale UAV images, even under constrained computing resources. The experimental results demonstrate DeepAAT's substantial improvements over conventional AAT methods, highlighting its potential in the efficiency and accuracy of UAV-based 3D reconstruction tasks. 11 | 12 | 13 | ## Contents 14 | 15 | - [Setup](#Setup) 16 | - [Usage](#Usage) 17 | - [Citation](#Citation) 18 | 19 | --- 20 | 21 | ## Setup 22 | This repository is implemented with python 3.8, and in order to run bundle adjustment requires linux. We have used Ubuntu 22.04. You should also have a CUDA-capable GPU. 23 | 24 | ## Directory structure 25 | The repository should contain the following directories: 26 | ``` 27 | DeepAAT 28 | ├── bundle_adjustment 29 | ├── code 30 | ├── scripts 31 | ├── environment.yml 32 | ``` 33 | 34 | ## Conda environment 35 | Create the environment using the following commands: 36 | ``` 37 | conda env create -f environment.yml 38 | conda activate deepaat 39 | ``` 40 | 41 | ## PyCeres 42 | Next follow the bundle adjustment instructions. 43 | 44 | ## Data and pretrained models 45 | Attached to [this](https://www.dropbox.com/scl/fo/gtju43lxu9zgn36ft86ly/AArhmE-Q2QxlmGwwvWUiyKc?rlkey=4lse57283dskw1mfy6oii3ti5&st=ky4qc4nk&dl=0). You can find both the datasets and pretrained models for Euclidean reconstruction of novel scenes. 46 | Download the data and pretrained model, and then modify the path in the conf file accordingly. Due to the large amount of training data, only pretrained models and test data are provided. 47 | 48 | ## Usage 49 | To execute the code, first navigate to the `code` subdirectory. Also make sure that the conda environment is activated. 50 | 51 | To train a model from scratch for reconstruction of novel test scenes, run (Please make sure to modify the corresponding data path and configuration in the conf file correctly): 52 | ``` 53 | python multiple_scenes_learning.py --conf path/to/conf 54 | ``` 55 | where `path/to/conf` is relative to `code/confs/`, and may e.g. be `training.conf` for training a Euclidean reconstruction model using data augmentation. 56 | 57 | The training phase is succeeded by bundle adjustment, evaluation, and by default also by separate fine-tuning of the model parameters on every test scene. 58 | 59 | To infer a new scene, run (Please make sure to modify the corresponding data path and configuration in the conf file correctly): 60 | ``` 61 | python inference.py --conf inference.conf 62 | ``` 63 | 64 | ## Citation 65 | If you find this work useful, please cite our paper: 66 | ``` 67 | @article{chen2024deepaat, 68 | title={DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based Mapping}, 69 | author={Chen, Zequan and Li, Jianping and Li, Qusheng and Dong, Zhen and Yang, Bisheng}, 70 | journal={International Journal of Applied Earth Observation and Geoinformation}, 71 | volume={134}, 72 | pages={104190}, 73 | year={2024}, 74 | publisher={Elsevier} 75 | } 76 | ``` 77 | 78 | ## Acknowledgement 79 | We make improvements based on [ESFM](https://github.com/drormoran/Equivariant-SFM). We thank the authors for releasing the source code. -------------------------------------------------------------------------------- /code/loss_functions.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | from utils import geo_utils 4 | from torch import dtype, nn 5 | from torch.nn import functional as F 6 | from pytorch3d import transforms as py3d_trans 7 | 8 | 9 | class ESFMLoss(nn.Module): 10 | def __init__(self, conf): 11 | super().__init__() 12 | self.infinity_pts_margin = conf.get_float("loss.infinity_pts_margin") 13 | self.normalize_grad = conf.get_bool("loss.normalize_grad") 14 | 15 | self.hinge_loss = conf.get_bool("loss.hinge_loss") 16 | if self.hinge_loss: 17 | self.hinge_loss_weight = conf.get_float("loss.hinge_loss_weight") 18 | else: 19 | self.hinge_loss_weight = 0 20 | 21 | def forward(self, pred_cam, data, epoch=None): 22 | Ps = pred_cam["Ps_norm"] 23 | pts_2d = Ps @ pred_cam["pts3D"] # [m, 3, n] 24 | 25 | if self.normalize_grad: 26 | pts_2d.register_hook(lambda grad: F.normalize(grad, dim=1) / data.valid_pts.sum()) 27 | 28 | projected_points = geo_utils.get_positive_projected_pts_mask(pts_2d, self.infinity_pts_margin) 29 | else: 30 | projected_points = geo_utils.get_projected_pts_mask(pts_2d, self.infinity_pts_margin) 31 | 32 | # Calculate hinge Loss 33 | hinge_loss = (self.infinity_pts_margin - pts_2d[:, 2, :]) * self.hinge_loss_weight 34 | 35 | # Calculate reprojection error 36 | pts_2d = (pts_2d / torch.where(projected_points, pts_2d[:, 2, :], torch.ones_like(projected_points).float()).unsqueeze(dim=1)) 37 | reproj_err = (pts_2d[:, 0:2, :] - data.norm_M.reshape(Ps.shape[0], 2, -1)).norm(dim=1) 38 | 39 | return torch.where(projected_points, reproj_err, hinge_loss)[data.valid_pts].mean() 40 | 41 | 42 | class GTLoss(nn.Module): 43 | def __init__(self, conf): 44 | super().__init__() 45 | # self.calibrated = conf.get_bool('dataset.calibrated') 46 | self.train_trans = conf.get_bool('train.train_trans') 47 | self.alpha = conf.get_float("dataset.alpha") 48 | self.beta = conf.get_float("dataset.beta") 49 | 50 | def forward(self, pred_cam, data, epoch=None): 51 | 52 | if self.train_trans: 53 | ts_gt = data.ts - data.gpss 54 | orient_err = (data.quats - pred_cam["quats"]).norm(2, dim=1) 55 | translation_err = (ts_gt - pred_cam["ts"]).norm(2, dim=1) 56 | orient_loss = orient_err.mean() * self.alpha 57 | trans_loss = translation_err.mean() * self.beta 58 | loss = orient_loss + trans_loss 59 | 60 | if epoch is not None and epoch % 1000 == 0: 61 | # Print loss 62 | print("GTloss = {}, orient err = {}, trans err = {}".format(loss, orient_loss, trans_loss)) 63 | return loss, orient_loss, trans_loss 64 | 65 | else: 66 | orient_err = (data.quats - pred_cam["quats"]).norm(2, dim=1) 67 | orient_loss = orient_err.mean() 68 | 69 | if epoch is not None and epoch % 1000 == 0: 70 | # Print loss 71 | print("orient err = {}".format(orient_loss)) 72 | return orient_loss 73 | 74 | 75 | class BCELoss(nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | self.bceloss = nn.BCELoss() 79 | 80 | def forward(self, pred_mask, data, epoch=None): 81 | # pred_mask_ = pred_mask.reshape(-1, 1) # m,n,1 -> m*n,1 82 | # mask_ = data.mask # 2m,n 83 | # mask_ = mask_[0:mask_.shape[0]:2,:].reshape(-1, 1) # 2m,n -> m*n,1 84 | pred_mask_ = torch.as_tensor(pred_mask.values).reshape(len(pred_mask.values), 1) 85 | gt_mask_ = torch.as_tensor(data.mask_sparse.values).reshape(len(data.mask_sparse.values), 1) 86 | 87 | loss = self.bceloss(pred_mask_, gt_mask_).mean() 88 | if epoch is not None and epoch % 1000 == 0: 89 | # Print loss 90 | print("BCEloss = {}".format(loss)) 91 | return loss 92 | 93 | class BCEWithLogitLoss(nn.Module): 94 | def __init__(self): 95 | super().__init__() 96 | self.lbceloss = nn.BCEWithLogitsLoss() 97 | 98 | def forward(self, pred_mask, data, epoch=None): 99 | 100 | pred_mask_ = torch.as_tensor(pred_mask.values).reshape(len(pred_mask.values), 1) 101 | gt_mask_ = torch.as_tensor(data.mask_sparse.values).reshape(len(data.mask_sparse.values), 1) 102 | 103 | loss = self.lbceloss(pred_mask_, gt_mask_).mean() 104 | if epoch is not None and epoch % 1000 == 0: 105 | # Print loss 106 | print("BCEloss = {}".format(loss)) 107 | return loss -------------------------------------------------------------------------------- /code/datasets/ScenesDataSet.py: -------------------------------------------------------------------------------- 1 | from datasets import SceneData 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | from datasets import Euclidean 5 | 6 | class DataLoader(): 7 | def __init__(self, dataset, batch_size=1, shuffle=False): 8 | self.n = len(dataset) 9 | self.dataset = dataset 10 | self.batch_size = batch_size 11 | self.num_batches = int(np.ceil(self.n / self.batch_size)) 12 | self.shuffle=shuffle 13 | self.permutation = self.init_permutation() 14 | self.current_batch = 0 15 | self.device = 'cpu' 16 | 17 | def init_permutation(self): 18 | return np.random.permutation(self.n) if self.shuffle else np.arange(self.n) 19 | 20 | def __iter__(self): 21 | self.current_batch = 0 22 | self.permutation = self.init_permutation() 23 | return self 24 | 25 | def __next__(self): 26 | if self.current_batch == self.num_batches: 27 | raise StopIteration 28 | start_ind = self.current_batch*self.batch_size 29 | end_ind = min((self.current_batch+1)*self.batch_size, self.n) 30 | current_indices = self.permutation[start_ind:end_ind] 31 | self.current_batch += 1 32 | return [self.dataset[i].to(self.device) for i in current_indices] 33 | 34 | def __len__(self): 35 | return self.n 36 | 37 | def to(self, device, **kwargs): 38 | self.device = device 39 | return self 40 | 41 | 42 | class myDataSetds(Dataset): 43 | def __init__(self, conf, datalist, flag): 44 | self.datalist = datalist 45 | self.conf = conf 46 | self.flag = flag 47 | self.dilute_M = conf.get_bool('dataset.diluteM', default=False) 48 | 49 | def __len__(self): 50 | return len(self.datalist) 51 | 52 | def __getitem__(self, idx): 53 | file = self.datalist[idx] 54 | M, Ns, Rs, ts, quats, mask = Euclidean.get_raw_data(self.conf, file, self.flag) 55 | data = SceneData.SceneData(M, Ns, Rs, ts, quats, mask, file, self.dilute_M) 56 | return data 57 | 58 | 59 | class myDataSet(): 60 | def __init__(self, conf, flag, datalist, batch_size=1, shuffle=False): 61 | self.n = len(datalist) 62 | self.datalist = datalist 63 | self.batch_size = batch_size 64 | self.num_batches = int(np.ceil(self.n / self.batch_size)) 65 | self.shuffle=shuffle 66 | self.permutation = self.init_permutation() 67 | self.current_batch = 0 68 | self.device = 'cpu' 69 | self.conf = conf 70 | self.flag = flag 71 | self.dilute_M = conf.get_bool('dataset.diluteM', default=False) 72 | 73 | # print(f"num_batches {self.num_batches}") 74 | # print(f"permutation {self.permutation}") 75 | # print(f"flag {self.flag}") 76 | 77 | def init_permutation(self): 78 | return np.random.permutation(self.n) if self.shuffle else np.arange(self.n) 79 | 80 | def __iter__(self): 81 | self.current_batch = 0 82 | self.permutation = self.init_permutation() 83 | # print(f"in iter, perm: {self.permutation}") 84 | return self 85 | 86 | def __next__(self): 87 | if self.current_batch == self.num_batches: 88 | raise StopIteration 89 | start_ind = self.current_batch*self.batch_size 90 | # print(f"start ind: {start_ind}") 91 | end_ind = min((self.current_batch+1)*self.batch_size, self.n) 92 | current_indices = self.permutation[start_ind:end_ind] 93 | # print(f"current_indices: {current_indices}") 94 | # print(f"current data: {self.datalist[current_indices[0]]}") 95 | self.current_batch += 1 96 | return [self.loaddata(self.datalist[i]).to(self.device) for i in current_indices] 97 | 98 | def __len__(self): 99 | return self.n 100 | 101 | def to(self, device, **kwargs): 102 | self.device = device 103 | return self 104 | 105 | def loaddata(self, file): 106 | M, Ns, Rs, ts, quats, mask, gpss, color, scale, use_spatial_encoder, dsc_idx, dsc_data, dsc_shape = Euclidean.get_raw_data(self.conf, file, self.flag) 107 | data = SceneData.SceneData(M, Ns, Rs, ts, quats, mask, file, gpss, color, scale, self.dilute_M, use_spatial_encoder, 108 | dsc_idx, dsc_data, dsc_shape) 109 | return data 110 | 111 | 112 | class ScenesDataSet: 113 | def __init__(self, data_list, return_all, min_sample_size=10, max_sample_size=30): 114 | super().__init__() 115 | self.data_list = data_list 116 | self.return_all = return_all 117 | self.min_sample_size = min_sample_size 118 | self.max_sample_size = max_sample_size 119 | 120 | def __getitem__(self, item): 121 | current_data = self.data_list[item] 122 | if self.return_all: 123 | return current_data 124 | else: 125 | max_sample = min(self.max_sample_size, len(current_data.scan_name)) 126 | if self.min_sample_size >= max_sample: 127 | sample_fraction = max_sample 128 | else: 129 | sample_fraction = np.random.randint(self.min_sample_size, max_sample + 1) 130 | return SceneData.sample_data(current_data, sample_fraction) 131 | 132 | def __len__(self): 133 | return len(self.data_list) -------------------------------------------------------------------------------- /code/run_ba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import geo_utils, ba_functions, general_utils 3 | import os 4 | import pandas as pd 5 | 6 | 7 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined): 8 | outputs = {} 9 | 10 | file = dict(np.load(npzpath)) 11 | scan_name = file['scan_name'] 12 | xs = file['xs'] 13 | Rs_pred = file['Rs'] 14 | ts_pred = file['ts'] 15 | Rs_gt = file['Rs_gt'] 16 | ts_gt = file['ts_gt'] 17 | Ks = file['Ks'] 18 | Ns = np.linalg.inv(Ks) 19 | Xs = file['pts3D_pred'] 20 | Ps = file['Ps'] 21 | raw_xs = file['raw_xs'] 22 | 23 | outputs['xs'] = xs 24 | outputs['Rs_gt'] = Rs_gt 25 | outputs['ts_gt'] = ts_gt 26 | 27 | ba_res = ba_functions.euc_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=Ks, Xs=Xs.T, Ps=Ps, Ns=Ns, 28 | repeat=repeat, max_iter=max_iter, ba_times=ba_times, 29 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs 30 | outputs['Rs_ba'] = ba_res['Rs'] 31 | outputs['ts_ba'] = ba_res['ts'] 32 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n 33 | outputs['Ps_ba'] = ba_res['Ps'] 34 | if refined: 35 | outputs['valid_points'] = ba_res['valid_points'] 36 | outputs['new_xs'] = ba_res['new_xs'] 37 | outputs['colidx'] = ba_res['colidx'] 38 | 39 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt, 40 | return_alignment=True) # Align Rs_fixed, tx_fixed 41 | outputs['Rs_ba_fixed'] = R_ba_fixed 42 | outputs['ts_ba_fixed'] = t_ba_fixed 43 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba']) 44 | # outputs['Rs_ba_fixed'] = ba_res['Rs'] 45 | # outputs['ts_ba_fixed'] = ba_res['ts'] 46 | # outputs['Xs_ba_fixed'] = ba_res['Xs'].T 47 | file.update(outputs) 48 | np.savez(npzpath, **file) 49 | 50 | return file, scan_name 51 | 52 | 53 | def _compute_errors(outputs, refined=False): 54 | model_errors = {} 55 | 56 | premask = outputs['premask'] 57 | gtmask = outputs['gtmask'] 58 | 59 | tp, fp, tn, fn = general_utils.compute_confusion_matrix(premask, gtmask) 60 | accuracy, precision, recall, F1 = general_utils.compute_indexes(tp, fp, tn, fn) 61 | model_errors["TP"] = tp 62 | model_errors["FP"] = fp 63 | model_errors["TN"] = tn 64 | model_errors["FN"] = fn 65 | model_errors["Accuracy"] = accuracy 66 | model_errors["Precision"] = precision 67 | model_errors["Recall"] = recall 68 | model_errors["F1"] = F1 69 | 70 | pts3D_pred = outputs['pts3D_pred'] 71 | Ps = outputs['Ps'] 72 | Rs_fixed = outputs['Rs'] 73 | ts_fixed = outputs['ts'] 74 | Rs_gt = outputs['Rs_gt'] 75 | ts_gt = outputs['ts_gt'] 76 | xs = outputs['xs'] 77 | Xs_ba = outputs['Xs_ba'] 78 | Ps_ba = outputs['Ps_ba'] 79 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs) 80 | if not our_repro_error.shape: return model_errors 81 | model_errors["our_repro"] = np.nanmean(our_repro_error) 82 | model_errors["our_repro_max"] = np.nanmax(our_repro_error) 83 | 84 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt) 85 | model_errors["ts_mean"] = np.mean(ts_error) 86 | model_errors["ts_med"] = np.median(ts_error) 87 | model_errors["ts_max"] = np.max(ts_error) 88 | model_errors["Rs_mean"] = np.mean(Rs_error) 89 | model_errors["Rs_med"] = np.median(Rs_error) 90 | model_errors["Rs_max"] = np.max(Rs_error) 91 | 92 | if refined: 93 | valid_points = outputs['valid_points'] 94 | new_xs = outputs['new_xs'] 95 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points) 96 | else: 97 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs) 98 | model_errors['repro_ba'] = np.nanmean(repro_ba_error) 99 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error) 100 | 101 | Rs_ba_fixed = outputs['Rs_ba_fixed'] 102 | ts_ba_fixed = outputs['ts_ba_fixed'] 103 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt) 104 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error) 105 | model_errors["ts_ba_med"] = np.median(ts_ba_error) 106 | model_errors["ts_ba_max"] = np.max(ts_ba_error) 107 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error) 108 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error) 109 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error) 110 | 111 | return model_errors 112 | 113 | 114 | def runba(): 115 | basedir = "/home/zeeq/results/test10" 116 | refined = True 117 | repeat = True 118 | max_iter = 50 119 | ba_times = 1 120 | repro_thre = 5 121 | 122 | errors_list = [] 123 | for _,_,files in os.walk(basedir): 124 | for f in files: 125 | print(f'processing {f} ...') 126 | npzpath = os.path.join(basedir, f) 127 | outputs, scan_name = _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined) 128 | errors = _compute_errors(outputs, refined) 129 | errors['Scene'] = scan_name 130 | errors_list.append(errors) 131 | 132 | df_errors = pd.DataFrame(errors_list) 133 | mean_errors = df_errors.mean(numeric_only=True) 134 | # df_errors = df_errors.append(mean_errors, ignore_index=True) 135 | df_errors = pd.concat([df_errors, mean_errors], ignore_index=True) 136 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean" 137 | df_errors.set_index("Scene", inplace=True) 138 | df_errors = df_errors.round(3) 139 | print(df_errors.to_string(), flush=True) 140 | 141 | if __name__=="__main__": 142 | runba() 143 | -------------------------------------------------------------------------------- /code/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import geo_utils, general_utils, sparse_utils 3 | import numpy as np 4 | 5 | 6 | def is_valid_sample(data, min_pts_per_cam): 7 | return data.x.pts_per_cam.min().item() >= min_pts_per_cam 8 | 9 | 10 | def divide_indices_to_train_test(N, n_val, n_test=0): 11 | perm = np.random.permutation(N) 12 | test_indices = perm[:n_test] if n_test>0 else [] 13 | val_indices = perm[n_test:n_test+n_val] 14 | train_indices = perm[n_test+n_val:] 15 | return train_indices, val_indices, test_indices 16 | 17 | 18 | def sample_indices(N, num_samples, adjacent): 19 | if num_samples == 1: # Return all the data 20 | indices = np.arange(N) 21 | else: 22 | if num_samples < 1: 23 | num_samples = int(np.ceil(num_samples * N)) 24 | num_samples = max(2, num_samples) 25 | if num_samples>=N: 26 | return np.arange(N) 27 | if adjacent: 28 | start_ind = np.random.randint(0,N-num_samples+1) 29 | end_ind = start_ind+num_samples 30 | indices = np.arange(start_ind, end_ind) 31 | else: 32 | indices = np.random.choice(N,num_samples,replace=False) 33 | return indices 34 | 35 | 36 | def radius_sample(ts, num_samples): 37 | total_num = len(ts) 38 | if num_samples>=total_num: 39 | return np.arange(total_num) 40 | idx = np.random.randint(total_num) 41 | center = ts[idx] 42 | dist = ((ts-center)**2).sum(axis=1) 43 | sorted_ids = np.argsort(dist) 44 | return sorted_ids[:num_samples] 45 | 46 | 47 | def simulate_sample(ts, num_samples, pts_per_cam): 48 | total_num = len(ts) 49 | fsamples = 2*num_samples 50 | if fsamples>=total_num: 51 | return np.arange(total_num) 52 | idx = np.random.randint(total_num) 53 | center = ts[idx] 54 | dist = ((ts-center)**2).sum(axis=1) 55 | sorted_ids = np.argsort(dist) 56 | fids = sorted_ids[:fsamples] 57 | visits = pts_per_cam[fids] 58 | sampled_ids = np.argsort(visits) 59 | return fids[sampled_ids] 60 | 61 | 62 | def save_cameras(outputs, conf, curr_epoch, phase): 63 | # xs = outputs['xs'] 64 | # M = geo_utils.xs_to_M(xs) 65 | general_utils.save_camera_mat(conf, outputs, outputs['scan_name'], phase, curr_epoch) 66 | 67 | 68 | def get_data_statistics(all_data): 69 | valid_pts = all_data.valid_pts 70 | valid_pts_stat = valid_pts.sum(dim=0).float() 71 | stats = {"Max_2d_pt": all_data.M.max().item(), "Num_2d_pts": valid_pts.sum().item(), "n_pts":all_data.M.shape[-1], 72 | "Cameras_per_pts_mean": valid_pts_stat.mean().item(), "Cameras_per_pts_std": valid_pts_stat.std().item(), 73 | "Num of cameras": all_data.ts.shape[0]} 74 | return stats 75 | 76 | 77 | def get_data_statistics2(all_data, pred_data): 78 | constructed_data = pred_data['pts3D_pred'] 79 | constructed_pt3ds_num = constructed_data.shape[1] 80 | valid_pts = all_data.valid_pts 81 | valid_pts_stat = valid_pts.sum(dim=0).float() 82 | mask = all_data.mask 83 | stats = {"Max_2d_pt": all_data.M.max().item(), "Num_2d_pts": valid_pts.sum().item(), "all_pts":all_data.M.shape[-1], 84 | "constructed_pts": constructed_pt3ds_num, "ground_truth_pts": mask[:,mask.sum(axis=0)!=0].shape[1], 85 | "Cameras_per_pts_mean": valid_pts_stat.mean().item(), 86 | "Cameras_per_pts_std": valid_pts_stat.std().item(), "Num of cameras": all_data.ts.shape[0]} 87 | return stats 88 | 89 | 90 | def correct_matches_global(M, Ps, Ns): 91 | M_invalid_pts = np.logical_not(get_M_valid_points(M)) 92 | 93 | Xs = geo_utils.n_view_triangulation(Ps, M, Ns) 94 | xs = geo_utils.batch_pflat((Ps @ Xs))[:, 0:2, :] 95 | 96 | # Remove invalid points 97 | xs[np.isnan(xs)] = 0 98 | xs[np.stack((M_invalid_pts, M_invalid_pts), axis=1)] = 0 99 | 100 | return xs.reshape(M.shape) 101 | 102 | 103 | def get_M_valid_points(M): 104 | n_pts = M.shape[-1] 105 | 106 | if type(M) is torch.Tensor: 107 | # m✖2✖n -> m✖n 108 | M_valid_pts = torch.abs(M.reshape(-1, 2, n_pts)).sum(dim=1) != 0 109 | M_valid_pts[:, M_valid_pts.sum(dim=0) < 2] = False 110 | else: 111 | M_valid_pts = np.abs(M.reshape(-1, 2, n_pts)).sum(axis=1) != 0 112 | M_valid_pts[:, M_valid_pts.sum(axis=0) < 2] = False 113 | 114 | return M_valid_pts 115 | 116 | 117 | def M2sparse(M, normalize=False, Ns=None): 118 | n_pts = M.shape[1] 119 | n_cams = int(M.shape[0] / 2) 120 | 121 | # Get indices 122 | valid_pts = get_M_valid_points(M) # m,n 123 | cam_per_pts = valid_pts.sum(dim=0).unsqueeze(1) # n 124 | pts_per_cam = valid_pts.sum(dim=1).unsqueeze(1) # m 125 | mat_indices = torch.nonzero(valid_pts).T 126 | 127 | # Get Values 128 | # reshaped_M = M.reshape(n_cams, 2, n_pts).transpose(1, 2) # [2m, n] -> [m, 2, n] -> [m, n, 2] 129 | if normalize: 130 | norm_M = geo_utils.normalize_M(M, Ns) 131 | mat_vals = norm_M[mat_indices[0], mat_indices[1], :] # nnz,2 132 | else: 133 | mat_vals = M.reshape(n_cams, 2, n_pts).transpose(1, 2)[mat_indices[0], mat_indices[1], :] 134 | 135 | mat_shape = (n_cams, n_pts, 2) 136 | return sparse_utils.SparseMat(mat_vals, mat_indices, cam_per_pts, pts_per_cam, mat_shape) 137 | 138 | 139 | def order_indices(indices, shuffle): 140 | if shuffle: 141 | np.random.shuffle(indices) 142 | M_indices = np.zeros(len(indices)*2, dtype=np.int64) 143 | M_indices[::2] = 2 * indices 144 | M_indices[1::2] = 2 * indices + 1 145 | else: 146 | indices.sort() 147 | M_indices = np.sort(np.concatenate((2 * indices, 2 * indices + 1))) 148 | return indices, M_indices 149 | 150 | 151 | def get_mask_by_reproj(M, M_gt, thred_square): 152 | n_pts = M.shape[-1] 153 | reproj_err = np.square(M.reshape(-1, 2, n_pts).transpose(0,2,1) - M_gt.reshape(-1, 2, n_pts).transpose(0,2,1)).sum(axis=-1) # m*n 154 | out_mat = reproj_err>thred_square 155 | ans = np.zeros((out_mat.shape[0], out_mat.shape[1], 2), dtype = bool) 156 | ans[:,:,0] = out_mat 157 | ans[:,:,1] = out_mat 158 | ans = ans.transpose(0,2,1).reshape(-1, n_pts) 159 | tmpM = np.copy(M) 160 | tmpM[ans] = 0 161 | return np.where(tmpM!=0, 1., 0.) -------------------------------------------------------------------------------- /code/models/SetOfSet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.baseNet import BaseNet 4 | from models.layers import * 5 | from utils import geo_utils 6 | 7 | 8 | class SetOfSetBlock(nn.Module): 9 | def __init__(self, d_in, d_out, conf): 10 | super(SetOfSetBlock, self).__init__() 11 | self.block_size = conf.get_int("model.block_size") 12 | self.use_skip = conf.get_bool("model.use_skip") 13 | 14 | modules = [] 15 | # modules.extend([SetOfSetLayer(d_in, 16), 16 | # NormalizationLayer(), 17 | # ActivationLayer(), 18 | # SetOfSetLayer(16, 64), 19 | # NormalizationLayer(), 20 | # ActivationLayer(), 21 | # SetOfSetLayer(64, 256), 22 | # NormalizationLayer(), 23 | # ]) 24 | modules.extend([SetOfSetLayer(d_in, d_out), NormalizationLayer()]) 25 | for i in range(1, self.block_size): 26 | modules.extend([ActivationLayer(), SetOfSetLayer(d_out, d_out), NormalizationLayer()]) 27 | self.layers = nn.Sequential(*modules) 28 | 29 | self.final_act = ActivationLayer() 30 | 31 | if self.use_skip: 32 | if d_in == d_out: 33 | self.skip = IdentityLayer() 34 | else: 35 | self.skip = nn.Sequential(ProjLayer(d_in, d_out), NormalizationLayer()) 36 | 37 | def forward(self, x): 38 | # x is [m,n,d] sparse matrix 39 | xl = self.layers(x) 40 | if self.use_skip: 41 | xl = self.skip(x) + xl 42 | 43 | out = self.final_act(xl) 44 | return out 45 | 46 | # def __init__(self, d_in, d_out, conf): 47 | # super(SetOfSetBlock, self).__init__() 48 | # self.block_size = conf.get_int("model.block_size") 49 | 50 | # self.firstlayer = SetOfSetLayer(d_in, d_out) 51 | # self.submodules1 = nn.Sequential(NormalizationLayer(), ActivationLayer(), SetOfSetLayer(d_out, d_out)) 52 | # self.submodules2 = nn.Sequential(NormalizationLayer(), ActivationLayer(), SetOfSetLayer(d_out, d_out)) 53 | # self.lastlayers = nn.Sequential(NormalizationLayer(), ActivationLayer()) 54 | 55 | # def forward(self, x): 56 | # # x is [m,n,d] sparse matrix 57 | # x = self.firstlayer(x) 58 | # x = self.submodules1(x) + x 59 | # x = self.submodules2(x) + x 60 | # out = self.lastlayers(x) 61 | # return out 62 | 63 | 64 | class SetOfSetNet(BaseNet): 65 | def __init__(self, conf): 66 | super(SetOfSetNet, self).__init__(conf) 67 | # n is the number of points and m is the number of cameras 68 | num_blocks = conf.get_int('model.num_blocks') 69 | num_feats = conf.get_int('model.num_features') 70 | self.train_trans = conf.get_bool('train.train_trans') 71 | self.use_spatial_encoder = conf.get_bool('dataset.use_spatial_encoder') 72 | self.x_embed_rank = conf.get_int('dataset.x_embed_rank') 73 | self.egps_embed_rank = conf.get_int('dataset.egps_embed_rank') 74 | self.dsc_egps_embed_width = conf.get_int('dataset.dsc_egps_embed_width') 75 | self.gps_embed_width = conf.get_int('dataset.gps_embed_width') 76 | 77 | m_d_out_rot = 4 78 | m_d_out_trans = 3 79 | d_in = 2 80 | 81 | self.embed_x = EmbeddingLayer(self.x_embed_rank, d_in) 82 | self.embed_egps = EmbeddingLayer(self.egps_embed_rank, 3) 83 | self.embed_dsc_egps = ProjLayer(128+self.embed_egps.d_out, self.dsc_egps_embed_width) 84 | 85 | if self.use_spatial_encoder: 86 | self.equivariant_blocks = torch.nn.ModuleList([SetOfSetBlock(self.embed_x.d_out + self.dsc_egps_embed_width, num_feats, conf)]) 87 | else: 88 | self.equivariant_blocks = torch.nn.ModuleList([SetOfSetBlock(2, num_feats, conf)]) 89 | 90 | for i in range(num_blocks - 1): 91 | self.equivariant_blocks.append(SetOfSetBlock(num_feats, num_feats, conf)) 92 | 93 | self.pt_net = nn.Sequential(ProjLayer(num_feats, num_feats//2), 94 | RCNormLayer(), 95 | ActivationLayer(), 96 | ProjLayer(num_feats, num_feats//2), 97 | RCNormLayer(), 98 | ActivationLayer(), 99 | ProjLayer(num_feats, 1)) 100 | 101 | # self.embed_gps = PosiEmbedding(self.egps_embed_rank) #nn.Linear(3, 128) 102 | self.embed_gps = nn.Linear(3, self.gps_embed_width) 103 | 104 | # self.m_net = get_linear_layers([num_feats] * 2 + [m_d_out], final_layer=True, batchnorm=False) 105 | # self.n_net = get_linear_layers([num_feats] * 2 + [n_d_out], final_layer=True, batchnorm=False) 106 | self.m_net_rot = get_linear_layers([num_feats] * 2 + [m_d_out_rot], final_layer=True, batchnorm=False) 107 | self.m_net_tran = get_linear_layers([num_feats+self.gps_embed_width, 256, m_d_out_trans], final_layer=True, batchnorm=False) 108 | # self.m_net_tran = get_linear_layers([num_feats+6*self.egps_embed_rank+3, 256, m_d_out_trans], final_layer=True, batchnorm=False) 109 | 110 | def forward(self, data): 111 | x = data.x # x is [m,n,d] sparse matrix 112 | # x = self.embed_x(x) 113 | if self.use_spatial_encoder: 114 | x = self.embed_x(x) 115 | egps = self.embed_egps(data.egps_sparse) 116 | dsc_egps = data.dsc.last_dim_cat(egps) 117 | embed_de = self.embed_dsc_egps(dsc_egps) 118 | x = x.last_dim_cat(embed_de) 119 | # x = x.last_dim_cat(self.embed_dsc_egps(data.dsc.last_dim_cat(self.embed_egps(data.egps_sparse)))) 120 | 121 | for eq_block in self.equivariant_blocks: 122 | x = eq_block(x) # [m,n,d_in] -> [m,n,d_out] 123 | 124 | # pt class predictions 125 | pt_out = self.pt_net(x) # [m,n,d_out] -> [m,n,1] 126 | 127 | # Cameras predictions 128 | # x = x*pt_out # [m,n,d_out] -> [m,n,d_out] 129 | m_input = x.mean(dim=1) # [m,d_out] 130 | # m_out = self.m_net(m_input) # [m, d_m] 131 | # pred_cam = self.extract_camera_outputs(m_out) 132 | rot = self.m_net_rot(m_input) 133 | # normed_gps = geo_utils.scale_ts_norm(data.gpss) 134 | if self.train_trans: 135 | tran = self.m_net_tran(torch.cat([m_input, self.embed_gps(data.gpss)], -1)) 136 | # tran = self.m_net_tran(m_input) 137 | pred_cam = {"quats": rot, "ts": tran} 138 | else: 139 | pred_cam = {"quats": rot} 140 | # tran = self.m_net_tran(m_input) 141 | # pred_cam = {"quats": rot, "ts": data.gpss} 142 | 143 | return pt_out, pred_cam -------------------------------------------------------------------------------- /code/utils/general_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datetime import datetime 3 | from scipy.io import savemat 4 | import shutil 5 | from pyhocon import HOCONConverter,ConfigTree 6 | import sys 7 | import json 8 | from pyhocon import ConfigFactory 9 | import argparse 10 | import os 11 | import numpy as np 12 | import pandas as pd 13 | import portalocker 14 | from utils.Phases import Phases 15 | from utils.path_utils import path_to_exp, path_to_cameras, path_to_code_logs, path_to_conf 16 | import random 17 | 18 | 19 | def log_code(conf): 20 | code_path = path_to_code_logs(conf) 21 | 22 | files_to_log = ["train.py", "single_scene_optimization.py", "multiple_scenes_learning.py", "loss_functions.py"] 23 | for file_name in files_to_log: 24 | shutil.copyfile('{}'.format(file_name), os.path.join(code_path, file_name)) 25 | 26 | dirs_to_log = ["datasets", "models"] 27 | for dir_name in dirs_to_log: 28 | shutil.copytree('{}'.format(dir_name), os.path.join(code_path, dir_name)) 29 | 30 | # Print conf 31 | with open(os.path.join(code_path, 'exp.conf'), 'w') as conf_log_file: 32 | conf_log_file.write(HOCONConverter.convert(conf, 'hocon')) 33 | 34 | 35 | def save_camera_mat(conf, save_cam_dict, scan, phase, epoch=None): 36 | path_cameras = path_to_cameras(conf, phase, epoch=epoch, scan=scan) 37 | np.savez(path_cameras, **save_cam_dict) 38 | #savemat(path_cameras, save_cam_dict) 39 | 40 | 41 | def write_results(conf, df, file_name="Results", append=False): 42 | exp_path = path_to_exp(conf) 43 | results_file_path = os.path.join(exp_path, '{}.xlsx'.format(file_name)) 44 | 45 | if append: 46 | locker_file = os.path.join(exp_path, '{}.lock'.format(file_name)) 47 | lock = portalocker.Lock(locker_file, timeout=1000) 48 | with lock: 49 | if os.path.exists(results_file_path): 50 | prev_df = pd.read_excel(results_file_path).set_index("Scene") 51 | merged_err_df = prev_df.append(df) 52 | else: 53 | merged_err_df = df 54 | 55 | merged_err_df.to_excel(results_file_path) 56 | else: 57 | df.to_excel(results_file_path) 58 | 59 | 60 | def init_exp_version(): 61 | return '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now()) 62 | 63 | 64 | def get_class(kls): 65 | parts = kls.split('.') # models SetOfSet SetOfSetNet 66 | module = ".".join(parts[:-1]) # models.SetOfSet 67 | m = __import__(module) 68 | for comp in parts[1:]: # SetOfSet SetOfSetNet 69 | m = getattr(m, comp) 70 | return m 71 | 72 | 73 | def count_parameters(model): 74 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 75 | 76 | 77 | def print_error(err_string): 78 | print(err_string, file=sys.stderr) 79 | 80 | 81 | def config_tree_to_string(config): 82 | config_dict={} 83 | for it in config.keys(): 84 | if isinstance(config[it],ConfigTree): 85 | it_dict = {key:val for key,val in config[it].items()} 86 | config_dict[it]=it_dict 87 | else: 88 | config_dict[it] = config[it] 89 | return json.dumps(config_dict) 90 | 91 | 92 | def bmvm(bmats, bvecs): 93 | return torch.bmm(bmats, bvecs.unsqueeze(-1)).squeeze() 94 | 95 | 96 | def get_full_conf_vals(conf): 97 | # return a conf file as a dictionary as follow: 98 | # "key.key.key...key": value 99 | # Useful for the conf.put() command 100 | full_vals = {} 101 | for key, val in conf.items(): 102 | if isinstance(val, dict): 103 | part_vals = get_full_conf_vals(val) 104 | for part_key, part_val in part_vals.items(): 105 | full_vals[key + "." +part_key] = part_val 106 | else: 107 | full_vals[key] = val 108 | 109 | return full_vals 110 | 111 | 112 | def parse_external_params(ext_params_str, conf): 113 | for param in ext_params_str.split(','): 114 | key_val = param.split(':') 115 | if len(key_val) == 3: 116 | conf[key_val[0]][key_val[1]] = key_val[2] 117 | elif len(key_val) == 2: 118 | conf[key_val[0]] = key_val[1] 119 | return conf 120 | 121 | 122 | def init_exp(default_phase): 123 | # Parse Arguments 124 | parser = argparse.ArgumentParser() 125 | parser.add_argument('--conf', type=str) 126 | parser.add_argument('--scan', type=str, default=None) 127 | parser.add_argument('--exp_version', type=str, default=None) 128 | parser.add_argument('--external_params', type=str, default=None) 129 | parser.add_argument('--phase', type=str, default=default_phase) 130 | opt = parser.parse_args() 131 | 132 | # Init Device 133 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 134 | 135 | # Init Conf 136 | conf_file_path = path_to_conf(opt.conf) 137 | conf = ConfigFactory.parse_file(conf_file_path) 138 | conf["original_file_name"] = opt.conf 139 | 140 | # Init external params 141 | if opt.external_params is not None: 142 | conf = parse_external_params(opt.external_params, conf) 143 | 144 | # Init Version 145 | if opt.exp_version is None: 146 | exp_version = init_exp_version() 147 | else: 148 | exp_version = opt.exp_version 149 | conf['exp_version'] = exp_version 150 | 151 | # Init scan 152 | if opt.scan is not None: 153 | conf['dataset']['scan'] = opt.scan 154 | elif 'scan' not in conf['dataset'].keys(): 155 | conf['dataset']['scan'] = 'Multiple_Scenes' 156 | 157 | # Init Seed 158 | seed = conf.get_int('random_seed', default=None) 159 | if seed is not None: 160 | torch.manual_seed(seed) 161 | np.random.seed(seed) 162 | 163 | # Init Phase 164 | phase = Phases[opt.phase] 165 | 166 | return conf, device, phase 167 | 168 | 169 | def compute_confusion_matrix(precited, expected): 170 | part = precited ^ expected 171 | pcount = np.bincount(part) 172 | tp_list = list(precited & expected) 173 | fp_list = list(precited & ~expected) 174 | tp = tp_list.count(1) 175 | fp = fp_list.count(1) 176 | tn = pcount[0] - tp 177 | if len(pcount)==2: fn = pcount[1] - fp 178 | else: fn = 0 179 | return tp, fp, tn, fn 180 | 181 | 182 | def compute_indexes(tp, fp, tn, fn): 183 | try: 184 | accuracy = (tp+tn) / (tp+tn+fp+fn) 185 | except(ZeroDivisionError): 186 | print("ZeroDivisionError: division by zero") 187 | accuracy = np.zeros_like(tp) 188 | 189 | try: 190 | precision = tp / (tp+fp) 191 | except(ZeroDivisionError): 192 | print("ZeroDivisionError: division by zero") 193 | precision = np.zeros_like(tp) 194 | 195 | try: 196 | recall = tp / (tp+fn) 197 | except(ZeroDivisionError): 198 | print("ZeroDivisionError: division by zero") 199 | recall = np.zeros_like(tp) 200 | 201 | try: 202 | F1 = (2*precision*recall) / (precision+recall) # F1 203 | except(ZeroDivisionError): 204 | print("ZeroDivisionError: division by zero") 205 | F1 = np.zeros_like(tp) 206 | 207 | return accuracy, precision, recall, F1 -------------------------------------------------------------------------------- /code/re_ba.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import geo_utils, ba_functions, general_utils 3 | import os 4 | import pandas as pd 5 | from scipy.sparse import coo_matrix 6 | 7 | 8 | # get Ps_gt 9 | # m,3,3 m,3,3 m,3 10 | def get_Ps_gt(K_gt, R_gt, T_gt): 11 | m = R_gt.shape[0] 12 | K_RT = np.matmul(K_gt, R_gt.transpose((0, 2, 1))) # K@R.T 13 | Is = np.expand_dims(np.identity(3), 0).repeat(m, axis = 0) 14 | tmp = np.concatenate((Is, -np.expand_dims(T_gt, 2)), axis = 2) # I|-t 15 | Ps_gt = np.matmul(K_RT, tmp) 16 | return Ps_gt 17 | 18 | 19 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined): 20 | outputs = {} 21 | 22 | scan_name = npzpath.split('/')[-1].split('.')[0] 23 | file = dict(np.load(npzpath)) 24 | M_row = file['M_row'] 25 | M_col = file['M_col'] 26 | M_data = file['M_data'] 27 | M_shape = file['M_shape'] 28 | mask_row = file['mask_row'] 29 | mask_col = file['mask_col'] 30 | mask_data = file['mask_data'] 31 | mask_shape = file['mask_shape'] 32 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A 33 | mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A 34 | 35 | enu = file['enu'] 36 | Ns = file['Ns'] 37 | Rs = file['Rs'] 38 | ts = file['ts'] 39 | quats = file['quats'] 40 | color = file['rgbs'] 41 | Ks = np.linalg.inv(Ns) 42 | Ps = get_Ps_gt(Ks, Rs, enu) 43 | 44 | M_gt = M*mask 45 | colidx = (mask!=0).sum(axis=0)>2 46 | M_gt = M_gt[:, colidx] 47 | color_gt = color[:, colidx] 48 | xs_gt = geo_utils.M_to_xs(M_gt) 49 | 50 | xs_raw = geo_utils.M_to_xs(M) 51 | Xs_gt = geo_utils.n_view_triangulation(Ps, M=M_gt, Ns=Ns) 52 | Xs_raw = geo_utils.n_view_triangulation(Ps, M=M, Ns=Ns) 53 | 54 | Rs_error_b, ts_error_b = geo_utils.tranlsation_rotation_errors(Rs, ts, Rs, enu) 55 | print(f"before BA, Rs_error={Rs_error_b}, ts_error={ts_error_b}") 56 | 57 | ba_res = ba_functions.euc_ba(xs_gt, xs_raw, Rs=Rs, ts=enu, Ks=Ks, Xs=Xs_gt.T, Ps=Ps, Ns=Ns, 58 | repeat=repeat, max_iter=max_iter, ba_times=ba_times, 59 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs 60 | outputs['Rs_ba'] = ba_res['Rs'] 61 | outputs['ts_ba'] = ba_res['ts'] 62 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n 63 | outputs['Ps_ba'] = ba_res['Ps'] 64 | if refined: 65 | outputs['valid_points'] = ba_res['valid_points'] 66 | outputs['new_xs'] = ba_res['new_xs'] 67 | outputs['colidx'] = ba_res['colidx'] 68 | # M_new = geo_utils.xs_to_M(outputs['new_xs']) 69 | # mask_new = np.array((M_new!=0), dtype=np.int8) 70 | # colidx2 = (M_new!=0).sum(axis=0)>2 71 | # M_new = M_new[:,colidx2] 72 | 73 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs, ba_res['ts'], enu, 74 | return_alignment=True) # Align Rs_fixed, tx_fixed 75 | outputs['Rs_ba_fixed'] = R_ba_fixed 76 | outputs['ts_ba_fixed'] = t_ba_fixed 77 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba']) 78 | 79 | Rs_error_a, ts_error_a = geo_utils.tranlsation_rotation_errors(R_ba_fixed, t_ba_fixed, Rs, enu) 80 | print(f"after BA, Rs_error={Rs_error_a}, ts_error={ts_error_a}") 81 | np.savetxt(ts_path, t_ba_fixed, delimiter=',') 82 | np.savetxt(enu_path, enu, delimiter=',') 83 | input() 84 | file.update(outputs) 85 | # np.savez(npzpath, **file) 86 | 87 | return file, scan_name 88 | 89 | 90 | def _compute_errors(outputs, refined=False): 91 | model_errors = {} 92 | 93 | premask = outputs['premask'] 94 | gtmask = outputs['gtmask'] 95 | 96 | tp, fp, tn, fn = general_utils.compute_confusion_matrix(premask, gtmask) 97 | accuracy, precision, recall, F1 = general_utils.compute_indexes(tp, fp, tn, fn) 98 | model_errors["TP"] = tp 99 | model_errors["FP"] = fp 100 | model_errors["TN"] = tn 101 | model_errors["FN"] = fn 102 | model_errors["Accuracy"] = accuracy 103 | model_errors["Precision"] = precision 104 | model_errors["Recall"] = recall 105 | model_errors["F1"] = F1 106 | 107 | pts3D_pred = outputs['pts3D_pred'] 108 | Ps = outputs['Ps'] 109 | Rs_fixed = outputs['Rs'] 110 | ts_fixed = outputs['ts'] 111 | Rs_gt = outputs['Rs_gt'] 112 | ts_gt = outputs['ts_gt'] 113 | xs = outputs['xs'] 114 | Xs_ba = outputs['Xs_ba'] 115 | Ps_ba = outputs['Ps_ba'] 116 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs) 117 | if not our_repro_error.shape: return model_errors 118 | model_errors["our_repro"] = np.nanmean(our_repro_error) 119 | model_errors["our_repro_max"] = np.nanmax(our_repro_error) 120 | 121 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt) 122 | model_errors["ts_mean"] = np.mean(ts_error) 123 | model_errors["ts_med"] = np.median(ts_error) 124 | model_errors["ts_max"] = np.max(ts_error) 125 | model_errors["Rs_mean"] = np.mean(Rs_error) 126 | model_errors["Rs_med"] = np.median(Rs_error) 127 | model_errors["Rs_max"] = np.max(Rs_error) 128 | 129 | if refined: 130 | valid_points = outputs['valid_points'] 131 | new_xs = outputs['new_xs'] 132 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points) 133 | else: 134 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs) 135 | model_errors['repro_ba'] = np.nanmean(repro_ba_error) 136 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error) 137 | 138 | Rs_ba_fixed = outputs['Rs_ba_fixed'] 139 | ts_ba_fixed = outputs['ts_ba_fixed'] 140 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt) 141 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error) 142 | model_errors["ts_ba_med"] = np.median(ts_ba_error) 143 | model_errors["ts_ba_max"] = np.max(ts_ba_error) 144 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error) 145 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error) 146 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error) 147 | 148 | return model_errors 149 | 150 | 151 | def runba(): 152 | basedir = "" 153 | refined = True 154 | repeat = True 155 | max_iter = 50 156 | ba_times = 1 157 | repro_thre = 2 158 | 159 | errors_list = [] 160 | for _,_,files in os.walk(basedir): 161 | for f in files: 162 | print(f'processing {f} ...') 163 | npzpath = os.path.join(basedir, f) 164 | outputs, scan_name = _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined) 165 | errors = _compute_errors(outputs, refined) 166 | errors['Scene'] = scan_name 167 | errors_list.append(errors) 168 | 169 | df_errors = pd.DataFrame(errors_list) 170 | mean_errors = df_errors.mean(numeric_only=True) 171 | df_errors = df_errors.append(mean_errors, ignore_index=True) 172 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean" 173 | df_errors.set_index("Scene", inplace=True) 174 | df_errors = df_errors.round(3) 175 | print(df_errors.to_string(), flush=True) 176 | 177 | if __name__=="__main__": 178 | runba() -------------------------------------------------------------------------------- /code/datasets/Euclidean.py: -------------------------------------------------------------------------------- 1 | import cv2 # Do not remove 2 | import torch 3 | 4 | import sys 5 | sys.path.append('/path/to/deepaat/code') 6 | import utils.path_utils 7 | from utils import geo_utils, general_utils, dataset_utils 8 | import scipy.io as sio 9 | import numpy as np 10 | import os.path 11 | from scipy.sparse import coo_matrix 12 | 13 | 14 | def get_raw_data(conf, scan, flag): 15 | """ 16 | :param conf: 17 | :return: 18 | M - Points Matrix (2mxn) 19 | Ns - Inversed Calibration matrix (Ks-1) (mx3x3) 20 | Ps_gt - GT projection matrices (mx3x4) 21 | NBs - Normzlize Bifocal Tensor (En) (3mx3m) 22 | mask - Inlier points mask (2mxn) 23 | triplets 24 | """ 25 | 26 | # Init 27 | # dataset_path_format = os.path.join(utils.path_utils.path_to_datasets(), 'Euclidean', '{}.npz') 28 | # dataset_path_format = os.path.join(conf.get_string('dataset.dataset_path'), '{}.npz') 29 | if flag==0: 30 | dataset_path_format = os.path.join(conf.get_string('dataset.trainset_path'), '{}.npz') 31 | elif flag==1: 32 | dataset_path_format = os.path.join(conf.get_string('dataset.valset_path'), '{}.npz') 33 | else: 34 | dataset_path_format = os.path.join(conf.get_string('dataset.testset_path'), '{}.npz') 35 | 36 | # Get conf parameters 37 | if scan is None: 38 | scan = conf.get_string('dataset.scan') 39 | 40 | # Get raw data 41 | dataset = np.load(dataset_path_format.format(scan)) 42 | 43 | # Get bifocal tensors and 2D points 44 | # M = dataset['M'] 45 | # mask = dataset['mask'] 46 | Rs = dataset['Rs'] 47 | ts = dataset['ts'] 48 | # ts = dataset['enu'] 49 | quats = dataset['quats'] 50 | Ns = dataset['Ns'] 51 | M_col = dataset['M_col'] 52 | M_row = dataset['M_row'] 53 | M_data = dataset['M_data'] 54 | M_shape = dataset['M_shape'] 55 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A 56 | mask_col = dataset['mask_col'] 57 | mask_row = dataset['mask_row'] 58 | mask_data = dataset['mask_data'] 59 | mask_shape = dataset['mask_shape'] 60 | mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A 61 | # M_gt = dataset_utils.correct_matches_global(M, Ps_gt, Ns) 62 | # mask = dataset_utils.get_mask_by_reproj(M, M_gt, 2) 63 | gpss = dataset['enu_noisy'] 64 | # gpss = dataset['enu'] 65 | 66 | scale = 1.0 67 | if conf.get_bool('train.train_trans'): 68 | t0 = ts[0] 69 | nts = ts-t0 70 | scale = np.max(np.linalg.norm(nts, axis=1)) 71 | ts = nts/scale 72 | gpss = (gpss-t0)/scale 73 | 74 | # use_gt = conf.get_bool('dataset.use_gt') 75 | # if use_gt: 76 | # M = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float() 77 | 78 | use_spatial_encoder = conf.get_bool('dataset.use_spatial_encoder') 79 | 80 | if flag==0: # shuffle row and col of train set 81 | indices = torch.randperm(Ns.shape[0]) 82 | M_indices = torch.zeros(len(indices)*2, dtype=torch.int64) 83 | M_indices[::2] = 2 * indices 84 | M_indices[1::2] = 2 * indices + 1 85 | M = torch.from_numpy(M).float()[M_indices] 86 | Rs = torch.from_numpy(Rs).float()[indices] 87 | ts = torch.from_numpy(ts).float()[indices] 88 | quats = torch.from_numpy(quats).float()[indices] 89 | Ns = torch.from_numpy(Ns).float()[indices] 90 | mask = torch.from_numpy(mask).float()[M_indices] 91 | gpss = torch.from_numpy(gpss).float()[indices] 92 | # shuffle column 93 | idx = torch.randperm(M.shape[1]) 94 | M = M[:,idx] 95 | mask = mask[:,idx] 96 | color = None 97 | 98 | if use_spatial_encoder: 99 | dsc_idx_np = dataset['D_idxs'] 100 | dsc_data_np = dataset['D_data'] 101 | dsc_shape_np = dataset['D_shape'] 102 | ord_data = np.arange(dsc_idx_np.shape[1]) + 1 103 | ord_shape = (dsc_shape_np[0], dsc_shape_np[1]) 104 | ord_dense = coo_matrix((ord_data, (dsc_idx_np[0], dsc_idx_np[1])), shape=ord_shape).todense().A 105 | ord_dense = ord_dense[indices,:] 106 | ord_dense = ord_dense[:,idx] 107 | ord_sparse = coo_matrix(ord_dense) 108 | ord_new = ord_sparse.data - 1 109 | dsc_idx = torch.from_numpy(np.vstack((ord_sparse.row, ord_sparse.col))).int() 110 | dsc_data = torch.from_numpy(dsc_data_np[ord_new, :]/128.0).float() # nnz,128 111 | dsc_shape = torch.from_numpy(dsc_shape_np).int() 112 | else: 113 | dsc_idx = None 114 | dsc_data = None 115 | dsc_shape = None 116 | 117 | else: 118 | M = torch.from_numpy(M).float() 119 | Rs = torch.from_numpy(Rs).float() 120 | ts = torch.from_numpy(ts).float() 121 | quats = torch.from_numpy(quats).float() 122 | Ns = torch.from_numpy(Ns).float() 123 | mask = torch.from_numpy(mask).float() 124 | gpss = torch.from_numpy(gpss).float() 125 | color = torch.from_numpy(dataset['rgbs']).int() 126 | 127 | if use_spatial_encoder: 128 | dsc_idx = torch.from_numpy(dataset['D_idxs']).int() 129 | dsc_data = torch.from_numpy(dataset['D_data']/128.0).float() 130 | dsc_shape = torch.from_numpy(dataset['D_shape']).int() 131 | else: 132 | dsc_idx = None 133 | dsc_data = None 134 | dsc_shape = None 135 | 136 | # Add Noise 137 | if conf.get_bool("dataset.addNoise"): 138 | noise_mean = conf.get_float("dataset.noise_mean") 139 | noise_std = conf.get_float("dataset.noise_std") 140 | noise_radio = conf.get_float("dataset.noise_radio") 141 | M = geo_utils.addNoise(M, noise_mean, noise_std, noise_radio) 142 | # dsc_data = geo_utils.addNoise(dsc_data, noise_mean, noise_std, noise_radio) 143 | 144 | return M, Ns, Rs, ts, quats, mask, gpss, color, scale, use_spatial_encoder, dsc_idx, dsc_data, dsc_shape 145 | 146 | 147 | def test_Ps_M(Ps, M, Ns): 148 | global_rep_err = geo_utils.calc_global_reprojection_error(Ps.numpy(), M.numpy(), Ns.numpy()) 149 | print("Reprojection Error: Mean = {}, Max = {}".format(np.nanmean(global_rep_err), np.nanmax(global_rep_err))) 150 | 151 | 152 | def test_euclidean_dataset(scan): 153 | # dataset_path_format = os.path.join(utils.path_utils.path_to_datasets(), 'Euclidean', '{}.npz') 154 | 155 | # # Get raw data 156 | # dataset = np.load(dataset_path_format.format(scan)) 157 | dataset = np.load("/path/to/data.npz") 158 | 159 | # Get bifocal tensors and 2D points 160 | M = dataset['M'] 161 | # M_col = dataset['M_col'] 162 | # M_row = dataset['M_row'] 163 | # M_data = dataset['M_data'] 164 | # M_shape = dataset['M_shape'] 165 | # M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A 166 | # mask_col = dataset['mask_col'] 167 | # mask_row = dataset['mask_row'] 168 | # mask_data = dataset['mask_data'] 169 | # mask_shape = dataset['mask_shape'] 170 | # mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A 171 | Ps_gt = dataset['Ps_gt'] 172 | Ns = dataset['Ns'] 173 | 174 | print(M.shape) 175 | M_gt = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float() 176 | print(M_gt.shape) 177 | 178 | M = torch.from_numpy(M).float() 179 | Ps_gt = torch.from_numpy(Ps_gt).float() 180 | Ns = torch.from_numpy(Ns).float() 181 | 182 | print("Test Ps and M") 183 | test_Ps_M(Ps_gt, M, Ns) 184 | 185 | print("Test Ps and M_gt") 186 | test_Ps_M(Ps_gt, M_gt, Ns) 187 | 188 | 189 | if __name__ == "__main__": 190 | scan = "Alcatraz Courtyard" 191 | test_euclidean_dataset(scan) -------------------------------------------------------------------------------- /code/models/layers.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | from cv2 import norm 3 | import torch 4 | from torch.nn import Linear, ReLU, BatchNorm1d, Sequential, Module, Identity, Sigmoid, Tanh 5 | from utils.sparse_utils import SparseMat 6 | from utils.pos_enc_utils import get_embedder 7 | 8 | 9 | def get_linear_layers(feats, final_layer=False, batchnorm=True): 10 | layers = [] 11 | 12 | # feats = 256*2 + out_dim 13 | # Add layers 14 | for i in range(len(feats) - 2): 15 | layers.append(Linear(feats[i], feats[i + 1])) 16 | 17 | if batchnorm: 18 | layers.append(BatchNorm1d(feats[i + 1], track_running_stats=False)) 19 | 20 | layers.append(ReLU()) 21 | 22 | # Add final layer 23 | layers.append(Linear(feats[-2], feats[-1])) 24 | if not final_layer: 25 | if batchnorm: 26 | layers.append(BatchNorm1d(feats[-1], track_running_stats=False)) 27 | 28 | layers.append(ReLU()) 29 | 30 | return Sequential(*layers) 31 | 32 | 33 | class Parameter3DPts(torch.nn.Module): 34 | def __init__(self, n_pts): 35 | super().__init__() 36 | 37 | # Init points randomly 38 | pts_3d = torch.normal(mean=0, std=0.1, size=(3, n_pts), requires_grad=True) 39 | 40 | self.pts_3d = torch.nn.Parameter(pts_3d) 41 | 42 | def forward(self): 43 | return self.pts_3d 44 | 45 | 46 | class SetOfSetLayer(Module): 47 | def __init__(self, d_in, d_out): 48 | super(SetOfSetLayer, self).__init__() 49 | # n is the number of points and m is the number of cameras 50 | self.lin_all = Linear(d_in, d_out) # w1 51 | self.lin_n = Linear(d_in, d_out) # w2 52 | self.lin_m = Linear(d_in, d_out) # w3 53 | self.lin_both = Linear(d_in, d_out) # w4 54 | 55 | def forward(self, x): 56 | # x is [m,n,d] sparse matrix 57 | out_all = self.lin_all(x.values) # [nnz,d_in] -> [nnz,d_out] 58 | 59 | mean_rows = x.mean(dim=0) # [m,n,d_in] -> [n,d_in] 60 | out_rows = self.lin_n(mean_rows) # [n,d_in] -> [n,d_out] 61 | 62 | mean_cols = x.mean(dim=1) # [m,n,d_in] -> [m,d_in] 63 | out_cols = self.lin_m(mean_cols) # [m,d_in] -> [m,d_out] 64 | 65 | out_both = self.lin_both(x.values.mean(dim=0, keepdim=True)) # [1,d_in] -> [1,d_out] 66 | 67 | new_features = (out_all + out_rows[x.indices[1], :] + out_cols[x.indices[0], :] + out_both) / 4 # [nnz,d_out] 68 | new_shape = (x.shape[0], x.shape[1], new_features.shape[1]) 69 | 70 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 71 | 72 | 73 | class ProjLayer(Module): 74 | def __init__(self, d_in, d_out): 75 | super(ProjLayer, self).__init__() 76 | # n is the number of points and m is the number of cameras 77 | self.lin_all = Linear(d_in, d_out) 78 | 79 | def forward(self, x): 80 | # x is [m,n,d] sparse matrix 81 | new_features = self.lin_all(x.values) # [nnz,d_in] -> [nnz,d_out] 82 | new_shape = (x.shape[0], x.shape[1], new_features.shape[1]) 83 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 84 | 85 | 86 | class NormalizationLayer(Module): 87 | def forward(self, x): 88 | features = x.values 89 | norm_features = features - features.mean(dim=0, keepdim=True) 90 | norm_features = norm_features / norm_features.std(dim=0, keepdim=True) 91 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) 92 | 93 | 94 | class CenterNorm(Module): 95 | def forward(self, x): 96 | features = x.values 97 | norm_features = features - features.mean(dim=0, keepdim=True) 98 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) 99 | 100 | 101 | class LastNorm(Module): 102 | def forward(self, x): 103 | features = x.values 104 | norm_features = features - features.mean(dim=0, keepdim=True) 105 | maxnum = torch.max(norm_features.max(), -norm_features.min()) 106 | norm_features = norm_features*5/maxnum 107 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) 108 | 109 | 110 | class RCNormLayer(Module): 111 | def forward(self, x): # x.shape = m,n,d 112 | # dim0 113 | mean0 = x.mean(dim=0) # n,d 114 | meandif0 = x.values - mean0[x.indices[1], :] # nnz,d 115 | # dim1 116 | mean1 = x.mean(dim=1) # m,d 117 | meandif1 = x.values - mean1[x.indices[0], :] # nnz,d 118 | # output 119 | new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d 120 | new_vals = torch.cat([meandif0, meandif1], -1) # m,n,2d 121 | return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 122 | 123 | 124 | 125 | class InstanceNormLayer(Module): 126 | # def forward(self, x): # x.shape = m,n,d 127 | # n_features = x.shape[2] # d 128 | # # dim0 129 | # mean0 = x.mean(dim=0) # n,d 130 | # meandif0 = x.values - mean0[x.indices[1], :] # nnz,d 131 | # sum0 = torch.zeros(x.shape[1], n_features, device=x.device) # n,d 132 | # sum0.index_add(0, x.indices[1], meandif0.pow(2.0)) 133 | # norm0 = (sum0 / x.cam_per_pts).pow(0.5) # n,d 134 | # normx0 = meandif0 / norm0[x.indices[1], :] # nnz,d 135 | # # dim1 136 | # mean1 = x.mean(dim=1) # m,d 137 | # meandif1 = x.values - mean1[x.indices[0], :] # nnz,d 138 | # sum1 = torch.zeros(x.shape[0], n_features, device=x.device) # m,d 139 | # sum1.index_add(0, x.indices[0], meandif1.pow(2.0)) 140 | # norm1 = (sum1 / x.pts_per_cam).pow(0.5) # m,d 141 | # normx1 = meandif1 / norm1[x.indices[0], :] # nnz,d 142 | # # output 143 | # new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d 144 | # new_vals = torch.cat([normx0, normx1], -1) # m,n,2d 145 | # return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 146 | 147 | def forward(self, x): # x.shape = m,n,d 148 | # dim0 149 | mean0 = x.mean(dim=0) # n,d 150 | std0 = x.std(dim=0) # n,d 151 | norm0 = (x.values - mean0[x.indices[1], :]) / std0[x.indices[1], :] # nnz,d 152 | # dim1 153 | mean1 = x.mean(dim=1) # n,d 154 | std1 = x.std(dim=1) # n,d 155 | norm1 = (x.values - mean1[x.indices[0], :]) / std1[x.indices[0], :] # nnz,d 156 | # output 157 | new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d 158 | new_vals = torch.cat([norm0, norm1], -1) # m,n,2d 159 | return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 160 | 161 | 162 | class ActivationLayer(Module): 163 | def __init__(self): 164 | super(ActivationLayer, self).__init__() 165 | self.relu = ReLU() 166 | 167 | def forward(self, x): 168 | new_features = self.relu(x.values) 169 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) 170 | 171 | 172 | class IdentityLayer(Module): 173 | def forward(self, x): 174 | return x 175 | 176 | 177 | class EmbeddingLayer(Module): 178 | def __init__(self, multires, in_dim): 179 | super(EmbeddingLayer, self).__init__() 180 | if multires > 0: 181 | self.embed, self.d_out = get_embedder(multires, in_dim) 182 | else: 183 | self.embed, self.d_out = (Identity(), in_dim) 184 | 185 | def forward(self, x): 186 | embeded_features = self.embed(x.values) 187 | new_shape = (x.shape[0], x.shape[1], embeded_features.shape[1]) 188 | return SparseMat(embeded_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape) 189 | 190 | 191 | class PosiEmbedding(Module): 192 | def __init__(self, num_freqs: int, logscale=True): 193 | """ 194 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...) 195 | """ 196 | super(PosiEmbedding, self).__init__() 197 | 198 | if logscale: 199 | self.freq_bands = 2 ** torch.linspace(0, num_freqs - 1, num_freqs) 200 | else: 201 | self.freq_bands = torch.linspace(1, 2 ** (num_freqs - 1), num_freqs) 202 | 203 | def forward(self, x: torch.Tensor) -> torch.Tensor: 204 | out = [x] 205 | for freq in self.freq_bands: 206 | out += [torch.sin(freq * x), torch.cos(freq * x)] 207 | return torch.cat(out, -1) 208 | 209 | 210 | class SigmoidScoreLayer(Module): 211 | def __init__(self): 212 | super(SigmoidScoreLayer, self).__init__() 213 | self.sigmoid = Sigmoid() 214 | 215 | def forward(self, x): 216 | new_features = self.sigmoid(x.values) 217 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) 218 | 219 | 220 | class TanhScoreLayer(Module): 221 | def __init__(self): 222 | super(TanhScoreLayer, self).__init__() 223 | self.relu = ReLU() 224 | self.tanh = Tanh() 225 | 226 | def forward(self, x): 227 | new_features = self.relu(x.values) 228 | new_features = self.tanh(new_features) 229 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape) -------------------------------------------------------------------------------- /code/joint_optimization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.sparse import coo_matrix 3 | from utils import geo_utils, ceres_utils, ba_functions 4 | import pandas as pd 5 | 6 | # m,3,3 m,3,3 m,3 7 | def get_Ps_gt(K_gt, R_gt, T_gt): 8 | m = R_gt.shape[0] 9 | K_RT = np.matmul(K_gt, R_gt.transpose((0, 2, 1))) # K@R.T 10 | Is = np.expand_dims(np.identity(3), 0).repeat(m, axis = 0) 11 | tmp = np.concatenate((Is, -np.expand_dims(T_gt, 2)), axis = 2) # I|-t 12 | Ps_gt = np.matmul(K_RT, tmp) 13 | return Ps_gt 14 | 15 | 16 | def loadtxt(classed_txt): 17 | ids_list = [] 18 | with open(classed_txt, 'r') as f: 19 | for line in f: 20 | ids0 = list(map(int, line[:-1].strip().split(' '))) 21 | ids = [x-1 for x in ids0] 22 | ids.sort() 23 | ids_list.append(ids) 24 | return ids_list 25 | 26 | 27 | def get_col_ids(cond, minval): 28 | ids = np.arange(cond.shape[1]) 29 | return ids[np.sum(cond!=0, axis=0)>=minval] 30 | 31 | 32 | def make_full_scene_by_predRt_rawM(pred_npzs, raw_npzs, full_npz, classed_txt, outpath): 33 | 34 | full_f = np.load(full_npz) 35 | Rs_gt = full_f['newRs'] 36 | ts_gt = full_f['newts'] 37 | Ns = full_f['Ns'] 38 | Ks = np.linalg.inv(Ns) 39 | M_data = full_f['M_data'] 40 | M_row = full_f['M_row'] 41 | M_col = full_f['M_col'] 42 | M_shape = full_f['M_shape'] 43 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A 44 | xs = geo_utils.M_to_xs(M) 45 | 46 | ids_list = loadtxt(classed_txt) 47 | Rs = np.zeros(Rs_gt.shape) 48 | ts = np.zeros(ts_gt.shape) 49 | 50 | scan_name = "ortho1_full" 51 | for i in range(len(pred_npzs)): 52 | fpred = np.load(pred_npzs[i]) 53 | fraw = np.load(raw_npzs[i]) 54 | ids = ids_list[i] 55 | Rs[ids,:] = fpred['Rs'] 56 | ts[ids,:] = fpred['ts'] + fraw['ts'][0] 57 | 58 | Ps = get_Ps_gt(Ks, Rs, ts) 59 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=M, Ns=Ns) 60 | 61 | results = {} 62 | results['scan_name'] = scan_name 63 | results['xs'] = xs 64 | results['Rs'] = Rs 65 | results['ts'] = ts 66 | results['Rs_gt'] = Rs_gt 67 | results['ts_gt'] = ts_gt 68 | results['Ks'] = Ks 69 | results['pts3D_pred'] = pts3D_triangulated.T 70 | results['Ps'] = Ps 71 | results['raw_xs'] = xs 72 | results['raw_color'] = full_f['rgbs'] 73 | np.savez(outpath, **results) 74 | 75 | 76 | def make_full_scene_by_abaRtM(aba_npzs, raw_npzs, full_npz, classed_txt, outpath): 77 | print("in make_full_scene_by_abaRtM() now ...") 78 | print("1. get full raw data ...") 79 | full_f = np.load(full_npz) 80 | Rs_gt = full_f['newRs'] 81 | ts_gt = full_f['newts'] 82 | Ns = full_f['Ns'] 83 | Ks = np.linalg.inv(Ns) 84 | M_data = full_f['M_data'] 85 | M_row = full_f['M_row'] 86 | M_col = full_f['M_col'] 87 | M_shape = full_f['M_shape'] 88 | raw_M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A 89 | raw_xs = geo_utils.M_to_xs(raw_M) 90 | xs = np.zeros((M_shape[0]//2, M_shape[1], 2)) 91 | 92 | ids_list = loadtxt(classed_txt) 93 | Rs = np.zeros(Rs_gt.shape) 94 | ts = np.zeros(ts_gt.shape) 95 | 96 | print("2. process splited data ...") 97 | scan_name = "ortho1_full" 98 | for i in range(len(aba_npzs)): 99 | faba = np.load(aba_npzs[i]) 100 | fraw = np.load(raw_npzs[i]) 101 | camids = ids_list[i] 102 | colids = fraw['pt3d_idx'][faba['colidx']] 103 | xstmp = np.zeros((len(camids), M_shape[1], 2)) 104 | xstmp[:,colids,:] = faba['new_xs'] 105 | xs[camids, :, :] = xstmp 106 | Rs[camids,:] = faba['Rs_ba_fixed'] 107 | ts[camids,:] = faba['ts_ba_fixed'] + fraw['ts'][0] 108 | 109 | M=geo_utils.xs_to_M(xs) 110 | valid_colidx = get_col_ids(M,4) 111 | 112 | Ps = get_Ps_gt(Ks, Rs, ts) 113 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=M[:,valid_colidx], Ns=Ns) 114 | 115 | print("3. write results ...") 116 | results = {} 117 | results['precolor'] = full_f['rgbs'][:, valid_colidx] 118 | results['scan_name'] = scan_name 119 | results['xs'] = xs[:,valid_colidx,:] 120 | results['Rs'] = Rs 121 | results['ts'] = ts 122 | results['Rs_gt'] = Rs_gt 123 | results['ts_gt'] = ts_gt 124 | results['Ks'] = Ks 125 | results['pts3D_pred'] = pts3D_triangulated 126 | results['Ps'] = Ps 127 | results['raw_xs'] = raw_xs 128 | results['raw_color'] = full_f['rgbs'] 129 | np.savez(outpath, **results) 130 | 131 | 132 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined, proj_first, proj_second): 133 | print("in _runba() now ...") 134 | outputs = {} 135 | 136 | print("1. load data ...") 137 | file = dict(np.load(npzpath)) 138 | scan_name = file['scan_name'] 139 | xs = file['xs'] 140 | Rs_pred = file['Rs'] 141 | ts_pred = file['ts'] 142 | Rs_gt = file['Rs_gt'] 143 | ts_gt = file['ts_gt'] 144 | Ks = file['Ks'] 145 | Ns = np.linalg.inv(Ks) 146 | # file['pts3D_pred'] = file['pts3D_pred'].T 147 | Xs = file['pts3D_pred'] 148 | Ps = file['Ps'] 149 | raw_xs = file['raw_xs'] 150 | 151 | outputs['xs'] = xs 152 | outputs['Rs_gt'] = Rs_gt 153 | outputs['ts_gt'] = ts_gt 154 | 155 | print("2. ba now ...") 156 | ba_res = ba_functions.merged_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=Ks, Xs=Xs.T, Ps=Ps, Ns=Ns, 157 | repeat=repeat, max_iter=max_iter, ba_times=ba_times, 158 | repro_thre=repro_thre, refined=refined, proj_first=proj_first, proj_second=proj_second) # Rs, ts, Ps, Xs 159 | outputs['Rs_ba'] = ba_res['Rs'] 160 | outputs['ts_ba'] = ba_res['ts'] 161 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n 162 | outputs['Ps_ba'] = ba_res['Ps'] 163 | if refined: 164 | outputs['valid_points'] = ba_res['valid_points'] 165 | outputs['new_xs'] = ba_res['new_xs'] 166 | outputs['colidx'] = ba_res['colidx'] 167 | 168 | print("3. align cams ...") 169 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt, 170 | return_alignment=True) # Align Rs_fixed, tx_fixed 171 | outputs['Rs_ba_fixed'] = R_ba_fixed 172 | outputs['ts_ba_fixed'] = t_ba_fixed 173 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba']) 174 | # outputs['Rs_ba_fixed'] = ba_res['Rs'] 175 | # outputs['ts_ba_fixed'] = ba_res['ts'] 176 | # outputs['Xs_ba_fixed'] = ba_res['Xs'].T 177 | 178 | print("4. save results ...") 179 | file.update(outputs) 180 | np.savez(npzpath, **file) 181 | 182 | return file, scan_name 183 | 184 | 185 | def _compute_errors(outputs, results_file_path, refined=False): 186 | print("in _compute_errors() now ...") 187 | model_errors = {} 188 | 189 | pts3D_pred = outputs['pts3D_pred'] 190 | Ps = outputs['Ps'] 191 | Rs_fixed = outputs['Rs'] 192 | ts_fixed = outputs['ts'] 193 | Rs_gt = outputs['Rs_gt'] 194 | ts_gt = outputs['ts_gt'] 195 | xs = outputs['xs'] 196 | Xs_ba = outputs['Xs_ba'] 197 | Ps_ba = outputs['Ps_ba'] 198 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs) 199 | if not our_repro_error.shape: return model_errors 200 | model_errors["our_repro"] = np.nanmean(our_repro_error) 201 | model_errors["our_repro_max"] = np.nanmax(our_repro_error) 202 | 203 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt) 204 | model_errors["ts_mean"] = np.mean(ts_error) 205 | model_errors["ts_med"] = np.median(ts_error) 206 | model_errors["ts_max"] = np.max(ts_error) 207 | model_errors["Rs_mean"] = np.mean(Rs_error) 208 | model_errors["Rs_med"] = np.median(Rs_error) 209 | model_errors["Rs_max"] = np.max(Rs_error) 210 | 211 | if refined: 212 | valid_points = outputs['valid_points'] 213 | new_xs = outputs['new_xs'] 214 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points) 215 | else: 216 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs) 217 | model_errors['repro_ba'] = np.nanmean(repro_ba_error) 218 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error) 219 | 220 | Rs_ba_fixed = outputs['Rs_ba_fixed'] 221 | ts_ba_fixed = outputs['ts_ba_fixed'] 222 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt) 223 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error) 224 | model_errors["ts_ba_med"] = np.median(ts_ba_error) 225 | model_errors["ts_ba_max"] = np.max(ts_ba_error) 226 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error) 227 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error) 228 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error) 229 | 230 | errors_list = [] 231 | errors_list.append(model_errors) 232 | df_errors = pd.DataFrame(errors_list) 233 | mean_errors = df_errors.mean(numeric_only=True) 234 | # df_errors = pd.concat([df_errors,mean_errors], axis=0, ignore_index=True) 235 | df_errors = df_errors.append(mean_errors, ignore_index=True) 236 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean" 237 | df_errors.set_index("Scene", inplace=True) 238 | # df_errors = df_errors.round(3) 239 | print(df_errors.to_string(), flush=True) 240 | df_errors.to_excel(results_file_path) 241 | 242 | 243 | if __name__=="__main__": 244 | 245 | refined = True 246 | repeat = True 247 | max_iter = 50 248 | ba_times = 2 249 | repro_thre = 2 250 | proj_first = True 251 | proj_second = True 252 | outputs, scan_name = _runba(out_npz_path, repeat, max_iter, ba_times, repro_thre, refined, proj_first, proj_second) 253 | _compute_errors(outputs, out_xlsx_path, refined=refined) 254 | -------------------------------------------------------------------------------- /code/utils/ba_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from utils import ceres_utils 3 | from utils import geo_utils 4 | 5 | 6 | def euc_ba(xs, raw_xs, Rs, ts, Ks, Xs, Ps=None, Ns=None, repeat=True, max_iter=100, ba_times=2, repro_thre=5, refined=False): 7 | """ 8 | Computes bundle adjustment with ceres solver 9 | :param xs: 2d points [m,n,2] 10 | :param Rs: rotations [m,3,3] 11 | :param ts: translations [m,3] 12 | :param Ks: inner parameters, calibration matrices [m,3,3] 13 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed 14 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]] 15 | :param Ns: normalization matrices. If Ks are known, Ns = inv(Ks) 16 | :param repeat: run ba twice. default: True 17 | :param triangulation: For initial point run triangulation. default: False 18 | :param return_repro: compute and return the reprojection errors before and after. 19 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors. 20 | """ 21 | results = {} 22 | 23 | visible_points = xs[:, :, 0] != 0 # m,3dptnum 24 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum 25 | visible_xs = xs[visible_points] # 2dptnum,2 26 | 27 | # if Ps is None: 28 | # Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs, ts, Ks) 29 | 30 | # if triangulation: 31 | # if Ns is None: 32 | # Ns = np.linalg.inv(Ks) 33 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns) 34 | # Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 35 | # else: 36 | # Xs = Xs_our 37 | 38 | if refined: 39 | new_Rs, new_ts, new_Ps, _, _, _, _, _ = ceres_utils.run_euc_ceres_iter( 40 | Ps, Xs, xs, Rs, ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre) 41 | # new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1) 42 | if repeat: 43 | visible_points = raw_xs[:, :, 0] != 0 # m,3dptnum 44 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum 45 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, raw_xs, Ns) 46 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 47 | proj_first = True 48 | # print(f"new_Xs1: {new_Xs.shape}") 49 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter( 50 | new_Ps, new_Xs, raw_xs, new_Rs, new_ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_first) 51 | else: 52 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres( 53 | Xs, visible_xs, Rs, ts, Ks, point_indices) 54 | if repeat: 55 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns) 56 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 57 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres( 58 | new_Xs, visible_xs, new_Rs, new_ts, Ks, point_indices) 59 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1) 60 | 61 | results['Rs'] = new_Rs 62 | results['ts'] = new_ts 63 | results['Ps'] = new_Ps 64 | results['Xs'] = new_Xs 65 | if refined: 66 | results['valid_points'] = visible_points 67 | results['new_xs'] = new_xs 68 | results['colidx'] = colidx 69 | 70 | return results 71 | 72 | 73 | def merged_ba(xs, raw_xs, Rs, ts, Ks, Xs, Ps=None, Ns=None, repeat=True, max_iter=100, ba_times=2, repro_thre=5, refined=False, 74 | proj_first=False, proj_second=True): 75 | """ 76 | Computes bundle adjustment with ceres solver 77 | :param xs: 2d points [m,n,2] 78 | :param Rs: rotations [m,3,3] 79 | :param ts: translations [m,3] 80 | :param Ks: inner parameters, calibration matrices [m,3,3] 81 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed 82 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]] 83 | :param Ns: normalization matrices. If Ks are known, Ns = inv(Ks) 84 | :param repeat: run ba twice. default: True 85 | :param triangulation: For initial point run triangulation. default: False 86 | :param return_repro: compute and return the reprojection errors before and after. 87 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors. 88 | """ 89 | results = {} 90 | 91 | visible_points = xs[:, :, 0] != 0 # m,3dptnum 92 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum 93 | 94 | # if Ps is None: 95 | # Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs, ts, Ks) 96 | 97 | # if triangulation: 98 | # if Ns is None: 99 | # Ns = np.linalg.inv(Ks) 100 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns) 101 | # Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 102 | # else: 103 | # Xs = Xs_our 104 | 105 | if refined: 106 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter( 107 | Ps, Xs, xs, Rs, ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_first) 108 | # new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1) 109 | if repeat: 110 | visible_points = raw_xs[:, :, 0] != 0 # m,3dptnum 111 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum 112 | # new_Ps = geo_utils.batch_get_camera_matrix_from_rtk(new_Rs, new_ts, Ks) 113 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, raw_xs, Ns) 114 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 115 | # print(f"new_Xs1: {new_Xs.shape}") 116 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter( 117 | new_Ps, new_Xs, raw_xs, new_Rs, new_ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_second) 118 | else: 119 | visible_xs = xs[visible_points] # 2dptnum,2 120 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres( 121 | Xs, visible_xs, Rs, ts, Ks, point_indices) 122 | if repeat: 123 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns) 124 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 125 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres( 126 | new_Xs, visible_xs, new_Rs, new_ts, Ks, point_indices) 127 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1) 128 | 129 | results['Rs'] = new_Rs 130 | results['ts'] = new_ts 131 | results['Ps'] = new_Ps 132 | results['Xs'] = new_Xs 133 | if refined: 134 | results['valid_points'] = visible_points 135 | results['new_xs'] = new_xs 136 | results['colidx'] = colidx 137 | 138 | return results 139 | 140 | 141 | def proj_ba(Ps, xs, Xs_our=None, Ns=None, repeat=True, triangulation=False, return_repro=True,normalize_in_tri=True): 142 | """ 143 | Computes bundle adjustment with ceres solve 144 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]] 145 | :param xs: 2d points [m,n,2] 146 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed 147 | :param Ns: normalization matrices. 148 | :param repeat: run ba twice. default: True 149 | :param triangulation: For initial point run triangulation. default: False 150 | :param return_repro: compute and return the reprojection errors before and after. 151 | :param normalize_in_tri: Normalize the points and the cameras when computing triangulation. default: True 152 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors. 153 | """ 154 | results = {} 155 | 156 | visible_points = xs[:, :, 0] != 0 157 | point_indices = np.stack(np.where(visible_points)) 158 | visible_xs = xs[visible_points] 159 | 160 | if triangulation: 161 | if normalize_in_tri: 162 | if Ns is None: 163 | Ns = geo_utils.batch_get_normalization_matrices(xs) 164 | norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns) 165 | Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 166 | else: 167 | Xs = geo_utils.dlt_triangulation(Ps, xs, visible_points) 168 | else: 169 | Xs = Xs_our 170 | 171 | if return_repro: 172 | results['repro_before'] = np.nanmean(geo_utils.reprojection_error_with_points(Ps, Xs, xs, visible_points)) 173 | 174 | new_Ps, new_Xs = ceres_utils.run_projective_python_ceres(Ps, Xs, visible_xs, point_indices) 175 | 176 | if repeat: 177 | if return_repro: 178 | results['repro_middle'] = np.nanmean(geo_utils.reprojection_error_with_points(new_Ps, new_Xs, xs, visible_points)) 179 | 180 | if normalize_in_tri: 181 | if Ns is None: 182 | Ns = geo_utils.batch_get_normalization_matrices(xs) 183 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns) 184 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points) 185 | else: 186 | new_Xs = geo_utils.dlt_triangulation(new_Ps, xs, visible_points) 187 | 188 | new_Ps, new_Xs = ceres_utils.run_projective_python_ceres(new_Ps, new_Xs, visible_xs, point_indices) 189 | 190 | if return_repro: 191 | results['repro_after'] = np.nanmean(geo_utils.reprojection_error_with_points(new_Ps, new_Xs, xs, visible_points)) 192 | 193 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1) 194 | 195 | results['Ps'] = new_Ps 196 | results['Xs'] = new_Xs 197 | return results 198 | 199 | -------------------------------------------------------------------------------- /code/datasets/SceneData.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import geo_utils, dataset_utils, sparse_utils 3 | from datasets import Projective, Euclidean 4 | import os.path 5 | from pyhocon import ConfigFactory 6 | import numpy as np 7 | import warnings 8 | from pytorch3d import transforms as py3d_trans 9 | 10 | 11 | class SceneData: 12 | def __init__(self, M, Ns, Rs, ts, quats, mask, scan_name, gpss, color, scale=1.0, dilute_M=False, 13 | use_spatial_encoder=True, dsc_idx=None, dsc_data=None, dsc_shape=None): 14 | n_images = Ns.shape[0] 15 | 16 | # Set attribute 17 | self.scan_name = scan_name 18 | # self.y = Ps_gt 19 | self.M = M 20 | self.Ns = Ns 21 | self.Ks = torch.inverse(Ns) 22 | self.mask = mask 23 | self.gpss = gpss 24 | 25 | # get R, t 26 | # Rs_gt, ts_gt = geo_utils.decompose_camera_matrix(Ps_gt, self.Ks) 27 | self.Rs = Rs 28 | self.ts = ts 29 | self.quats = quats 30 | self.color = color 31 | self.scale = scale 32 | 33 | # Dilute M 34 | if dilute_M: 35 | self.M = geo_utils.dilutePoint(M) 36 | 37 | # M to sparse matrix 38 | self.x = dataset_utils.M2sparse(M, normalize=True, Ns=Ns) 39 | 40 | # mask to sparse matrix 41 | mask_trim = mask[::2].unsqueeze(2) # 2m,n -> m,n,1 42 | mask_indices = self.x.indices 43 | mask_values = mask_trim[mask_indices[0], mask_indices[1], :] 44 | mask_shape = (self.x.shape[0], self.x.shape[1], 1) 45 | self.mask_sparse = sparse_utils.SparseMat(mask_values, mask_indices, 46 | self.x.cam_per_pts, self.x.pts_per_cam, mask_shape) 47 | 48 | egps_indices = self.x.indices 49 | egps_values = self.gpss[self.x.indices[0],:] 50 | egps_shape = (self.x.shape[0], self.x.shape[1], 3) 51 | self.egps_sparse = sparse_utils.SparseMat(egps_values, egps_indices, 52 | self.x.cam_per_pts, self.x.pts_per_cam, egps_shape) 53 | 54 | if use_spatial_encoder: 55 | self.dsc = sparse_utils.SparseMat(dsc_data, dsc_idx, self.x.cam_per_pts, self.x.pts_per_cam, dsc_shape) 56 | 57 | # Get valid points 58 | self.valid_pts = dataset_utils.get_M_valid_points(M) 59 | 60 | # Normalize M 61 | self.norm_M = geo_utils.normalize_M(M, Ns, self.valid_pts).transpose(1, 2).reshape(n_images * 2, -1) 62 | 63 | 64 | def to(self, *args, **kwargs): 65 | for key in self.__dict__: 66 | if not key.startswith('__'): 67 | attr = getattr(self, key) 68 | #if not callable(attr) and (isinstance(attr, sparse_utils.SparseMat) or torch.is_tensor(attr)): 69 | if isinstance(attr, sparse_utils.SparseMat) or torch.is_tensor(attr): 70 | setattr(self, key, attr.to(*args, **kwargs)) 71 | 72 | return self 73 | 74 | 75 | def create_scene_data(conf, flag): 76 | # Init 77 | scan = conf.get_string('dataset.scan') 78 | calibrated = conf.get_bool('dataset.calibrated') 79 | dilute_M = conf.get_bool('dataset.diluteM', default=False) 80 | 81 | # Get raw data 82 | if calibrated: 83 | M, Ns, Ps_gt, mask, gpss = Euclidean.get_raw_data(conf, scan, flag) 84 | else: 85 | M, Ns, Ps_gt, mask = Projective.get_raw_data(conf, scan) 86 | 87 | return SceneData(M, Ns, Ps_gt, mask, scan, gpss, dilute_M) 88 | 89 | 90 | def sample_data(data, num_samples, adjacent=True): 91 | # Get indices 92 | # indices = dataset_utils.sample_indices(len(data.y), num_samples, adjacent=adjacent) 93 | indices = dataset_utils.radius_sample(data.ts.numpy(), num_samples) 94 | # indices = dataset_utils.simulate_sample(data.gpss.numpy(), num_samples, data.x.pts_per_cam.numpy()) 95 | indices, M_indices = dataset_utils.order_indices(indices, shuffle=True) 96 | 97 | indices = torch.from_numpy(indices).squeeze() 98 | M_indices = torch.from_numpy(M_indices).squeeze() 99 | 100 | # Get sampled data 101 | Rs = data.Rs[indices] 102 | ts = data.ts[indices] 103 | quats = data.quats[indices] 104 | Ns = data.Ns[indices] 105 | M = data.M[M_indices] 106 | mask = data.mask[M_indices] 107 | mask = mask[:,(M!=0).sum(dim=0)>2] 108 | M = M[:,(M!=0).sum(dim=0)>2] 109 | 110 | # shuffle column 111 | idx = torch.randperm(M.shape[1]) 112 | M = M[:,idx] 113 | mask = mask[:,idx] 114 | 115 | sampled_data = SceneData(M, Ns, Rs, ts, quats, mask, data.scan_name) 116 | if (sampled_data.x.pts_per_cam == 0).any(): 117 | warnings.warn('Cameras with no points for dataset '+ data.scan_name) 118 | 119 | return sampled_data 120 | 121 | 122 | # flag=0 means train; flag=1 means val; flag=2 means test 123 | def create_scene_data_from_list(scan_names_list, conf, flag): 124 | data_list = [] 125 | for scan_name in scan_names_list: 126 | conf["dataset"]["scan"] = scan_name 127 | data = create_scene_data(conf, flag) 128 | data_list.append(data) 129 | 130 | return data_list 131 | 132 | 133 | def create_scene_data_from_dir(conf, flag): 134 | if flag==0: 135 | datadir = conf.get_string("dataset.trainset_path") 136 | elif flag==1: 137 | datadir = conf.get_string("dataset.valset_path") 138 | else: 139 | datadir = conf.get_string("dataset.testset_path") 140 | 141 | data_list = [] 142 | dilute_M = conf.get_bool('dataset.diluteM', default=False) 143 | for _,_,files in os.walk(datadir): 144 | for f in files: 145 | # Get raw data 146 | f = f.split('.')[0] # get name only 147 | M, Ns, Rs, ts, quats, mask = Euclidean.get_raw_data(conf, f, flag) 148 | data = SceneData(M, Ns, Rs, ts, quats, mask, f, dilute_M) 149 | data_list.append(data) 150 | return data_list 151 | 152 | 153 | def get_data_list(conf, flag): 154 | if flag==0: 155 | datadir = conf.get_string("dataset.trainset_path") 156 | elif flag==1: 157 | datadir = conf.get_string("dataset.valset_path") 158 | else: 159 | datadir = conf.get_string("dataset.testset_path") 160 | 161 | data_list = [] 162 | for _,_,files in os.walk(datadir): 163 | for f in files: 164 | data_list.append(f.split('.')[0]) 165 | return data_list 166 | 167 | 168 | def test_dataset(): 169 | # Prepare configuration 170 | dataset_dict = {"images_path": "/home/labs/waic/hodaya/PycharmProjects/GNN-for-SFM/datasets/images/", 171 | "normalize_pts": True, 172 | "normalize_f": True, 173 | "use_gt": False, 174 | "calibrated": False, 175 | "scan": "Alcatraz Courtyard", 176 | "edge_min_inliers": 30, 177 | "use_all_edges": True, 178 | } 179 | 180 | train_dict = {"infinity_pts_margin": 1e-4, 181 | "hinge_loss_weight": 1, 182 | } 183 | loss_dict = {"infinity_pts_margin": 1e-4, 184 | "normalize_grad": False, 185 | "hinge_loss": True, 186 | "hinge_loss_weight" : 1 187 | } 188 | conf_dict = {"dataset": dataset_dict, "loss":loss_dict} 189 | 190 | print("Test projective") 191 | conf = ConfigFactory.from_dict(conf_dict) 192 | data = create_scene_data(conf) 193 | test_data(data, conf) 194 | 195 | print('Test move to device') 196 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 197 | new_data = data.to(device) 198 | 199 | print(os.linesep) 200 | print("Test Euclidean") 201 | conf = ConfigFactory.from_dict(conf_dict) 202 | conf["dataset"]["calibrated"] = True 203 | data = create_scene_data(conf) 204 | test_data(data, conf) 205 | 206 | print(os.linesep) 207 | print("Test use_gt GT") 208 | conf = ConfigFactory.from_dict(conf_dict) 209 | conf["dataset"]["use_gt"] = True 210 | data = create_scene_data(conf) 211 | test_data(data, conf) 212 | 213 | 214 | def test_data(data, conf): 215 | import loss_functions 216 | 217 | # Test Losses of GT and random on data 218 | repLoss = loss_functions.ESFMLoss(conf) 219 | cams_gt = prepare_cameras_for_loss_func(data.y, data) 220 | cams_rand = prepare_cameras_for_loss_func(torch.rand(data.y.shape), data) 221 | 222 | print("Loss for GT: Reprojection = {}".format(repLoss(cams_gt, data))) 223 | print("Loss for rand: Reprojection = {}".format(repLoss(cams_rand, data))) 224 | 225 | 226 | def prepare_cameras_for_loss_func(Ps, data): 227 | Vs_invT = Ps[:, 0:3, 0:3] 228 | Vs = torch.inverse(Vs_invT).transpose(1, 2) 229 | ts = torch.bmm(-Vs.transpose(1, 2), Ps[:, 0:3, 3].unsqueeze(dim=-1)).squeeze() 230 | pts_3D = torch.from_numpy(geo_utils.n_view_triangulation(Ps.numpy(), data.M.numpy(), data.Ns.numpy())).float() 231 | return {"Ps": torch.bmm(data.Ns, Ps), "pts3D": pts_3D} 232 | 233 | 234 | def get_subset(data, subset_size): 235 | # Get subset indices 236 | valid_pts = dataset_utils.get_M_valid_points(data.M) 237 | n_cams = valid_pts.shape[0] 238 | 239 | first_idx = valid_pts.sum(dim=1).argmax().item() 240 | curr_pts = valid_pts[first_idx].clone() 241 | valid_pts[first_idx] = False 242 | indices = [first_idx] 243 | 244 | for i in range(subset_size - 1): 245 | shared_pts = curr_pts.expand(n_cams, -1) & valid_pts 246 | next_idx = shared_pts.sum(dim=1).argmax().item() 247 | curr_pts = curr_pts | valid_pts[next_idx] 248 | valid_pts[next_idx] = False 249 | indices.append(next_idx) 250 | 251 | print("Cameras are:") 252 | print(indices) 253 | 254 | indices = torch.tensor(indices) 255 | M_indices = torch.sort(torch.cat((2 * indices, 2 * indices + 1)))[0] 256 | Rs = data.Rs[indices] 257 | ts = data.ts[indices] 258 | quats = data.quats[indices] 259 | Ns = data.Ns[indices] 260 | M = data.M[M_indices] 261 | M = M[:, (M != 0).sum(dim=0) > 2] 262 | mask = data.mask[M_indices] 263 | mask = mask[:,(M!=0).sum(dim=0)>2] 264 | return SceneData(M, Ns, Rs, ts, quats, mask, data.scan_name + "_{}".format(subset_size)) 265 | 266 | 267 | if __name__ == "__main__": 268 | test_dataset() 269 | 270 | -------------------------------------------------------------------------------- /code/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import plotly 2 | import os 3 | import utils.path_utils 4 | import plotly.express as px 5 | from plotly.subplots import make_subplots 6 | import numpy as np 7 | import torch 8 | import plotly.graph_objects as go 9 | from utils import geo_utils 10 | from matplotlib import image 11 | 12 | 13 | def plot_img_sets_error_bar(err_list, imgs_sets, path, title): 14 | max_val = 20 15 | 16 | # Prepare illegal values 17 | #illegal_idx = (err_list == float("inf")) 18 | illegal_idx = torch.logical_or((err_list > max_val), (err_list == float("inf"))) 19 | final_err = err_list.clone() 20 | final_err[illegal_idx] = max_val 21 | colors = np.array(['#636efa', ] * final_err.shape[0]) 22 | colors[illegal_idx] = 'crimson' 23 | 24 | colors = colors.tolist() 25 | final_err = final_err.tolist() 26 | 27 | # Create figure 28 | fig = px.bar(x=imgs_sets, y=final_err) 29 | fig.update_xaxes(title='Images sets') 30 | fig.update_yaxes(title=title, range=[0, max_val]) 31 | fig.update_traces(marker_color=colors) 32 | fig.update_layout(xaxis_type='category') 33 | 34 | plotly.offline.plot(fig, filename=path) 35 | 36 | 37 | def plot_img_reprojection_error_bar(err_list, img_list): 38 | return go.Bar(x=img_list, y=err_list.tolist()) 39 | # max_val = conf.get_int('plot.reproj_err_bar_max', default=20) 40 | # 41 | # # Create figure 42 | # fig = go.Figure([go.Bar(x=img_list.tolist(), y=err_list.tolist())]) 43 | # fig.update_xaxes(title='Images') 44 | # fig.update_yaxes(title='Reprojection Error', range=[0, max_val]) 45 | # # fig.update_traces(marker_color='crimson') 46 | # fig.update_layout(xaxis_type='category') 47 | # 48 | # path = os.path.join(general_utils.path_to_exp(conf), 'reprojection_err.html') 49 | # plotly.offline.plot(fig, filename=path) 50 | 51 | 52 | def plot_error_per_images_bar(repreoj_err, symetric_epipolar_dist, imgs_sets, conf, sub_name=""): 53 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'reprojection_err' + sub_name) 54 | plot_img_sets_error_bar(repreoj_err, imgs_sets, path, 'Mean Reprojection Error') 55 | 56 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'SymEpDist' + sub_name) 57 | plot_img_sets_error_bar(symetric_epipolar_dist, imgs_sets, path, 'Mean Symmetric Epipolar Distance') 58 | 59 | 60 | def plot_matrix_heatmap(data_matrix, indices, zmax): 61 | mask = data_matrix == 0 62 | mat = data_matrix.clone().numpy() 63 | mat[mask] = None 64 | #fig = px.imshow(mat, x=indices, y=indices) 65 | hm = go.Heatmap(z=mat, x=indices, y=indices, zmin=0, zmax=zmax) 66 | 67 | return hm 68 | 69 | 70 | def plot_heatmaps(repreoj_err, symetric_epipolar_dist, global_reprojection_error, edges, img_list, conf, path=None, static_path=None): 71 | repreoj_err_edges = torch.zeros(repreoj_err.shape) 72 | repreoj_err_edges[edges[0], edges[1]] = repreoj_err[edges[0], edges[1]] 73 | 74 | symetric_epipolar_dist_edges = torch.zeros(symetric_epipolar_dist.shape) 75 | symetric_epipolar_dist_edges[edges[0], edges[1]] = symetric_epipolar_dist[edges[0], edges[1]] 76 | 77 | zmax = conf.get_int('plot.color_bar_max', default=5) 78 | hm_rep_err = plot_matrix_heatmap(repreoj_err, list(map(str, img_list)), zmax) 79 | hm_rep_err_edges = plot_matrix_heatmap(repreoj_err_edges, list(map(str, img_list)), zmax) 80 | 81 | hm_sed = plot_matrix_heatmap(symetric_epipolar_dist, list(map(str, img_list)), zmax) 82 | hm_sed_edges = plot_matrix_heatmap(symetric_epipolar_dist_edges, list(map(str, img_list)), zmax) 83 | 84 | bar_global_rep_err = plot_img_reprojection_error_bar(global_reprojection_error, img_list) 85 | 86 | fig = make_subplots(2, 4, subplot_titles=['Reprojection Error', 'Symmetric Epipolar Distance', 87 | 'Reprojection Error - Triplets', 'Symmetric Epipolar Distance - Triplets', 88 | 'Global Reprojection Error'], 89 | specs=[[{}, {}, {}, {}], [{"colspan": 4}, None, None, None]]) 90 | 91 | fig.add_trace(hm_rep_err, 1, 1) 92 | fig.update_xaxes(type='category', title='Image', row=1, col=1) 93 | fig.update_yaxes(type='category', title='Image', row=1, col=1) 94 | 95 | fig.add_trace(hm_sed, 1, 2) 96 | fig.update_xaxes(type='category', title='Image', row=1, col=2) 97 | fig.update_yaxes(type='category', title='Image', row=1, col=2) 98 | 99 | fig.add_trace(hm_rep_err_edges, 1, 3) 100 | fig.update_xaxes(type='category', title='Image', row=1, col=3) 101 | fig.update_yaxes(type='category', title='Image', row=1, col=3) 102 | 103 | fig.add_trace(hm_sed_edges, 1, 4) 104 | fig.update_xaxes(type='category', title='Image', row=1, col=4) 105 | fig.update_yaxes(type='category', title='Image', row=1, col=4) 106 | 107 | fig.add_trace(bar_global_rep_err, 2, 1) 108 | max_rep_err = conf.get_int('plot.reproj_err_bar_max', default=20) 109 | fig.update_xaxes(type='category', title='Image', row=2, col=1) 110 | fig.update_yaxes(title='Reprojection Error', range=[0, max_rep_err], row=2, col=1) 111 | 112 | # fig.update_layout(width=1000) 113 | if path is None: 114 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'errors_heatmap.html') 115 | plotly.offline.plot(fig, filename=path) 116 | if static_path is not None: 117 | fig.write_image(static_path) 118 | 119 | 120 | def plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan, epoch=None, bundle_adjustment=False): 121 | Rs_gt = outputs['Rs_gt'] 122 | ts_gt = outputs['ts_gt'] 123 | 124 | # Rs_pred = outputs['Rs_fixed'] 125 | # ts_pred = outputs['ts_fixed'] 126 | # pts3D = outputs['pts3D_pred_fixed'][:3,:] 127 | Rs_pred = outputs['Rs'] 128 | ts_pred = outputs['ts'] 129 | pts3D = outputs['pts3D_pred'][:3,:] 130 | Rs_error = errors['Rs_mean'] 131 | ts_error = errors['ts_mean'] 132 | plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=scan, epoch=epoch) 133 | 134 | if bundle_adjustment: 135 | Rs_pred = outputs['Rs_ba_fixed'] 136 | ts_pred = outputs['ts_ba_fixed'] 137 | pts3D = outputs['Xs_ba_fixed'][:3,:] 138 | Rs_error = errors['Rs_ba_mean'] 139 | ts_error = errors['ts_ba_mean'] 140 | plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=scan+'_ba', epoch=epoch) 141 | 142 | def get_points_colors(images_path, image_names, xs, first_occurence=False): 143 | m, n, _ = xs.shape 144 | points_colors = np.zeros([n, 3]) 145 | if first_occurence: 146 | images_indices = (geo_utils.xs_valid_points(xs)).argmax(axis=0) 147 | unique_images = np.unique(images_indices) 148 | for i, image_ind in enumerate(unique_images): 149 | image_name = str(image_names[image_ind][0]).split('/')[1] 150 | im = image.imread(os.path.join(images_path, image_name)) 151 | # read the image to ndarray 152 | points_in_image = np.where(image_ind == images_indices)[0] 153 | for point_ind in points_in_image: 154 | point_2d_in_image = xs[image_ind, point_ind].astype(int) 155 | points_colors[point_ind] = im[point_2d_in_image[1], point_2d_in_image[0]] 156 | else: 157 | valid_points = geo_utils.xs_valid_points(xs) 158 | colors = np.zeros([m, n, 3]) 159 | for image_ind in range(m): 160 | image_name = str(image_names[image_ind][0]).split('/')[1] 161 | im = image.imread(os.path.join(images_path, image_name)) 162 | points_in_image = np.where(valid_points[image_ind])[0] 163 | for point_ind in points_in_image: 164 | point_2d_in_image = xs[image_ind, point_ind].astype(int) 165 | colors[image_ind, point_ind] = im[point_2d_in_image[1], point_2d_in_image[0]] 166 | for point_ind in range(n): 167 | points_colors[point_ind] = np.mean(colors[valid_points[:, point_ind], point_ind], axis=0) 168 | 169 | return points_colors 170 | 171 | def plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=None, epoch=None): 172 | data = [] 173 | data.append(get_3D_quiver_trace(ts_gt, Rs_gt[:, :3, 2], color='#86CE00', name='cam_gt', cam_size=2)) 174 | data.append(get_3D_quiver_trace(ts_pred, Rs_pred[:, :3, 2], color='#C4451C', name='cam_learn', cam_size=2)) 175 | # data.append(get_3D_scater_trace(ts_gt.T, color='#86CE00', name='cam_gt', size=2)) 176 | # data.append(get_3D_scater_trace(ts_pred.T, color='#C4451C', name='cam_learn', size=2)) 177 | data.append(get_3D_scater_trace(pts3D, '#3366CC', '3D points', size=0.5)) 178 | 179 | fig = go.Figure(data=data) 180 | fig.update_layout(title='Cameras: Rotation Mean = {:.5f}, Translation Mean = {:.5f}, Points num = {}'.format(Rs_error.mean(), ts_error.mean(), pts3D.shape[1]), showlegend=True) 181 | 182 | path = utils.path_utils.path_to_plots(conf, phase, epoch=epoch, scan=scan) 183 | plotly.offline.plot(fig, filename=path, auto_open=False) 184 | 185 | # return path 186 | 187 | 188 | def get_3D_quiver_trace(points, directions, color='#bd1540', name='', cam_size=1): 189 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped " 190 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped " 191 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped " 192 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped " 193 | 194 | trace = go.Cone( 195 | name=name, 196 | x=points[:, 0], 197 | y=points[:, 1], 198 | z=points[:, 2], 199 | u=directions[:, 0], 200 | v=directions[:, 1], 201 | w=directions[:, 2], 202 | sizemode='absolute', 203 | sizeref=cam_size, 204 | showscale=False, 205 | colorscale=[[0, color], [1, color]], 206 | anchor="tail" 207 | ) 208 | 209 | return trace 210 | 211 | 212 | def get_3D_scater_trace(points, color, name,size=0.5): 213 | assert points.shape[0] == 3, "3d plot input points are not correctely shaped " 214 | assert len(points.shape) == 2, "3d plot input points are not correctely shaped " 215 | 216 | trace = go.Scatter3d( 217 | name=name, 218 | x=points[0, :], 219 | y=points[1, :], 220 | z=points[2, :], 221 | mode='markers', 222 | marker=dict( 223 | size=size, 224 | color=color, 225 | ) 226 | ) 227 | 228 | return trace 229 | 230 | 231 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dsfm 2 | channels: 3 | - bottler 4 | - comet_ml 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 6 | - conda-forge 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 10 | - defaults 11 | dependencies: 12 | - _libgcc_mutex=0.1=conda_forge 13 | - _openmp_mutex=4.5=1_llvm 14 | - aadict=0.2.3=pyh9f0ad1d_0 15 | - alsa-lib=1.2.3=h516909a_0 16 | - asset=0.6.13=pyh9f0ad1d_0 17 | - attrs=21.2.0=pyhd8ed1ab_0 18 | - binutils_impl_linux-64=2.36.1=h193b22a_2 19 | - binutils_linux-64=2.36=hf3e587d_1 20 | - blas=2.111=mkl 21 | - blas-devel=3.9.0=11_linux64_mkl 22 | - bokeh=2.4.0=py38h578d9bd_0 23 | - brotlipy=0.7.0=py38h497a2fe_1001 24 | - bzip2=1.0.8=h7f98852_4 25 | - c-ares=1.17.2=h7f98852_0 26 | - ca-certificates=2024.7.2=h06a4308_0 27 | - cairo=1.16.0=h6cf1ce9_1008 28 | - certifi=2024.7.4=py38h06a4308_0 29 | - cffi=1.14.6=py38h3931269_1 30 | - chardet=4.0.0=py38h578d9bd_1 31 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 32 | - click=8.0.3=py38h578d9bd_0 33 | - cloudpickle=2.0.0=pyhd8ed1ab_0 34 | - colorama=0.4.4=pyh9f0ad1d_0 35 | - comet_ml=3.18.1=py38 36 | - configobj=5.0.6=py_0 37 | - cryptography=3.4.8=py38ha5dfef3_0 38 | - cudatoolkit=10.2.89=h8f6ccaa_9 39 | - cvxpy=1.1.15=py38h578d9bd_0 40 | - cvxpy-base=1.1.15=py38h43a58ef_0 41 | - cycler=0.10.0=py_2 42 | - cytoolz=0.11.0=py38h497a2fe_3 43 | - dask=2021.9.1=pyhd8ed1ab_0 44 | - dask-core=2021.9.1=pyhd8ed1ab_0 45 | - dataclasses=0.8=pyhc8e2a94_3 46 | - dbus=1.13.6=h48d8840_2 47 | - distributed=2021.9.1=py38h578d9bd_0 48 | - dulwich=0.20.25=py38h497a2fe_0 49 | - ecos=2.0.8=py38h6c62de6_1 50 | - eigen=3.4.0=h4bd325d_0 51 | - et_xmlfile=1.0.1=py_1001 52 | - everett=2.0.0=pyhd8ed1ab_0 53 | - expat=2.4.1=h9c3ff4c_0 54 | - ffmpeg=4.3.2=hca11adc_0 55 | - fontconfig=2.13.1=hba837de_1005 56 | - freeglut=3.2.1=h9c3ff4c_2 57 | - freetype=2.10.4=h0708190_1 58 | - fsspec=2021.10.0=pyhd8ed1ab_0 59 | - fvcore=0.1.5.post20210804=pyhd8ed1ab_0 60 | - gcc_impl_linux-64=11.2.0=h82a94d6_10 61 | - gcc_linux-64=11.2.0=h39a9532_1 62 | - gettext=0.19.8.1=h73d1719_1008 63 | - gflags=2.2.2=he1b5a44_1004 64 | - glib=2.68.4=h9c3ff4c_1 65 | - glib-tools=2.68.4=h9c3ff4c_1 66 | - globre=0.1.5=pyh9f0ad1d_0 67 | - glog=0.5.0=h48cff8f_0 68 | - gmp=6.2.1=h58526e2_0 69 | - gnutls=3.6.13=h85f3911_1 70 | - graphite2=1.3.13=h58526e2_1001 71 | - gst-plugins-base=1.18.5=hf529b03_0 72 | - gstreamer=1.18.5=h76c114f_0 73 | - gxx_impl_linux-64=11.2.0=h82a94d6_10 74 | - gxx_linux-64=11.2.0=hacbe6df_1 75 | - harfbuzz=3.0.0=h83ec7ef_1 76 | - hdf5=1.12.1=nompi_h2750804_101 77 | - heapdict=1.0.1=py_0 78 | - icu=68.1=h58526e2_0 79 | - idna=3.1=pyhd3deb0d_0 80 | - importlib-metadata=4.8.1=py38h578d9bd_0 81 | - iopath=0.1.8=pyhd8ed1ab_0 82 | - jasper=2.0.14=ha77e612_2 83 | - jbig=2.1=h7f98852_2003 84 | - jinja2=3.0.2=pyhd8ed1ab_0 85 | - jpeg=9d=h36c2ea0_0 86 | - jsonschema=4.1.0=pyhd8ed1ab_0 87 | - kernel-headers_linux-64=2.6.32=he073ed8_14 88 | - kiwisolver=1.3.2=py38h1fd1430_0 89 | - krb5=1.19.2=hcc1bbae_2 90 | - lame=3.100=h7f98852_1001 91 | - lcms2=2.12=hddcbb42_0 92 | - ld_impl_linux-64=2.36.1=hea4e1c9_2 93 | - lerc=2.2.1=h9c3ff4c_0 94 | - libblas=3.9.0=11_linux64_mkl 95 | - libcblas=3.9.0=11_linux64_mkl 96 | - libclang=11.1.0=default_ha53f305_1 97 | - libcurl=7.79.1=h2574ce0_1 98 | - libdeflate=1.7=h7f98852_5 99 | - libedit=3.1.20191231=he28a2e2_2 100 | - libev=4.33=h516909a_1 101 | - libevent=2.1.10=h9b69904_4 102 | - libffi=3.4.2=h9c3ff4c_4 103 | - libgcc-devel_linux-64=11.2.0=h0952999_10 104 | - libgcc-ng=11.2.0=h1d223b6_10 105 | - libgfortran-ng=11.2.0=h69a702a_10 106 | - libgfortran5=11.2.0=h5c6108e_10 107 | - libglib=2.68.4=h174f98d_1 108 | - libglu=9.0.0=he1b5a44_1001 109 | - libgomp=11.2.0=h1d223b6_10 110 | - libiconv=1.16=h516909a_0 111 | - liblapack=3.9.0=11_linux64_mkl 112 | - liblapacke=3.9.0=11_linux64_mkl 113 | - libllvm11=11.1.0=hf817b99_2 114 | - libnghttp2=1.43.0=h812cca2_1 115 | - libogg=1.3.4=h7f98852_1 116 | - libopencv=4.5.3=py38h66b0e6b_5 117 | - libopus=1.3.1=h7f98852_1 118 | - libpng=1.6.37=h21135ba_2 119 | - libpq=13.3=hd57d9b9_1 120 | - libprotobuf=3.18.1=h780b84a_0 121 | - libsanitizer=11.2.0=he4da1e4_10 122 | - libssh2=1.10.0=ha56f1ee_2 123 | - libstdcxx-devel_linux-64=11.2.0=h0952999_10 124 | - libstdcxx-ng=11.2.0=he4da1e4_10 125 | - libtiff=4.3.0=hf544144_1 126 | - libuuid=2.32.1=h7f98852_1000 127 | - libuv=1.42.0=h7f98852_0 128 | - libvorbis=1.3.7=h9c3ff4c_0 129 | - libwebp-base=1.2.1=h7f98852_0 130 | - libxcb=1.13=h7f98852_1003 131 | - libxkbcommon=1.0.3=he3ba5ed_0 132 | - libxml2=2.9.12=h72842e0_0 133 | - libzlib=1.2.11=h36c2ea0_1013 134 | - llvm-openmp=12.0.1=h4bd325d_1 135 | - locket=0.2.0=py_2 136 | - lz4-c=1.9.3=h9c3ff4c_1 137 | - matplotlib=3.4.3=py38h578d9bd_1 138 | - matplotlib-base=3.4.3=py38hf4fb855_1 139 | - metis=5.1.0=h58526e2_1006 140 | - mkl=2021.3.0=h726a3e6_557 141 | - mkl-devel=2021.3.0=ha770c72_558 142 | - mkl-include=2021.3.0=h726a3e6_557 143 | - mpfr=4.1.0=h9202a9a_1 144 | - msgpack-python=1.0.2=py38h1fd1430_1 145 | - mysql-common=8.0.25=ha770c72_2 146 | - mysql-libs=8.0.25=hfa10184_2 147 | - ncurses=6.2=h58526e2_4 148 | - nettle=3.6=he412f7d_0 149 | - ninja=1.10.2=h4bd325d_1 150 | - nspr=4.30=h9c3ff4c_0 151 | - nss=3.69=hb5efdd6_1 152 | - numpy=1.21.2=py38he2449b9_0 153 | - nvidia-ml=7.352.0=py_0 154 | - nvidiacub=1.10.0=0 155 | - olefile=0.46=pyh9f0ad1d_1 156 | - opencv=4.5.3=py38h578d9bd_5 157 | - openh264=2.1.1=h780b84a_0 158 | - openjpeg=2.4.0=hb52868f_1 159 | - openpyxl=3.0.9=pyhd8ed1ab_0 160 | - openssl=1.1.1w=h7f8727e_0 161 | - osqp=0.6.2=py38h43a58ef_2 162 | - packaging=21.0=pyhd8ed1ab_0 163 | - pandas=1.3.3=py38h43a58ef_0 164 | - partd=1.2.0=pyhd8ed1ab_0 165 | - pcre=8.45=h9c3ff4c_0 166 | - pip=21.3=pyhd8ed1ab_0 167 | - pixman=0.40.0=h36c2ea0_0 168 | - plotly=5.3.1=pyhd8ed1ab_0 169 | - portalocker=2.3.2=py38h578d9bd_0 170 | - psutil=5.8.0=py38h497a2fe_1 171 | - pthread-stubs=0.4=h36c2ea0_1001 172 | - py-opencv=4.5.3=py38he5a9106_5 173 | - pycparser=2.20=pyh9f0ad1d_2 174 | - pyhocon=0.3.58=pyhd8ed1ab_0 175 | - pyopenssl=21.0.0=pyhd8ed1ab_0 176 | - pyparsing=2.4.7=pyh9f0ad1d_0 177 | - pyqt=5.12.3=py38h578d9bd_7 178 | - pyqt-impl=5.12.3=py38h7400c14_7 179 | - pyqt5-sip=4.19.18=py38h709712a_7 180 | - pyqtchart=5.12=py38h7400c14_7 181 | - pyqtwebengine=5.12.1=py38h7400c14_7 182 | - pyrsistent=0.17.3=py38h497a2fe_2 183 | - pysocks=1.7.1=py38h578d9bd_3 184 | - python-dateutil=2.8.2=pyhd8ed1ab_0 185 | - pytz=2021.3=pyhd8ed1ab_0 186 | - pyyaml=5.4.1=py38h497a2fe_1 187 | - qdldl-python=0.1.5=py38h43a58ef_1 188 | - qt=5.12.9=hda022c4_4 189 | - readline=8.1=h46c0cb4_0 190 | - requests=2.26.0=pyhd8ed1ab_0 191 | - requests-toolbelt=0.9.1=py_0 192 | - rhash=1.4.1=h7f98852_0 193 | - scipy=1.7.1=py38h56a6a73_0 194 | - scs=2.1.4=py38h6afa1d1_0 195 | - semantic_version=2.8.5=pyh9f0ad1d_0 196 | - setuptools=58.2.0=py38h578d9bd_0 197 | - six=1.16.0=pyh6c4a22f_0 198 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 199 | - sqlite=3.36.0=h9cd32fc_2 200 | - suitesparse=5.10.1=h9e50725_1 201 | - sysroot_linux-64=2.12=he073ed8_14 202 | - tabulate=0.8.9=pyhd8ed1ab_0 203 | - tbb=2021.3.0=h4bd325d_0 204 | - tblib=1.7.0=pyhd8ed1ab_0 205 | - tenacity=8.0.1=pyhd8ed1ab_0 206 | - termcolor=1.1.0=py_2 207 | - tk=8.6.11=h27826a3_1 208 | - toolz=0.11.1=py_0 209 | - tornado=6.1=py38h497a2fe_1 210 | - tqdm=4.62.3=pyhd8ed1ab_0 211 | - urllib3=1.26.7=pyhd8ed1ab_0 212 | - websocket-client=0.57.0=py38h578d9bd_4 213 | - wheel=0.37.0=pyhd8ed1ab_1 214 | - wrapt=1.13.1=py38h497a2fe_0 215 | - wurlitzer=3.0.2=py38h578d9bd_0 216 | - x264=1!161.3030=h7f98852_1 217 | - xlrd=2.0.1=pyhd8ed1ab_3 218 | - xorg-fixesproto=5.0=h7f98852_1002 219 | - xorg-inputproto=2.3.2=h7f98852_1002 220 | - xorg-kbproto=1.0.7=h7f98852_1002 221 | - xorg-libice=1.0.10=h7f98852_0 222 | - xorg-libsm=1.2.3=hd9c2040_1000 223 | - xorg-libx11=1.7.2=h7f98852_0 224 | - xorg-libxau=1.0.9=h7f98852_0 225 | - xorg-libxdmcp=1.1.3=h7f98852_0 226 | - xorg-libxext=1.3.4=h7f98852_1 227 | - xorg-libxfixes=5.0.3=h7f98852_1004 228 | - xorg-libxi=1.7.10=h7f98852_0 229 | - xorg-libxrender=0.9.10=h7f98852_1003 230 | - xorg-renderproto=0.11.1=h7f98852_1002 231 | - xorg-xextproto=7.3.0=h7f98852_1002 232 | - xorg-xproto=7.0.31=h7f98852_1007 233 | - xz=5.2.5=h516909a_1 234 | - yacs=0.1.6=py_0 235 | - yaml=0.2.5=h516909a_0 236 | - zict=2.0.0=py_0 237 | - zipp=3.6.0=pyhd8ed1ab_0 238 | - zlib=1.2.11=h36c2ea0_1013 239 | - zstd=1.5.0=ha95c52a_0 240 | - pip: 241 | - absl-py==2.1.0 242 | - aiohappyeyeballs==2.4.0 243 | - aiohttp==3.10.5 244 | - aiosignal==1.3.1 245 | - async-timeout==4.0.3 246 | - cachetools==5.4.0 247 | - cmake==3.30.2 248 | - easydict==1.13 249 | - filelock==3.15.4 250 | - frozenlist==1.4.1 251 | - google-auth==2.33.0 252 | - google-auth-oauthlib==1.0.0 253 | - grpcio==1.65.4 254 | - joblib==1.4.2 255 | - lit==18.1.8 256 | - markdown==3.6 257 | - markupsafe==2.1.5 258 | - mpmath==1.3.0 259 | - multidict==6.0.5 260 | - networkx==3.1 261 | - nvidia-cublas-cu11==11.10.3.66 262 | - nvidia-cuda-cupti-cu11==11.7.101 263 | - nvidia-cuda-nvrtc-cu11==11.7.99 264 | - nvidia-cuda-runtime-cu11==11.7.99 265 | - nvidia-cudnn-cu11==8.5.0.96 266 | - nvidia-cufft-cu11==10.9.0.58 267 | - nvidia-curand-cu11==10.2.10.91 268 | - nvidia-cusolver-cu11==11.4.0.1 269 | - nvidia-cusparse-cu11==11.7.4.91 270 | - nvidia-nccl-cu11==2.14.3 271 | - nvidia-nvtx-cu11==11.7.91 272 | - oauthlib==3.2.2 273 | - pillow==10.4.0 274 | - protobuf==5.27.3 275 | - pyasn1==0.6.0 276 | - pyasn1-modules==0.4.0 277 | - pyceres==2.3 278 | - pytorch3d==0.3.0 279 | - requests-oauthlib==2.0.0 280 | - rsa==4.9 281 | - scikit-learn==1.3.2 282 | - sympy==1.13.2 283 | - tensorboard==2.14.0 284 | - tensorboard-data-server==0.7.2 285 | - threadpoolctl==3.5.0 286 | - torch==2.0.0 287 | - torch-cluster==1.6.3 288 | - torch-geometric==2.5.3 289 | - torch-scatter==2.1.2 290 | - torch-sparse==0.6.18 291 | - torch-spline-conv==1.2.2 292 | - torchaudio==2.0.1 293 | - torchvision==0.15.1 294 | - triton==2.0.0 295 | - typing-extensions==4.12.2 296 | - werkzeug==3.0.3 297 | - yarl==1.9.4 298 | -------------------------------------------------------------------------------- /code/utils/ceres_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/path/to/ceres-solver/ceres-bin/lib/") # so 3 | 4 | import PyCeres 5 | import numpy as np 6 | import scipy.io as sio 7 | import cv2 8 | from utils import geo_utils 9 | 10 | 11 | def order_cam_param_for_c(Rs, ts, Ks): 12 | """ 13 | Orders a [m, 12] matrix for the ceres function as follows: 14 | Ps_for_c[i, 0:3] 3 parameters for the vector representing the rotation 15 | Ps_for_c[i, 3:6] 3 parameters for the location of the camera 16 | Ps_for_c[i, 6:11] 5 parameters for the upper triangular part of the calibration matrix 17 | :param Rs: [m,3,3] 18 | :param ts: [m,3] 19 | :param Ks: [m,3,3] 20 | :return: Ps_for_c [m, 12] 21 | """ 22 | n_cam = len(Rs) 23 | Ps_for_c = np.zeros([n_cam, 12]) 24 | for i in range(n_cam): 25 | Ps_for_c[i, 0:3] = cv2.Rodrigues(Rs[i].T)[0].T 26 | Ps_for_c[i, 3:6] = (-Rs[i].T @ ts[i].reshape([3, 1])).T 27 | Ps_for_c[i, 6:11] = [Ks[i, 0, 0], Ks[i, 0, 1], Ks[i, 0, 2], Ks[i, 1, 1], Ks[i, 1, 2]] 28 | Ps_for_c[i, -1] = 1.0 29 | return Ps_for_c 30 | 31 | 32 | def reorder_from_c_to_py(Ps_for_c, Ks): 33 | """ 34 | Read back the camera parameters from the 35 | :param Ps_for_c: 36 | :return: Rs, ts, Ps 37 | """ 38 | n_cam = len(Ps_for_c) 39 | Rs = np.zeros([n_cam, 3, 3]) 40 | ts = np.zeros([n_cam, 3]) 41 | Ps = np.zeros([n_cam, 3,4]) 42 | for i in range(n_cam): 43 | Rs[i] = cv2.Rodrigues(Ps_for_c[i, 0:3])[0].T 44 | ts[i] = -Rs[i] @ Ps_for_c[i, 3:6].reshape([3, 1]).flatten() 45 | Ps[i] = geo_utils.get_camera_matrix(R=Rs[i], t=ts[i], K=Ks[i]) 46 | return Rs, ts, Ps 47 | 48 | 49 | def run_euclidean_ceres(Xs, xs, Rs, ts, Ks, point_indices): 50 | """ 51 | Calls a c++ function that optimizes the camera parameters and the 3D points for a lower reprojection error. 52 | :param Xs: [n, 3] 53 | :param xs: [v,2] 54 | :param Rs: [m,3,3] 55 | :param ts: [m,3] 56 | :param Ks: [m,3,3] 57 | :param point_indices: [2,v] 58 | :return: 59 | new_Rs, new_ts, new_Ps, new_Xs Which have a lower reprojection error 60 | """ 61 | if Xs.shape[-1] == 4: 62 | Xs = Xs[:,:3] 63 | assert Xs.shape[-1] == 3 64 | assert xs.shape[-1] == 2 65 | n_cam = len(Rs) 66 | n_pts = Xs.shape[0] 67 | n_observe = xs.shape[0] 68 | 69 | Ps_for_c = order_cam_param_for_c(Rs, ts, Ks).astype(np.double) 70 | Xs_flat = Xs.flatten("C").astype(np.double) 71 | Ps_for_c_flat = Ps_for_c.flatten("C").astype(np.double) 72 | xs_flat = xs.flatten("C").astype(np.double) 73 | point_indices = point_indices.flatten("C") 74 | 75 | Xsu = np.zeros_like(Xs_flat) 76 | Psu = np.zeros_like(Ps_for_c_flat) 77 | 78 | PyCeres.eucPythonFunctionOursBA(Xs_flat, xs_flat, Ps_for_c_flat, point_indices, Xsu, Psu, n_cam, n_pts, n_observe) 79 | 80 | new_Ps_for_c = Ps_for_c + Psu.reshape([n_cam, 12], order="C") 81 | 82 | new_Rs, new_ts, new_Ps = reorder_from_c_to_py(new_Ps_for_c, Ks) 83 | new_Xs = Xs + Xsu.reshape([n_pts,3], order="C") 84 | 85 | return new_Rs, new_ts, new_Ps, new_Xs 86 | 87 | 88 | def run_projective_ceres(Ps, Xs, xs, point_indices): 89 | """ 90 | Calls the c++ function, that loops over the variables: 91 | for i in range(v): 92 | xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex) 93 | :param Ps: [m, 3, 4] 94 | :param Xs: [n, 3] 95 | :param xs: [v, 2] 96 | :param point_indices: [2,v] 97 | :return: new_Ps: [m, 12] 98 | new_Xs: [n,3] 99 | """ 100 | if Xs.shape[-1] == 4: 101 | Xs = Xs[:,:3] 102 | assert Xs.shape[-1] == 3 103 | assert xs.shape[-1] == 2 104 | m = Ps.shape[0] 105 | n = Xs.shape[0] 106 | v = point_indices.shape[1] 107 | Ps_single_flat = Ps.reshape([-1, 12], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab 108 | 109 | Ps_flat = Ps_single_flat.flatten("C") # row major as in python 110 | Xs_flat = Xs.flatten("C") 111 | xs_flat = xs.flatten("C") 112 | point_idx_flat = point_indices.flatten("C") 113 | 114 | Psu = np.zeros_like(Ps_flat) 115 | Xsu = np.zeros_like(Xs_flat) 116 | 117 | PyCeres.pythonFunctionOursBA(Xs_flat, xs_flat, Ps_flat, point_idx_flat, Xsu, Psu, m, n, v) 118 | Psu = Psu.reshape([m,12], order="C") 119 | Psu = Psu.reshape([m,3,4], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab 120 | Xsu = Xsu.reshape([n,3]) 121 | 122 | new_Ps = Ps + Psu 123 | new_Xs = Xs + Xsu 124 | 125 | return new_Ps, new_Xs 126 | 127 | def run_euclidean_python_ceres(Xs, xs, Rs, ts, Ks, point_indices, print_out=True, max_iter=100): 128 | """ 129 | Calls a c++ function that optimizes the camera parameters and the 3D points for a lower reprojection error. 130 | :param Xs: [n, 3] 131 | :param xs: [v,2] 132 | :param Rs: [m,3,3] 133 | :param ts: [m,3] 134 | :param Ks: [m,3,3] 135 | :param point_indices: [2,v] 136 | :return: 137 | new_Rs, new_ts, new_Ps, new_Xs Which have a lower reprojection error 138 | """ 139 | if Xs.shape[-1] == 4: 140 | Xs = Xs[:,:3] 141 | assert Xs.shape[-1] == 3 142 | assert xs.shape[-1] == 2 143 | n_cam = len(Rs) 144 | n_pts = Xs.shape[0] 145 | n_observe = xs.shape[0] 146 | 147 | Ps_for_c = order_cam_param_for_c(Rs, ts, Ks).astype(np.double) 148 | Xs_flat = Xs.flatten("C").astype(np.double) 149 | Ps_for_c_flat = Ps_for_c.flatten("C").astype(np.double) 150 | xs_flat = xs.flatten("C").astype(np.double) 151 | point_indices = point_indices.flatten("C") 152 | 153 | Xsu = np.zeros_like(Xs_flat) 154 | Psu = np.zeros_like(Ps_for_c_flat) 155 | 156 | problem = PyCeres.Problem() 157 | for i in range(n_observe): # loop over the observations 158 | camIndex = int(point_indices[i]) 159 | point3DIndex = int(point_indices[i + n_observe]) 160 | 161 | cost_function = PyCeres.eucReprojectionError(xs_flat[2 * i], xs_flat[2 * i + 1], 162 | Ps_for_c_flat[12 * camIndex:12 * (camIndex + 1)], 163 | Xs_flat[3 * point3DIndex:3 * (point3DIndex + 1)]) 164 | 165 | loss_function = PyCeres.HuberLoss(0.1) 166 | problem.AddResidualBlock(cost_function, loss_function, Psu[12 * camIndex:12 * (camIndex + 1)], 167 | Xsu[3 * point3DIndex:3 * (point3DIndex + 1)]) 168 | 169 | options = PyCeres.SolverOptions() 170 | 171 | options.function_tolerance = 0.0001 172 | options.max_num_iterations = max_iter 173 | options.num_threads = 8 174 | 175 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_SCHUR 176 | options.minimizer_progress_to_stdout = True 177 | if not print_out: 178 | PyCeres.LoggingType = PyCeres.LoggingType.SILENT 179 | 180 | summary = PyCeres.Summary() 181 | PyCeres.Solve(options, problem, summary) 182 | if print_out: 183 | print(summary.FullReport()) 184 | 185 | if ~Psu.any(): 186 | print('Warning no change to Ps') 187 | if ~Xsu.any(): 188 | print('Warning no change to Xs') 189 | 190 | new_Ps_for_c = Ps_for_c + Psu.reshape([n_cam, 12], order="C") 191 | 192 | new_Rs, new_ts, new_Ps = reorder_from_c_to_py(new_Ps_for_c, Ks) 193 | new_Xs = Xs + Xsu.reshape([n_pts,3], order="C") 194 | 195 | return new_Rs, new_ts, new_Ps, new_Xs 196 | 197 | def get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre): 198 | errors = geo_utils.reprojection_error_with_points(Ps, Xs, xs, valid_points) 199 | outliers_mask = errors[:]>repro_thre 200 | outliers_col = outliers_mask.sum(axis=0)>0 201 | inliers_col = ~outliers_col 202 | Xs = Xs[inliers_col, :] 203 | xs = xs[:, inliers_col, :] 204 | valid_points = xs.sum(axis=-1)!=0 205 | point_indices = np.stack(np.where(valid_points)) 206 | return Xs, xs, valid_points, point_indices, np.where(inliers_col)[0] 207 | 208 | def run_euc_ceres_iter(Ps, Xs, xs, Rs, ts, Ks, point_indices, valid_points, max_iter=100, ba_times=2, repro_thre=5, proj_first=False): 209 | # Ns = np.linalg.inv(Ks) 210 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns) 211 | colidx = np.where(valid_points.sum(axis=0)!=0)[0] 212 | if proj_first: 213 | Xs, xs, valid_points, point_indices, colidx = get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre) 214 | 215 | for _ in range(ba_times): 216 | valid_xs = xs[valid_points] 217 | Rs, ts, Ps, Xs = run_euclidean_python_ceres(Xs, valid_xs, Rs, ts, Ks, point_indices, max_iter=max_iter) 218 | Xs, xs, valid_points, point_indices, newcolidx = get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre) 219 | colidx = colidx[newcolidx] 220 | return Rs, ts, Ps, Xs, xs, valid_points, point_indices, colidx 221 | 222 | 223 | def run_projective_python_ceres(Ps, Xs, xs, point_indices, print_out=True): 224 | """ 225 | Calls the c++ function, that loops over the variables: 226 | for i in range(v): 227 | xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex) 228 | :param Ps: [m, 3, 4] 229 | :param Xs: [n, 3] 230 | :param xs: [v, 2] 231 | :param point_indices: [2,v] 232 | :return: new_Ps: [m, 12] 233 | new_Xs: [n,3] 234 | """ 235 | if Xs.shape[-1] == 4: 236 | Xs = Xs[:,:3] 237 | assert Xs.shape[-1] == 3 238 | assert xs.shape[-1] == 2 239 | m = Ps.shape[0] 240 | n = Xs.shape[0] 241 | v = point_indices.shape[1] 242 | Ps_single_flat = Ps.reshape([-1, 12], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab 243 | 244 | Ps_flat = Ps_single_flat.flatten("C").astype(np.double) # row major as in python 245 | Xs_flat = Xs.flatten("C").astype(np.double) 246 | xs_flat = xs.flatten("C") 247 | point_idx_flat = point_indices.flatten("C") 248 | 249 | Psu = np.zeros_like(Ps_flat) 250 | Xsu = np.zeros_like(Xs_flat) 251 | 252 | problem = PyCeres.Problem() 253 | for i in range(v): # loop over the observations 254 | camIndex = int(point_idx_flat[i]) 255 | point3DIndex = int(point_idx_flat[i + v]) 256 | 257 | cost_function = PyCeres.projReprojectionError(xs_flat[2*i], xs_flat[2*i + 1], Ps_flat[12*camIndex:12*(camIndex+1)], Xs_flat[3 *point3DIndex:3*(point3DIndex+1)]) 258 | 259 | loss_function = PyCeres.HuberLoss(0.1) 260 | problem.AddResidualBlock(cost_function, loss_function, Psu[12*camIndex:12*(camIndex+1)], Xsu[3 *point3DIndex:3*(point3DIndex+1)]) 261 | 262 | 263 | options = PyCeres.SolverOptions() 264 | 265 | options.function_tolerance = 0.0001 266 | options.max_num_iterations = 100 267 | options.num_threads = 8 268 | 269 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_SCHUR 270 | options.minimizer_progress_to_stdout = True 271 | 272 | summary = PyCeres.Summary() 273 | PyCeres.Solve(options, problem, summary) 274 | if print_out: 275 | print(summary.FullReport()) 276 | Psu = Psu.reshape([m,12], order="C") 277 | Psu = Psu.reshape([m,3,4], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab 278 | Xsu = Xsu.reshape([n,3]) 279 | 280 | new_Ps = Ps + Psu 281 | new_Xs = Xs + Xsu 282 | 283 | return new_Ps, new_Xs -------------------------------------------------------------------------------- /bundle_adjustment/custom_cpp_cost_functions.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | 9 | namespace py = pybind11; 10 | 11 | 12 | 13 | // Parses a numpy array and extracts the pointer to the first element. 14 | // Requires that the numpy array be either an array or a row/column vector 15 | double* _ParseNumpyData(py::array_t& np_buf) { 16 | py::buffer_info info = np_buf.request(); 17 | // This is essentially just all error checking. As it will always be the info 18 | // ptr 19 | if (info.ndim > 2) { 20 | std::string error_msg("Number of dimensions must be <=2. This function" 21 | "only allows either an array or row/column vector(2D matrix) " 22 | + std::to_string(info.ndim)); 23 | throw std::runtime_error( 24 | error_msg); 25 | } 26 | if (info.ndim == 2) { 27 | // Row or Column Vector. Represents 1 parameter 28 | if (info.shape[0] == 1 || info.shape[1] == 1) { 29 | } 30 | else { 31 | std::string error_msg 32 | ("Matrix is not a row or column vector and instead has size " 33 | + std::to_string(info.shape[0]) + "x" 34 | + std::to_string(info.shape[1])); 35 | throw std::runtime_error( 36 | error_msg); 37 | } 38 | } 39 | return (double*)info.ptr; 40 | } 41 | 42 | struct ExampleFunctor { 43 | template 44 | bool operator()(const T *const x, T *residual) const { 45 | residual[0] = T(10.0) - x[0]; 46 | return true; 47 | } 48 | 49 | static ceres::CostFunction *Create() { 50 | return new ceres::AutoDiffCostFunction(new ExampleFunctor); 53 | } 54 | }; 55 | 56 | struct projReprojectionError { 57 | projReprojectionError(double observed_x, double observed_y, double* Porig, double* Xorig, double blockNum) 58 | : _observed_x(observed_x), _observed_y(observed_y), _Porig(Porig), _Xorig(Xorig), _blockNum(blockNum) {} 59 | 60 | template 61 | bool operator()(const T* const P, const T* const X, 62 | T* residuals) const { 63 | 64 | T Pnow[12]; 65 | T Xnow[4]; 66 | T projection[3]; 67 | for (int i = 0; i < 12; i++) { 68 | Pnow[i] = P[i] + _Porig[i]; 69 | } 70 | for (int i = 0; i < 3; i++) { 71 | Xnow[i] = X[i] + _Xorig[i]; 72 | } 73 | Xnow[3] = T(1.0); 74 | projection[0] = Pnow[0] * Xnow[0] + Pnow[3] * Xnow[1] + Pnow[6] * Xnow[2] + Pnow[9] * Xnow[3]; 75 | projection[1] = Pnow[1] * Xnow[0] + Pnow[4] * Xnow[1] + Pnow[7] * Xnow[2] + Pnow[10] * Xnow[3]; 76 | projection[2] = Pnow[2] * Xnow[0] + Pnow[5] * Xnow[1] + Pnow[8] * Xnow[2] + Pnow[11] * Xnow[3]; 77 | 78 | projection[0] = projection[0] / projection[2]; 79 | projection[1] = projection[1] / projection[2]; 80 | 81 | residuals[0] = (projection[0] - _observed_x) / _blockNum; 82 | residuals[1] = (projection[1] - _observed_y) / _blockNum; 83 | 84 | 85 | return true; 86 | } 87 | 88 | static ceres::CostFunction* CreateMyRep(const double observed_x, 89 | const double observed_y, double* Porig, double* Xorig, double blocksNum) { 90 | return (new ceres::AutoDiffCostFunction( 91 | new projReprojectionError(observed_x, observed_y, Porig, Xorig, blocksNum))); 92 | } 93 | 94 | double _observed_x; 95 | double _observed_y; 96 | 97 | double _blockNum; 98 | double* _Porig; 99 | double* _Xorig; 100 | 101 | 102 | }; 103 | 104 | 105 | struct eucReprojectionError { 106 | eucReprojectionError(double observed_x, double observed_y, double * Porig, double * Xorig,double blockNum) 107 | : _observed_x(observed_x), _observed_y(observed_y), _Porig(Porig), _Xorig(Xorig), _blockNum(blockNum){} 108 | 109 | template 110 | bool operator()(const T* const P, const T* const X, 111 | T* residuals) const { 112 | // camera[0,1,2] are the angle-axis rotation. 113 | 114 | T Rnow[3]; 115 | T tnow[3]; 116 | 117 | 118 | T Xnow[4]; 119 | T Xrot[3]; 120 | 121 | T projection[3]; 122 | for (int i = 0; i < 3; i++){ 123 | tnow[i] = P[i+3] + _Porig[i+3]; // P[3:6] camera location 124 | Rnow[i] = P[i] + _Porig[i]; // P[0:3] camera rotation 125 | } 126 | for (int i = 0; i < 3; i++){ 127 | Xnow[i] = X[i] + _Xorig[i]; 128 | } 129 | ceres::AngleAxisRotatePoint(Rnow, Xnow, Xrot); 130 | Xrot[0]+=tnow[0]; 131 | Xrot[1]+=tnow[1]; 132 | Xrot[2]+=tnow[2]; 133 | 134 | projection[0]=(Xrot[0]*_Porig[6]+Xrot[1]*_Porig[7]+Xrot[2]*_Porig[8])/Xrot[2]; // P[6:5] K 135 | projection[1]=(Xrot[1]*_Porig[9]+Xrot[2]*_Porig[10])/Xrot[2]; 136 | 137 | residuals[0] = (projection[0] - _observed_x) ; 138 | residuals[1] = (projection[1] - _observed_y) ; 139 | //residuals[0] = ceres::sqrt((projection[0] - _observed_x)*(projection[0] - _observed_x) + (projection[1] - _observed_y)*(projection[1] - _observed_y)); 140 | //residuals[1] = T(0.0); 141 | 142 | return true; 143 | } 144 | 145 | static ceres::CostFunction* CreateMyRepEuc(const double observed_x, 146 | const double observed_y, double * Porig, double * Xorig, double blocksNum) { 147 | return (new ceres::AutoDiffCostFunction( 148 | new eucReprojectionError(observed_x, observed_y, Porig, Xorig, blocksNum))); 149 | } 150 | double _observed_x; 151 | double _observed_y; 152 | double _blockNum; 153 | double * _Porig; 154 | double * _Xorig; 155 | }; 156 | 157 | 158 | 159 | //****** Complete function here 160 | /* The gateway function */ 161 | void pythonFunctionOursBA(double* Xs, double* xs, double* Ps, double* camPointmap, double* Xsu, double* Psu, int n_cam, int n_pts, int n_observe) 162 | { 163 | 164 | ceres::Problem problem; 165 | for (int i = 0; i < n_observe; i++) { 166 | 167 | int camIndex = int(camPointmap[i]); 168 | int point3DIndex = int(camPointmap[i + n_observe]); 169 | 170 | ceres::CostFunction* cost_function = 171 | projReprojectionError::CreateMyRep(xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex), 1); 172 | 173 | ceres::LossFunction* loss_function = new ceres::HuberLoss(0.1); 174 | problem.AddResidualBlock(cost_function, 175 | loss_function, 176 | Psu + 12 * (camIndex), Xsu + 3 * (point3DIndex)); 177 | } 178 | 179 | ceres::Solver::Options options; 180 | options.function_tolerance = 0.0001; 181 | options.max_num_iterations = 100; 182 | options.num_threads = 24; 183 | 184 | options.linear_solver_type = ceres::DENSE_SCHUR; 185 | options.minimizer_progress_to_stdout = true; 186 | 187 | ceres::Solver::Summary summary; 188 | ceres::Solve(options, &problem, &summary); 189 | std::cout << summary.FullReport().c_str() << "\n"; 190 | } 191 | void eucPythonFunctionOursBA(double* Xs, double* xs, double* Ps, double* camPointmap , double* Xsu, double* Psu, int nrows, int ncols, int n_observe) 192 | { 193 | 194 | ceres::Problem problem; 195 | for (int i = 0; i < n_observe; i++){ 196 | 197 | int camIndex = int(camPointmap[i]); 198 | int point3DIndex = int(camPointmap[i+n_observe]); 199 | 200 | ceres::CostFunction* cost_function = 201 | eucReprojectionError::CreateMyRepEuc(xs[2*i], xs[2*i + 1], Ps+(12*camIndex), Xs+(3*point3DIndex),1); 202 | ceres::LossFunction * loss_function = new ceres::HuberLoss(0.1); 203 | 204 | problem.AddResidualBlock(cost_function, 205 | loss_function, 206 | Psu+(12*camIndex), Xsu+(3*point3DIndex)); 207 | } 208 | 209 | ceres::Solver::Options options; 210 | options.function_tolerance = 0.0001; 211 | //options.max_num_iterations = 100; 212 | options.max_num_iterations = 100; 213 | //options.num_threads = 8; 214 | options.num_threads = 24; 215 | 216 | options.linear_solver_type = ceres::DENSE_SCHUR; 217 | options.minimizer_progress_to_stdout = true; 218 | 219 | ceres::Solver::Summary summary; 220 | ceres::Solve(options, &problem, &summary); 221 | std::cout << summary.FullReport() << "\n"; 222 | } 223 | 224 | 225 | //***** complete function ends here 226 | 227 | void add_custom_cost_functions(py::module &m) { 228 | 229 | // // Use pybind11 code to wrap your own cost function which is defined in C++s 230 | 231 | 232 | // // Here is an example 233 | // m.def("CreateCustomExampleCostFunction", &ExampleFunctor::Create); 234 | 235 | m.def("projReprojectionError", []( 236 | double observed_x, double observed_y, py::array_t& _Porig, py::array_t& _Xorig 237 | ) { 238 | double* Porig = _ParseNumpyData(_Porig); 239 | double* Xorig = _ParseNumpyData(_Xorig); 240 | py::gil_scoped_release release; 241 | 242 | return projReprojectionError::CreateMyRep(observed_x, observed_y, Porig, Xorig, 1); 243 | } , py::return_value_policy::reference); 244 | m.def("eucReprojectionError", []( 245 | double observed_x, double observed_y, py::array_t& _Porig, py::array_t& _Xorig 246 | ) { 247 | double* Porig = _ParseNumpyData(_Porig); 248 | double* Xorig = _ParseNumpyData(_Xorig); 249 | py::gil_scoped_release release; 250 | 251 | return eucReprojectionError::CreateMyRepEuc(observed_x, observed_y, Porig, Xorig, 1); 252 | } , py::return_value_policy::reference); 253 | 254 | 255 | m.def("pythonFunctionOursBA", []( 256 | py::array_t& _Xs, 257 | py::array_t& _xs, 258 | py::array_t& _Ps, 259 | py::array_t& _camPointmap, 260 | py::array_t& _Xsu, 261 | py::array_t& _Psu, 262 | int n_cam, int n_pts, int n_observe 263 | ) { 264 | double* Xs = _ParseNumpyData(_Xs); 265 | double* xs = _ParseNumpyData(_xs); 266 | double* Ps = _ParseNumpyData(_Ps); 267 | double* camPointmap = _ParseNumpyData(_camPointmap); 268 | double* Xsu = _ParseNumpyData(_Xsu); 269 | double* Psu = _ParseNumpyData(_Psu); 270 | py::gil_scoped_release release; 271 | 272 | pythonFunctionOursBA(Xs, xs, Ps, camPointmap, Xsu, Psu, n_cam, n_pts, n_observe); 273 | 274 | }); 275 | 276 | m.def("eucPythonFunctionOursBA", []( 277 | py::array_t& _Xs, 278 | py::array_t& _xs, 279 | py::array_t& _Ps, 280 | py::array_t& _camPointmap, 281 | py::array_t& _Xsu, 282 | py::array_t& _Psu, 283 | int n_cam, int n_pts, int n_observe 284 | ) { 285 | double* Xs = _ParseNumpyData(_Xs); 286 | double* xs = _ParseNumpyData(_xs); 287 | double* Ps = _ParseNumpyData(_Ps); 288 | double* camPointmap = _ParseNumpyData(_camPointmap); 289 | double* Xsu = _ParseNumpyData(_Xsu); 290 | double* Psu = _ParseNumpyData(_Psu); 291 | py::gil_scoped_release release; 292 | 293 | eucPythonFunctionOursBA(Xs, xs, Ps, camPointmap, Xsu, Psu, n_cam, n_pts, n_observe); 294 | 295 | }); 296 | } 297 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | import loss_functions 5 | import evaluation 6 | import copy 7 | from utils import path_utils, dataset_utils, plot_utils, geo_utils 8 | from time import time 9 | import pandas as pd 10 | from utils.Phases import Phases 11 | from utils.path_utils import path_to_exp 12 | import os 13 | from torch.utils.data.dataloader import DataLoader 14 | from torch.utils.tensorboard import SummaryWriter 15 | 16 | # import fitlog 17 | 18 | # for each epoch 19 | def epoch_train(train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts): 20 | model.train() 21 | train_losses = [] 22 | bce_losses = [] 23 | orient_losses = [] 24 | trans_losses = [] 25 | 26 | # for each batch 27 | for train_batch in train_data: # Loop over all sets - 30 28 | batch_loss = torch.tensor([0.0], device=train_data.device) 29 | optimizer.zero_grad() 30 | 31 | # for each data 32 | for curr_data in train_batch: 33 | 34 | pred_mask, pred_cam = model(curr_data) 35 | bceloss = loss_func1(pred_mask, curr_data, epoch) 36 | gtloss, orient_loss, trans_loss = loss_func2(pred_cam, curr_data, epoch) 37 | loss = bceloss + gtloss 38 | batch_loss += loss 39 | train_losses.append(loss.item()) 40 | bce_losses.append(bceloss.item()) 41 | orient_losses.append(orient_loss.item()) 42 | trans_losses.append(trans_loss.item()) 43 | 44 | if batch_loss.item()>0: 45 | batch_loss.backward() 46 | optimizer.step() 47 | 48 | scheduler.step() 49 | mean_loss = torch.tensor(train_losses).mean() 50 | mean_bceloss = torch.tensor(bce_losses).mean() 51 | mean_oriloss = torch.tensor(orient_losses).mean() 52 | mean_trloss = torch.tensor(trans_losses).mean() 53 | return mean_loss, mean_bceloss, mean_oriloss, mean_trloss 54 | 55 | 56 | def sim_epoch_train(conf, tbwriter, train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts, phase, validation_data, best_validation_metric): 57 | model.train() 58 | train_losses = [] 59 | bce_losses = [] 60 | orient_losses = [] 61 | trans_losses = [] 62 | 63 | i = 0 64 | epoch_size = len(train_data) 65 | print(f"there are {epoch_size} data in each epoch") 66 | eval_intervals = conf.get_int('train.eval_intervals', default=500) 67 | validation_metric = conf.get_list('train.validation_metric', default=["our_repro"]) 68 | no_ba_during_training = not conf.get_bool('ba.only_last_eval') 69 | train_trans = conf.get_bool('train.train_trans') 70 | save_predictions = conf.get_bool('train.save_predictions') 71 | 72 | best_epoch = 0 73 | best_model = torch.empty(0) 74 | 75 | for train_batch in train_data: # Loop over all sets 76 | batch_loss = torch.tensor([0.0], device=train_data.device) 77 | optimizer.zero_grad() 78 | 79 | for curr_data in train_batch: 80 | i+=1 81 | epoch0 = epoch*epoch_size+i 82 | # if not dataset_utils.is_valid_sample(curr_data, min_valid_pts): 83 | # print('{} {} has a camera with not enough points'.format(epoch0, curr_data.scan_name)) 84 | # continue 85 | 86 | pred_mask, pred_cam = model(curr_data) 87 | bceloss = loss_func1(pred_mask, curr_data, epoch0) 88 | if train_trans: 89 | gtloss, orient_loss, trans_loss = loss_func2(pred_cam, curr_data, epoch0) 90 | loss = bceloss + gtloss 91 | batch_loss += loss 92 | train_losses.append(loss.item()) 93 | bce_losses.append(bceloss.item()) 94 | orient_losses.append(orient_loss.item()) 95 | trans_losses.append(trans_loss.item()) 96 | else: 97 | orient_loss = loss_func2(pred_cam, curr_data, epoch0) 98 | loss = bceloss + orient_loss 99 | batch_loss += loss 100 | train_losses.append(loss.item()) 101 | bce_losses.append(bceloss.item()) 102 | orient_losses.append(orient_loss.item()) 103 | # if torch.isnan(gtloss): 104 | # print(epoch0) 105 | # continue 106 | # print(bceloss.item()) 107 | # print(gtloss.item()) 108 | 109 | # print(batch_loss.item()) 110 | if batch_loss.item()>0: 111 | batch_loss.backward() 112 | optimizer.step() 113 | 114 | mean_loss = torch.tensor(train_losses).mean() 115 | mean_bceloss = torch.tensor(bce_losses).mean() 116 | mean_oriloss = torch.tensor(orient_losses).mean() 117 | if train_trans: mean_trloss = torch.tensor(trans_losses).mean() 118 | 119 | epoch0 = epoch*epoch_size+i 120 | if epoch0 % 100 == 0: 121 | print('{} Train Loss: {}'.format(epoch0, mean_loss)) 122 | tbwriter.add_scalar("Total_Loss", mean_loss, epoch0) 123 | tbwriter.add_scalar("BCE_Loss", mean_bceloss, epoch0) 124 | tbwriter.add_scalar("Orient_Loss", mean_oriloss, epoch0) 125 | if train_trans: tbwriter.add_scalar("Trans_Loss", mean_trloss, epoch0) 126 | # if epoch0 % 20 == 0: 127 | # scheduler.step() 128 | 129 | if epoch0!=0 and (epoch0 % eval_intervals == 0): # Eval current results 130 | if phase is Phases.TRAINING: 131 | validation_errors = epoch_evaluation(validation_data, model, conf, epoch0, Phases.VALIDATION, save_predictions=save_predictions,bundle_adjustment=no_ba_during_training) 132 | else: 133 | validation_errors = epoch_evaluation(train_data, model, conf, epoch0, phase, save_predictions=save_predictions,bundle_adjustment=no_ba_during_training) 134 | 135 | metric = validation_errors.loc[["Mean"], validation_metric].sum(axis=1).values.item() 136 | # fitlog.add_metric({"dev":{"Acc":metric}}, step=epoch) 137 | 138 | if metric < best_validation_metric: 139 | best_validation_metric = metric 140 | best_epoch = epoch0 141 | best_model = copy.deepcopy(model) 142 | print('Updated best validation metric: {}'.format(best_validation_metric)) 143 | path = path_utils.path_to_model(conf, phase, epoch=epoch0) 144 | torch.save(best_model.state_dict(), path) 145 | 146 | scheduler.step() 147 | return best_epoch, best_validation_metric, best_model 148 | 149 | 150 | def epoch_evaluation(data_loader, model, conf, epoch, phase, save_predictions=False, bundle_adjustment=True): 151 | refined = conf.get_bool("ba.refined") 152 | errors_list = [] 153 | model.eval() 154 | with torch.no_grad(): 155 | for batch_data in data_loader: 156 | for curr_data in batch_data: 157 | # Get predictions 158 | begin_time = time() 159 | pred_mask, pred_cam = model(curr_data) 160 | # pred_mask = model(curr_data) 161 | pred_time = time() - begin_time 162 | 163 | # Eval results 164 | # outputs = evaluation.prepare_ptpredictions(curr_data, pred_mask) 165 | outputs = evaluation.prepare_predictions_2(curr_data, pred_mask, pred_cam, conf, epoch, bundle_adjustment, refined=refined) 166 | errors = evaluation.compute_errors(outputs, conf, bundle_adjustment, refined=refined) 167 | 168 | errors['Inference time'] = pred_time 169 | errors['Scene'] = curr_data.scan_name 170 | 171 | # Get scene statistics on final evaluation 172 | if epoch is None: 173 | # stats = dataset_utils.get_data_statistics(curr_data) 174 | stats = dataset_utils.get_data_statistics2(curr_data, outputs) 175 | errors.update(stats) 176 | 177 | errors_list.append(errors) 178 | 179 | if save_predictions: 180 | dataset_utils.save_cameras(outputs, conf, curr_epoch=epoch, phase=phase) 181 | if conf.get_bool('dataset.calibrated'): 182 | plot_utils.plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan=curr_data.scan_name, epoch=epoch, bundle_adjustment=bundle_adjustment) 183 | 184 | df_errors = pd.DataFrame(errors_list) 185 | mean_errors = df_errors.mean(numeric_only=True) 186 | df_errors = df_errors.append(mean_errors, ignore_index=True) 187 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean" 188 | df_errors.set_index("Scene", inplace=True) 189 | df_errors = df_errors.round(3) 190 | print(df_errors.to_string(), flush=True) 191 | model.train() 192 | 193 | return df_errors 194 | 195 | 196 | def train(conf, train_data, model, phase, validation_data=None, test_data=None): 197 | # fitlog.set_log_dir(os.path.join(conf.get_string('dataset.results_path'), 'logs')) 198 | tbwriter = SummaryWriter(log_dir=os.path.join(path_to_exp(conf), 'tb'), flush_secs=60) 199 | 200 | num_of_epochs = conf.get_int('train.num_of_epochs') 201 | eval_intervals = conf.get_int('train.eval_intervals', default=500) 202 | validation_metric = conf.get_list('train.validation_metric', default=["our_repro"]) 203 | 204 | # Loss functions 205 | loss_func2 = getattr(loss_functions, conf.get_string('loss.func'))(conf) 206 | loss_func1 = loss_functions.BCEWithLogitLoss() 207 | 208 | # Optimizer params 209 | lr = conf.get_float('train.lr') 210 | scheduler_milestone = conf.get_list('train.scheduler_milestone') 211 | gamma = conf.get_float('train.gamma', default=0.1) 212 | 213 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 214 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_milestone, gamma=gamma) 215 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) 216 | 217 | best_validation_metric = math.inf 218 | best_epoch = 0 219 | best_model = torch.empty(0) 220 | converge_time = -1 221 | begin_time = time() 222 | 223 | no_ba_during_training = not conf.get_bool('ba.only_last_eval') 224 | min_valid_pts = conf.get_int('train.min_valid_pts', default=20) 225 | 226 | if not conf.get_bool("dataset.sample"): # using preprocessed dataset 227 | for epoch in range(num_of_epochs): 228 | print(f"epoch {epoch} now ...") 229 | bepoch, bmetric, bmodel = sim_epoch_train(conf, tbwriter, train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts, phase, validation_data, best_validation_metric) 230 | if bepoch>best_epoch: 231 | best_epoch=bepoch 232 | best_model=bmodel 233 | if bmetricthred 249 | prediction = premask.values.numpy().squeeze() 250 | outputs['prediction'] = prediction 251 | premask = prediction>thred 252 | gtmask = data.mask_sparse.values.cpu().numpy().squeeze() 253 | outputs['premask'] = premask+0 254 | outputs['gtmask'] = gtmask.astype(int) 255 | 256 | # precisions, recalls, thresholds = precision_recall_curve(gtmask.reshape(-1), prediction.reshape(-1)) 257 | # plt.plot(precisions, recalls) 258 | # plt.xlabel('Recall') 259 | # plt.ylabel('Precision') 260 | # print(f"pr thresholds: {thresholds}") 261 | # prpath = os.path.join(conf.get_string('dataset.results_path'), "pr_img") 262 | # if not os.path.exists(prpath): os.mkdir(prpath) 263 | # plt.savefig(prpath+"/"+str(epoch)+".jpg") 264 | 265 | # processing M 266 | M = data.M.cpu().numpy() 267 | mask = np.zeros(M.shape) 268 | mask[0:mask.shape[0]:2]=predensemask 269 | mask[1:mask.shape[0]:2]=predensemask 270 | preM = M*mask 271 | colidx = (mask!=0).sum(axis=0)>2 272 | preM = preM[:, colidx] 273 | precolor = color[:, colidx] 274 | outputs['precolor'] = precolor 275 | outputs['raw_color'] = color 276 | # M = M[(mask!=0).sum(axis=1)!=0, :] 277 | # print(M.shape) 278 | # input() 279 | raw_xs = geo_utils.M_to_xs(M) 280 | xs = geo_utils.M_to_xs(preM) 281 | 282 | Rs_pred = py3d_trans.quaternion_to_matrix(geo_utils.norm_quats(pred_cam["quats"])).cpu().numpy() 283 | if train_trans: 284 | ts_pred = (data.gpss.cpu().numpy() + pred_cam["ts"].cpu().numpy()) * data.scale 285 | # ts_pred = pred_cam["ts"].cpu().numpy() * data.scale 286 | else: 287 | ts_pred = data.gpss.cpu().numpy() * data.scale 288 | 289 | #Rs_gt, ts_gt = geo_utils.decompose_camera_matrix(data.y.cpu().numpy(), Ks) # For alignment and R,t errors 290 | Rs_gt, ts_gt = data.Rs.cpu().numpy(), data.ts.cpu().numpy() 291 | outputs['Rs_gt'] = Rs_gt 292 | outputs['Rs'] = Rs_pred 293 | outputs['ts_gt'] = ts_gt * data.scale 294 | outputs['ts'] = ts_pred 295 | 296 | Ks = data.Ks.cpu().numpy() # data.Ns.inverse().cpu().numpy() 297 | outputs['Ks'] = Ks 298 | Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs_pred, ts_pred, Ks) 299 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=preM, Ns=Ns) 300 | outputs['xs'] = xs # to compute reprojection error later 301 | outputs['raw_xs'] = raw_xs 302 | outputs['Ps'] = Ps # Ps = K@(R|t) 303 | outputs['pts3D_pred'] = pts3D_triangulated # 4,n 304 | 305 | #Rs_pred, ts_pred = geo_utils.decompose_camera_matrix(Ps_norm) 306 | 307 | # Rs_fixed, ts_fixed, similarity_mat = geo_utils.align_cameras(Rs_pred, Rs_gt, ts_pred, ts_gt, return_alignment=True) # Align Rs_fixed, tx_fixed 308 | # outputs['Rs_fixed'] = Rs_fixed 309 | # outputs['ts_fixed'] = ts_fixed 310 | # outputs['pts3D_pred_fixed'] = (similarity_mat @ pts3D_triangulated) 311 | # outputs['Rs_fixed'] = Rs_pred 312 | # outputs['ts_fixed'] = ts_pred 313 | # outputs['pts3D_pred_fixed'] = pts3D_triangulated 314 | 315 | if bundle_adjustment: 316 | repeat = conf.get_bool('ba.repeat') 317 | max_iter = conf.get_int('ba.max_iter') 318 | ba_times = conf.get_int('ba.ba_times') 319 | repro_thre = conf.get_float('ba.repro_thre') 320 | ba_res = ba_functions.euc_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=np.linalg.inv(Ns), 321 | Xs=pts3D_triangulated.T, Ps=Ps, Ns=Ns, 322 | repeat=repeat, max_iter=max_iter, ba_times=ba_times, 323 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs 324 | outputs['Rs_ba'] = ba_res['Rs'] 325 | outputs['ts_ba'] = ba_res['ts'] 326 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n 327 | outputs['Ps_ba'] = ba_res['Ps'] 328 | if refined: 329 | outputs['valid_points'] = ba_res['valid_points'] 330 | outputs['new_xs'] = ba_res['new_xs'] 331 | 332 | # R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt, 333 | # return_alignment=True) # Align Rs_fixed, tx_fixed 334 | # outputs['Rs_ba_fixed'] = R_ba_fixed 335 | # outputs['ts_ba_fixed'] = t_ba_fixed 336 | # outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba']) 337 | outputs['Rs_ba_fixed'] = ba_res['Rs'] 338 | outputs['ts_ba_fixed'] = ba_res['ts'] 339 | outputs['Xs_ba_fixed'] = ba_res['Xs'].T 340 | 341 | # else: 342 | # if bundle_adjustment: 343 | # repeat = conf.get_bool('ba.repeat') 344 | # triangulation = conf.get_bool('ba.triangulation') 345 | # ba_res = ba_functions.proj_ba(Ps=Ps, xs=xs, Xs_our=pts3D_triangulated.T, Ns=Ns, repeat=repeat, 346 | # triangulation=triangulation, return_repro=True, normalize_in_tri=True) # Ps, Xs 347 | # outputs['Xs_ba'] = ba_res['Xs'].T # 4,n 348 | # outputs['Ps_ba'] = ba_res['Ps'] 349 | 350 | return outputs 351 | 352 | 353 | def prepare_ptpredictions(data, pred_mask): 354 | # Take the inputs from pred cam and turn to ndarray 355 | outputs = {} 356 | outputs['scan_name'] = data.scan_name 357 | 358 | premask = pred_mask.to('cpu') 359 | prediction = premask.values.numpy().squeeze() 360 | outputs['prediction'] = prediction 361 | premask = prediction>0.8 362 | gtmask = data.mask_sparse.values.cpu().numpy().squeeze() 363 | outputs['premask'] = premask+0 364 | outputs['gtmask'] = gtmask.astype(int) 365 | 366 | return outputs --------------------------------------------------------------------------------