├── system.png
├── code
├── models
│ ├── .SetOfSet.py.swp
│ ├── baseNet.py
│ ├── SetOfSet.py
│ └── layers.py
├── utils
│ ├── __pycache__
│ │ └── general_utils.cpython-39.pyc
│ ├── Phases.py
│ ├── pos_enc_utils.py
│ ├── ba_io.py
│ ├── path_utils.py
│ ├── sparse_utils.py
│ ├── dataset_utils.py
│ ├── general_utils.py
│ ├── ba_functions.py
│ ├── plot_utils.py
│ └── ceres_utils.py
├── .vscode
│ └── launch.json
├── confs
│ ├── inference.conf
│ └── training.conf
├── datasets
│ ├── Projective.py
│ ├── ScenesDataSet.py
│ ├── Euclidean.py
│ └── SceneData.py
├── inference.py
├── multiple_scenes_learning.py
├── loss_functions.py
├── run_ba.py
├── re_ba.py
├── joint_optimization.py
├── train.py
└── evaluation.py
├── .vscode
└── settings.json
├── LICENSE
├── bundle_adjustment
├── README.md
└── custom_cpp_cost_functions.cpp
├── README.md
└── environment.yml
/system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/system.png
--------------------------------------------------------------------------------
/code/models/.SetOfSet.py.swp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/code/models/.SetOfSet.py.swp
--------------------------------------------------------------------------------
/code/utils/__pycache__/general_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WHU-USI3DV/DeepAAT/HEAD/code/utils/__pycache__/general_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/code/utils/Phases.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class Phases(Enum):
5 | OPTIMIZATION = 1
6 | TRAINING = 2
7 | VALIDATION = 3
8 | TEST = 4
9 | FINE_TUNE = 5
10 | SHORT_OPTIMIZATION = 6
11 | INFERENCE = 7
--------------------------------------------------------------------------------
/code/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | // Use IntelliSense to learn about possible attributes.
3 | // Hover to view descriptions of existing attributes.
4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5 | "version": "0.2.0",
6 | "configurations": [
7 | {
8 | "name": "Python: 当前文件",
9 | "type": "python",
10 | "request": "launch",
11 | "program": "${file}",
12 | "console": "integratedTerminal"
13 | }
14 | ]
15 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "MicroPython.executeButton": [
3 | {
4 | "text": "▶",
5 | "tooltip": "Run",
6 | "alignment": "left",
7 | "command": "extension.executeFile",
8 | "priority": 3.5
9 | }
10 | ],
11 | "MicroPython.syncButton": [
12 | {
13 | "text": "$(sync)",
14 | "tooltip": "sync",
15 | "alignment": "left",
16 | "command": "extension.execute",
17 | "priority": 4
18 | }
19 | ]
20 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code/confs/inference.conf:
--------------------------------------------------------------------------------
1 | exp_name = Inference
2 | num_iter = 1
3 | dataset
4 | {
5 | use_gt = False
6 | calibrated = True
7 | flag = 2 # 1 for val, 2 for test
8 | valset_path = /home/zeeq/data/testset/npzs100-130
9 | testset_path = /home/zeeq/data/testset/npzs100-130
10 | results_path = /home/zeeq/results/test10
11 | use_spatial_encoder = True
12 | gps_embed_width = 128
13 | egps_embed_rank = 4
14 | x_embed_rank = 4
15 | dsc_egps_embed_width = 32 # 0 means not embed
16 | addNoise = True
17 | noise_mean = 0
18 | noise_std = 0.01
19 | noise_radio = 1.0
20 | alpha = 0.9 # rot
21 | beta = 0.1 # trans
22 | }
23 | model
24 | {
25 | model_path = /path/to/trained/model.pt # //deepaat/models/Model_Ep44000_embed44.pt
26 | type = SetOfSet.SetOfSetNet
27 | num_features = 256
28 | num_blocks = 3
29 | block_size = 2
30 | use_skip = True
31 | }
32 | train
33 | {
34 | lr = 1e-3
35 | num_of_epochs = 50
36 | eval_intervals = 2000
37 | train_trans = True
38 | save_predictions = False
39 | }
40 | loss
41 | {
42 | func = GTLoss
43 | mask_thred = 0.5
44 | }
45 | ba
46 | {
47 | run_ba = False
48 | repeat=True # If repeat, the first time is from our points and the second from triangulation
49 | triangulation=True
50 | only_last_eval = True
51 | refined = False
52 | max_iter = 50
53 | ba_times = 2
54 | repro_thre = 2
55 | }
56 |
--------------------------------------------------------------------------------
/code/utils/pos_enc_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class Embedder:
5 | def __init__(self, **kwargs):
6 | self.kwargs = kwargs
7 | self.create_embedding_fn()
8 |
9 | def create_embedding_fn(self):
10 | embed_fns = []
11 | d = self.kwargs['input_dims']
12 | out_dim = 0
13 | if self.kwargs['include_input']:
14 | embed_fns.append(lambda x: x)
15 | out_dim += d
16 |
17 | max_freq = self.kwargs['max_freq_log2']
18 | N_freqs = self.kwargs['num_freqs']
19 |
20 | if self.kwargs['log_sampling']:
21 | freq_bands = 2. ** torch.linspace(0., max_freq, steps=N_freqs)
22 | else:
23 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, steps=N_freqs)
24 |
25 | for freq in freq_bands:
26 | for p_fn in self.kwargs['periodic_fns']:
27 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
28 | out_dim += d
29 |
30 | self.embed_fns = embed_fns
31 | self.out_dim = out_dim
32 |
33 | def embed(self, inputs):
34 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
35 |
36 |
37 | def get_embedder(multires, in_dim):
38 | embed_kwargs = {
39 | 'include_input': True,
40 | 'input_dims': in_dim,
41 | 'max_freq_log2': multires - 1,
42 | 'num_freqs': multires,
43 | 'log_sampling': True,
44 | 'periodic_fns': [torch.sin, torch.cos],
45 | }
46 |
47 | embedder_obj = Embedder(**embed_kwargs)
48 | embed = lambda x, eo=embedder_obj: eo.embed(x)
49 | return embed, embedder_obj.out_dim
--------------------------------------------------------------------------------
/code/utils/ba_io.py:
--------------------------------------------------------------------------------
1 | import scipy.io as sio
2 | import numpy as np
3 | import os
4 |
5 |
6 | def read_mat_files(path):
7 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True)
8 | Xs = raw_data['Points3D'].T
9 | M = raw_data['M']
10 | m, n = M.shape
11 | m = m // 2
12 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1])
13 | Ps = np.stack(raw_data['Ps'])
14 | data = {'Ps': Ps, 'Xs': Xs, 'xs': xs}
15 | return data
16 |
17 | def read_euc_gt_mat_files(path):
18 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True)
19 | M = raw_data['M']
20 | if not isinstance(M, (np.ndarray, np.generic) ):
21 | M = np.asarray(M.todense())
22 | Rs = np.stack(raw_data['R_gt'])
23 | ts = np.stack(raw_data['T_gt'])
24 | Ks = np.stack(raw_data['K_gt'])
25 | m, n = M.shape
26 | m = m // 2
27 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1])
28 | data = {'Rs': Rs, 'ts': ts, 'Ks':Ks, 'xs': xs}
29 | return data
30 |
31 | def read_proj_gt_mat_files(path):
32 | raw_data = sio.loadmat(path + '.mat', squeeze_me=True)
33 | M = np.asarray(raw_data['M'])
34 | m, n = M.shape
35 | m = m // 2
36 | xs = M.reshape([m, 2, n]).transpose([0, 2, 1])
37 | data = {'xs': xs}
38 | return data
39 |
40 | def read_euc_our_mat_files(path, name='Final_Cameras'):
41 | raw_data = sio.loadmat(os.path.join(path, 'cameras', name) + '.mat', squeeze_me=True)
42 | Xs = raw_data['pts3D'][:3].T.astype(np.double)
43 | Rs = raw_data['Rs']
44 | ts = raw_data['ts']
45 | Ks = raw_data['Ks']
46 | data = {'Xs': Xs, 'Rs': Rs, 'ts': ts, 'Ks':Ks}
47 | return data
48 |
49 |
50 | def read_proj_our_mat_files(path, name='Final_Cameras'):
51 | raw_data = sio.loadmat(os.path.join(path, 'cameras', name) + '.mat', squeeze_me=True)
52 | Xs = raw_data['pts3D'][:3].T.astype(np.double)
53 | Ps = raw_data['Ps']
54 | data = {'Ps': Ps, 'Xs':Xs}
55 | return data
56 |
57 |
--------------------------------------------------------------------------------
/code/confs/training.conf:
--------------------------------------------------------------------------------
1 | exp_name = Learning_Euc
2 | # random_seed=0
3 | dataset
4 | {
5 | batch_size = 1
6 | shuffle_data = True # if true, then shuffle rows for each data
7 | trainset_path = /home/zeeq/data/trainset/npzs100-130 # Change to your own path
8 | valset_path = /home/zeeq/data/valset/npzs100-130
9 | testset_path = /home/zeeq/data/valset/npzs100-130
10 | results_path = /home/zeeq/results/val10
11 | use_spatial_encoder = True
12 | gps_embed_width = 128
13 | egps_embed_rank = 4
14 | x_embed_rank = 4
15 | dsc_egps_embed_width = 32 # 0 means not embed
16 | addNoise = True
17 | noise_mean = 0
18 | noise_std = 0.01
19 | noise_radio = 1.0
20 | alpha = 0.9 # rot
21 | beta = 0.1 # trans
22 | }
23 | model
24 | {
25 | type = SetOfSet.SetOfSetNet
26 | num_features = 256
27 | num_blocks = 3
28 | block_size = 2
29 | use_skip = True
30 |
31 | multires = 10 # discard
32 | }
33 | train
34 | {
35 | lr = 1e-3
36 | num_of_epochs = 50
37 | eval_intervals = 1000
38 | train_trans = True
39 | save_predictions = False
40 |
41 | scheduler_milestone = [12000,16000,18000] #[20000, 30000, 40000, 45000] # 20000
42 | gamma = 0.2
43 | optimization_num_of_epochs = 500
44 | optimization_eval_intervals = 250
45 | optimization_lr = 1e-3
46 | min_valid_pts = 100 # if any one camera in data batch watch points less than this, it will be skipped
47 | }
48 | loss
49 | {
50 | func = GTLoss
51 | mask_thred = 0.5
52 |
53 | infinity_pts_margin = 1e-4
54 | normalize_grad = True
55 | hinge_loss = True
56 | hinge_loss_weight = 1
57 | }
58 | ba
59 | {
60 | run_ba = False
61 | repeat=True # If repeat, the first time is from our points and the second from triangulation
62 | triangulation=True
63 | only_last_eval = True
64 | refined = False # multi-ba
65 | max_iter = 50
66 | ba_times = 2
67 | repro_thre = 2
68 | }
--------------------------------------------------------------------------------
/code/datasets/Projective.py:
--------------------------------------------------------------------------------
1 | import cv2 # Do not remove
2 | import torch
3 | from utils import geo_utils, general_utils, dataset_utils, path_utils
4 | import scipy.io as sio
5 | import numpy as np
6 | import os.path
7 |
8 |
9 | def get_raw_data(conf, scan):
10 | """
11 | :param conf:
12 | :return:
13 | M - Points Matrix (2mxn)
14 | Ns - Normalization matrices (mx3x3)
15 | Ps_gt - Olsson's estimated camera matrices (mx3x4)
16 | NBs - Normzlize Bifocal Tensor (Normalized Fn) (3mx3m)
17 | triplets
18 | """
19 | # Init
20 | dataset_path_format = os.path.join(path_utils.path_to_datasets(), 'Projective', '{}.npz')
21 |
22 | # Get conf parameters
23 | if scan is None:
24 | scan = conf.get_string('dataset.scan')
25 | use_gt = conf.get_bool('dataset.use_gt')
26 |
27 | # Get raw data
28 | dataset = np.load(dataset_path_format.format(scan))
29 |
30 | # Get bifocal tensors and 2D points
31 | M = dataset['M']
32 | Ps_gt = dataset['Ps_gt']
33 | Ns = dataset['Ns']
34 | mask = dataset['mask']
35 |
36 | if use_gt:
37 | M = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns))
38 |
39 | M = torch.from_numpy(M).float()
40 | Ps_gt = torch.from_numpy(Ps_gt).float()
41 | Ns = torch.from_numpy(Ns).float()
42 | mask = torch.from_numpy(mask).int()
43 |
44 | return M, Ns, Ps_gt, mask
45 |
46 |
47 | def test_Ps_M(Ps, M, Ns):
48 | global_rep_err = geo_utils.calc_global_reprojection_error(Ps.numpy(), M.numpy(), Ns.numpy())
49 | print("Reprojection Error: Mean = {}, Max = {}".format(np.nanmean(global_rep_err), np.nanmax(global_rep_err)))
50 |
51 |
52 | def test_projective_dataset(scan):
53 | dataset_path_format = os.path.join(path_utils.path_to_datasets(), 'Projective', '{}.npz')
54 |
55 | # Get raw data
56 | dataset = np.load(dataset_path_format.format(scan))
57 |
58 | # Get bifocal tensors and 2D points
59 | M = dataset['M']
60 | Ps_gt = dataset['Ps_gt']
61 | Ns = dataset['Ns']
62 |
63 | M_gt = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float()
64 |
65 | M = torch.from_numpy(M).float()
66 | Ps_gt = torch.from_numpy(Ps_gt).float()
67 | Ns = torch.from_numpy(Ns).float()
68 |
69 | print("Test Ps and M")
70 | test_Ps_M(Ps_gt, M, Ns)
71 |
72 | print("Test Ps and M_gt")
73 | test_Ps_M(Ps_gt, M_gt, Ns)
74 |
75 |
76 | if __name__ == "__main__":
77 | scan = "Alcatraz Courtyard"
78 | test_projective_dataset(scan)
79 |
80 |
81 |
--------------------------------------------------------------------------------
/code/inference.py:
--------------------------------------------------------------------------------
1 | import cv2 # DO NOT REMOVE
2 | from time import time
3 | from utils import general_utils, dataset_utils, plot_utils
4 | from utils.Phases import Phases
5 | from datasets import SceneData, ScenesDataSet
6 | from datasets.ScenesDataSet import DataLoader, myDataSet
7 | from single_scene_optimization import train_single_model
8 | import evaluation
9 | import torch
10 | import pandas as pd
11 |
12 | def inference(conf, device, phase):
13 |
14 | # Get conf
15 | flag = conf.get_int("dataset.flag") # 1 means val, 2 means test
16 | scans_list = SceneData.get_data_list(conf, flag)
17 | bundle_adjustment = conf.get_bool("ba.run_ba")
18 | refined = conf.get_bool("ba.refined")
19 |
20 | # Create model
21 | model_path = conf.get_string('model.model_path')
22 | model = general_utils.get_class("models." + conf.get_string("model.type"))(conf).to(device)
23 | model.load_state_dict(torch.load(model_path))
24 | model.eval()
25 |
26 | errors_list = []
27 | with torch.no_grad():
28 | data_loader = myDataSet(conf, flag, scans_list).to(device)
29 | for batch_data in data_loader:
30 | for scene_data in batch_data:
31 | print(f"processing {scene_data.scan_name} ...")
32 |
33 | # Optimize Scene
34 | begin_time = time()
35 | pred_mask, pred_cam = model(scene_data)
36 | pred_time = time() - begin_time
37 | outputs = evaluation.prepare_predictions_2(scene_data, pred_mask, pred_cam, conf, 0, bundle_adjustment, refined=refined)
38 | errors = evaluation.compute_errors(outputs, conf, bundle_adjustment, refined=refined)
39 |
40 | errors['Inference time'] = pred_time
41 | errors['Scene'] = scene_data.scan_name
42 | errors['all_pts'] = scene_data.M.shape[-1]
43 | errors['pred_pts'] = outputs['pts3D_pred'].shape[1]
44 | # errors['after_ba_pts'] = outputs['Xs_ba'].shape[1]
45 | errors['gt_pts'] = scene_data.mask[:,scene_data.mask.sum(axis=0)!=0].shape[1]
46 |
47 | errors_list.append(errors)
48 | dataset_utils.save_cameras(outputs, conf, curr_epoch=None, phase=phase)
49 | plot_utils.plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan=scene_data.scan_name, epoch=None, bundle_adjustment=bundle_adjustment)
50 |
51 | # Write results
52 | df_errors = pd.DataFrame(errors_list)
53 | mean_errors = df_errors.mean(numeric_only=True)
54 | # df_errors = pd.concat([df_errors,mean_errors], axis=0, ignore_index=True)
55 | df_errors = df_errors.append(mean_errors, ignore_index=True)
56 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean"
57 | df_errors.set_index("Scene", inplace=True)
58 | df_errors = df_errors.round(3)
59 | print(df_errors.to_string(), flush=True)
60 | general_utils.write_results(conf, df_errors, file_name="Inference")
61 |
62 |
63 | if __name__ == "__main__":
64 | conf, device, phase = general_utils.init_exp(Phases.INFERENCE.name)
65 | inference(conf, device, phase)
--------------------------------------------------------------------------------
/code/utils/path_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from utils.Phases import Phases
3 |
4 |
5 | def join_and_create(path, folder):
6 | full_path = os.path.join(path, folder)
7 | if not os.path.exists(full_path):
8 | os.mkdir(full_path)
9 |
10 | return full_path
11 |
12 |
13 | def path_to_datasets():
14 | return os.path.join('..', 'datasets')
15 |
16 |
17 | def path_to_condition(conf):
18 | experiments_folder = conf.get_string('dataset.results_path')
19 | if not os.path.exists(experiments_folder):
20 | os.mkdir(experiments_folder)
21 | exp_name = conf.get_string('exp_name')
22 | return join_and_create(experiments_folder, exp_name)
23 |
24 |
25 | def path_to_exp(conf):
26 | exp_ver = conf.get_string('exp_version')
27 | exp_ver_path = join_and_create(path_to_condition(conf), exp_ver)
28 |
29 | return exp_ver_path
30 |
31 |
32 | def path_to_phase(conf, phase):
33 | exp_path = path_to_exp(conf)
34 | return join_and_create(exp_path, phase.name)
35 |
36 |
37 | def path_to_scan(conf, phase, scan=None):
38 | exp_path = path_to_phase(conf, phase)
39 | scan = conf.get_string("dataset.scan") if scan is None else scan
40 | return join_and_create(exp_path, scan)
41 |
42 |
43 | def path_to_model(conf, phase, epoch=None, scan=None):
44 | if phase in [Phases.TRAINING, Phases.VALIDATION, Phases.TEST]:
45 | parent_folder = path_to_exp(conf)
46 | else:
47 | parent_folder = path_to_scan(conf, phase, scan=scan)
48 |
49 | models_path = join_and_create(parent_folder, 'models')
50 |
51 | if epoch is None:
52 | model_file_name = "Final_Model.pt"
53 | else:
54 | model_file_name = "Model_Ep{}.pt".format(epoch)
55 |
56 | return os.path.join(models_path, model_file_name)
57 |
58 |
59 | def path_to_learning_data(conf, phase):
60 | return join_and_create(path_to_condition(conf), phase)
61 |
62 |
63 | def path_to_cameras(conf, phase, epoch=None, scan=None):
64 | scan_path = path_to_scan(conf, phase, scan=scan)
65 | cameras_path = join_and_create(scan_path, 'cameras')
66 |
67 | if epoch is None:
68 | cameras_file_name = "Final_Cameras"
69 | else:
70 | cameras_file_name = "Cameras_Ep{}".format(epoch)
71 |
72 | return os.path.join(cameras_path, cameras_file_name)
73 |
74 |
75 | def path_to_plots(conf, phase, epoch=None, scan=None):
76 | scan_path = path_to_scan(conf, phase, scan=scan)
77 | plots_path = join_and_create(scan_path, 'plots')
78 |
79 | if epoch is None:
80 | plots_file_name = "Final_plots.html"
81 | else:
82 | plots_file_name = "Plot_Ep{}.html".format(epoch)
83 |
84 | return os.path.join(plots_path, plots_file_name)
85 |
86 |
87 | def path_to_logs(conf, phase):
88 | phase_path = path_to_phase(conf, phase)
89 | logs_path = join_and_create(phase_path, "logs")
90 | return logs_path
91 |
92 |
93 | def path_to_code_logs(conf):
94 | exp_path = path_to_exp(conf)
95 | code_path = join_and_create(exp_path, "code")
96 | return code_path
97 |
98 |
99 | def path_to_conf(conf_file):
100 | return os.path.join( 'confs', conf_file)
--------------------------------------------------------------------------------
/code/multiple_scenes_learning.py:
--------------------------------------------------------------------------------
1 | import cv2 # DO NOT REMOVE
2 | from utils import general_utils, dataset_utils
3 | from utils.Phases import Phases
4 | from datasets.ScenesDataSet import ScenesDataSet, DataLoader, myDataSet
5 | from datasets import SceneData
6 | from single_scene_optimization import train_single_model
7 | import train
8 | import copy
9 | import time
10 |
11 |
12 | def main():
13 | # Init Experiment
14 | conf, device, phase = general_utils.init_exp(Phases.TRAINING.name)
15 | general_utils.log_code(conf)
16 |
17 | # Get configuration
18 | sample = conf.get_bool('dataset.sample')
19 | batch_size = conf.get_int('dataset.batch_size')
20 |
21 | train_list = SceneData.get_data_list(conf, 0)
22 | val_list = SceneData.get_data_list(conf, 1)
23 | test_list = SceneData.get_data_list(conf, 2)
24 |
25 | if sample:
26 | min_sample_size = conf.get_int('dataset.min_sample_size')
27 | max_sample_size = conf.get_int('dataset.max_sample_size')
28 | # optimization_num_of_epochs = conf.get_int("train.optimization_num_of_epochs")
29 | # optimization_eval_intervals = conf.get_int('train.optimization_eval_intervals')
30 | # optimization_lr = conf.get_int('train.optimization_lr')
31 |
32 | # Create train, test and validation sets
33 | train_scenes = SceneData.create_scene_data_from_list(train_list, conf, 0)
34 | validation_scenes = SceneData.create_scene_data_from_list(val_list, conf, 1)
35 | test_scenes = SceneData.create_scene_data_from_list(test_list, conf, 2)
36 |
37 | train_set = ScenesDataSet(train_scenes, return_all=False, min_sample_size=min_sample_size, max_sample_size=max_sample_size)
38 | validation_set = ScenesDataSet(validation_scenes, return_all=True)
39 | test_set = ScenesDataSet(test_scenes, return_all=True)
40 |
41 | # Create dataloaders
42 | train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True).to(device)
43 | validation_loader = DataLoader(validation_set, batch_size=1, shuffle=False).to(device)
44 | test_loader = DataLoader(test_set, batch_size=1, shuffle=False).to(device)
45 | else:
46 | train_loader = myDataSet(conf, 0, train_list, batch_size, True).to(device)
47 | validation_loader = myDataSet(conf, 1, val_list).to(device)
48 | test_loader = myDataSet(conf, 2, test_list).to(device)
49 |
50 | # Train model
51 | model = general_utils.get_class("models." + conf.get_string("model.type"))(conf).to(device)
52 | train_stat, train_errors, validation_errors, test_errors = train.train(conf, train_loader, model, phase, validation_loader, test_loader)
53 | # Write results
54 | general_utils.write_results(conf, train_stat, file_name="Train_Stats")
55 | general_utils.write_results(conf, train_errors, file_name="Train")
56 | general_utils.write_results(conf, validation_errors, file_name="Validation")
57 | general_utils.write_results(conf, test_errors, file_name="Test")
58 |
59 |
60 | def optimization_all_sets(conf, device, phase):
61 | # Get logs directories
62 | scans_list = conf.get_list('dataset.scans_list')
63 | for i, scan in enumerate(scans_list):
64 | conf["dataset"]["scan"] = scan
65 | train_single_model(conf, device, phase)
66 |
67 |
68 | if __name__ == "__main__":
69 | main()
--------------------------------------------------------------------------------
/code/models/baseNet.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | import torch
4 | from utils import geo_utils
5 | from pytorch3d import transforms as py3d_trans
6 | import numpy as np
7 |
8 |
9 | class BaseNet(torch.nn.Module):
10 | def __init__(self, conf):
11 | super(BaseNet, self).__init__()
12 |
13 | self.calibrated = conf.get_bool('dataset.calibrated')
14 | self.normalize_output = conf.get_string('model.normalize_output', default=None)
15 | self.rot_representation = conf.get_string('model.rot_representation', default='quat')
16 | self.soft_sign = torch.nn.Softsign()
17 |
18 | if self.calibrated and self.rot_representation == '6d':
19 | print('rot representation: ' + self.rot_representation)
20 | self.out_channels = 9
21 | elif self.calibrated and self.rot_representation == 'quat':
22 | self.out_channels = 7
23 | elif self.calibrated and self.rot_representation == 'svd':
24 | self.out_channels = 12
25 | elif not self.calibrated:
26 | self.out_channels = 12
27 | else:
28 | print("Illegal output format")
29 | exit()
30 |
31 | @abc.abstractmethod
32 | def forward(self, data):
33 | pass
34 |
35 | def extract_model_outputs(self, x, pts_3D, data):
36 | # Get points
37 | pts_3D = geo_utils.ones_padding(pts_3D)
38 |
39 | # Get calibrated predictions
40 | if self.calibrated:
41 | # Get rotation
42 | if self.rot_representation == '6d':
43 | RTs = py3d_trans.rotation_6d_to_matrix(x[:, :6])
44 | elif self.rot_representation == 'svd':
45 | m = x[:, :9].reshape(-1, 3, 3)
46 | RTs = geo_utils.project_to_rot(m)
47 | elif self.rot_representation == 'quat':
48 | RTs = py3d_trans.quaternion_to_matrix(x[:, :4])
49 | else:
50 | print("Illegal output format")
51 | exit()
52 |
53 | # Get translation
54 | minRTts = x[:, -3:]
55 |
56 | # n = len(RTs)
57 | # Ps = torch.zeros([n, 3, 4]).cuda(0)
58 |
59 | # for i,r,t in zip(np.arange(n),RTs,minRTts):
60 | # Ps[i] = r.T @ torch.cat((torch.eye(3).cuda(0), -t.view(3, 1)), dim=1)
61 |
62 | # Get camera matrix
63 | Ps = torch.cat((RTs, minRTts.unsqueeze(dim=-1)), dim=-1)
64 |
65 | else: # Projective
66 | Ps = x.reshape(-1, 3, 4)
67 |
68 | # Normalize predictions
69 | if self.normalize_output == "Chirality":
70 | scale = torch.sign(Ps[:, 0:3, 0:3].det()) / Ps[:, 2, 0:3].norm(dim=1)
71 | Ps = Ps * scale.reshape(-1, 1, 1)
72 | elif self.normalize_output == "Differentiable Chirality":
73 | scale = self.soft_sign(Ps[:, 0:3, 0:3].det() * 10e3) / Ps[:, 2, 0:3].norm(dim=1)
74 | Ps = Ps * scale.reshape(-1, 1, 1)
75 | elif self.normalize_output == "Frobenius":
76 | Ps = Ps / Ps.norm(dim=(1, 2), p='fro', keepdim=True)
77 |
78 | # The model outputs a normalized camera! Meaning from world coordinates to camera coordinates, not to pixels in the image.
79 | pred_cams = {"Ps_norm": Ps, "pts3D": pts_3D}
80 | return pred_cams
81 |
82 | def extract_camera_outputs(self, x):
83 |
84 | # quats = self.norm_quats(x[:, :4])
85 | # pred_cams = {"quats": geo_utils.norm_quats(x[:, :4]), "ts": geo_utils.scale_ts_norm(x[:, -3:])}
86 | pred_cams = {"quats": x[:, :4], "ts": x[:, -3:]}
87 | return pred_cams
--------------------------------------------------------------------------------
/code/utils/sparse_utils.py:
--------------------------------------------------------------------------------
1 | from hashlib import new
2 | import torch
3 |
4 |
5 | class SparseMat:
6 | def __init__(self, values, indices, cam_per_pts, pts_per_cam, shape):
7 | assert len(shape) == 3
8 | self.values = values
9 | self.indices = indices
10 | self.shape = shape
11 | self.cam_per_pts = cam_per_pts
12 | self.pts_per_cam = pts_per_cam
13 | self.device = self.values.device
14 |
15 | @property
16 | def size(self):
17 | return self.shape
18 |
19 |
20 | def sum(self, dim):
21 | assert dim == 1 or dim == 0
22 | n_features = self.shape[2]
23 | out_size = self.shape[0] if dim == 1 else self.shape[1]
24 | indices_index = 0 if dim == 1 else 1
25 | mat_sum = torch.zeros(out_size, n_features, device=self.device)
26 | return mat_sum.index_add(0, self.indices[indices_index], self.values)
27 |
28 |
29 | def mean(self, dim):
30 | assert dim == 1 or dim == 0
31 | if dim == 0:
32 | return self.sum(dim=0) / self.cam_per_pts
33 | else:
34 | return self.sum(dim=1) / self.pts_per_cam
35 |
36 |
37 | def std(self, dim):
38 | assert dim == 1 or dim == 0
39 | #mean_mat = torch.zeros(self.shape, device=self.device)
40 | if dim==0:
41 | meanmat = self.mean(dim=0) # n,d
42 | mean_mat = meanmat.unsqueeze(dim=0).repeat(self.shape[0], 1, 1) # m,n,d
43 | mat_vals = mean_mat[self.indices[0], self.indices[1], :] # nnz
44 | sparse_meanmat = SparseMat(mat_vals, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) # m,n,d
45 | sparse_std = (((self-sparse_meanmat)**2).sum(dim=0) / self.cam_per_pts).pow(0.5) # n,d
46 | else:
47 | meanmat = self.mean(dim=1)
48 | mean_mat = meanmat.unsqueeze(dim=1).repeat(1, self.shape[1], 1)
49 | mat_vals = mean_mat[self.indices[0], self.indices[1], :]
50 | sparse_meanmat = SparseMat(mat_vals, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape) # m,n,d
51 | sparse_std = (((self-sparse_meanmat)**2).sum(dim=1) / self.pts_per_cam).pow(0.5) # m,d
52 |
53 | return sparse_std
54 |
55 | # concat in last dim
56 | def last_dim_cat(self, other):
57 | new_values = torch.cat([self.values, other.values], -1)
58 | new_shape = (self.shape[0], self.shape[1], new_values.shape[-1])
59 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, new_shape)
60 |
61 |
62 | def to(self, device, **kwargs):
63 | self.device = device
64 | self.values = self.values.to(device, **kwargs)
65 | self.indices = self.indices.to(device, **kwargs)
66 | self.pts_per_cam = self.pts_per_cam.to(device, **kwargs)
67 | self.cam_per_pts = self.cam_per_pts.to(device, **kwargs)
68 | return self
69 |
70 |
71 |
72 | def __add__(self, other):
73 | assert self.shape == other.shape
74 | # assert (self.indices == other.indices).all() # removed due to runtime
75 | new_values = self.values + other.values
76 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape)
77 |
78 |
79 | def __sub__(self, other):
80 | assert self.shape == other.shape
81 | new_values = self.values - other.values
82 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape)
83 |
84 |
85 | def __mul__(self, other):
86 | assert self.shape[:-1] == other.shape[:-1]
87 | new_values = self.values * other.values
88 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape)
89 |
90 |
91 | def __truediv__(self, other):
92 | assert self.shape[:-1] == other.shape[:-1]
93 | new_values = self.values / other.values
94 | return SparseMat(new_values, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape)
95 |
96 |
97 | def __pow__(self, other):
98 | return SparseMat(self.values ** other, self.indices, self.cam_per_pts, self.pts_per_cam, self.shape)
--------------------------------------------------------------------------------
/bundle_adjustment/README.md:
--------------------------------------------------------------------------------
1 | # Python Bundle Adjustment
2 | These instructions are based on and almost identical to the ones at [https://github.com/drormoran/gasfm/tree/main/bundle_adjustment/README.md](https://github.com/drormoran/gasfm/tree/main/bundle_adjustment/README.md).
3 |
4 |
5 | ## Conda envorinment
6 | Use the gasfm environment.
7 | ```
8 | conda activate deepaat
9 | export PYBIND11_PYTHON_VERSION="3.8"
10 | export PYTHON_VERSION="3.8"
11 | ```
12 |
13 | ## Directory structure
14 | After this set up, the directory structure be:
15 | ```
16 | DeepAAT
17 | ├── bundle_adjustment
18 | │ ├── ceres-solver
19 | │ │ ├── ceres-bin
20 | │ │ | └── lib
21 | │ │ | └── PyCeres.cpython-39-x86_64-linux-gnu.so
22 | │ │ ├── ceres_python_bindings
23 | │ │ | └── python_bindings
24 | │ │ | └── custom_cpp_cost_functions.cpp
25 | │ │ └── CMakeLists.txt
26 | │ └── custom_cpp_cost_functions.cpp
27 | ├── code
28 | ├── datasets
29 | ```
30 | ## Set up
31 | 1. Clone the Ceres-Solver repository to the bundle_adjustment folder and check out version 2.1.0:
32 |
33 | ```
34 | cd bundle_adjustment
35 | git clone https://ceres-solver.googlesource.com/ceres-solver -b 2.1.0
36 | ```
37 |
38 |
39 | 2. Clone the ceres_python_bindings package inside the ceres-solver folder:
40 |
41 | ```
42 | cd ceres-solver
43 | git clone https://github.com/Edwinem/ceres_python_bindings
44 | ```
45 |
46 |
47 | 3. Copy the file "custom_cpp_cost_functions.cpp" and replace the file "ceres-solver/ceres_python_bindings/python_bindings/custom_cpp_cost_functions.cpp".
48 | This file contains projective and euclidean custom bundle adjustment functions.
49 |
50 | ```
51 | cp ../custom_cpp_cost_functions.cpp ceres_python_bindings/python_bindings/custom_cpp_cost_functions.cpp
52 | ```
53 |
54 | Next, you need to build ceres_python_bindings and ceres-solver and create a shared object file that python can call.
55 | You can either continue with the instructions here or follow the instructions at the ceres_python_bindings repository.
56 |
57 | 1. run:
58 |
59 | ```
60 | cd ceres_python_bindings
61 | git submodule init
62 | git submodule update
63 | ```
64 |
65 |
66 | 1. Make sure that the C++ standard library version used during the build is recent enough, and not hard-coded to C++11 by pybind11. Please check your c++ compiler version and modify it bellow (for example here is c++17):
67 |
68 | ```
69 | sed -i 's/set(PYBIND11_CPP_STANDARD -std=c++11)/set(PYBIND11_CPP_STANDARD -std=c++17)/g' AddToCeres.cmake
70 | ```
71 |
72 |
73 | 5. Add to the end of the file ceres-solver/CMakeLists.txt the line: "include(ceres_python_bindings/AddToCeres.cmake)":
74 |
75 | ```
76 | cd ..
77 | echo "include(ceres_python_bindings/AddToCeres.cmake)" >> CMakeLists.txt
78 | ```
79 |
80 |
81 | 6. Inside ceres-solver folder run:
82 |
83 |
84 | ```
85 | mkdir ceres-bin
86 | cd ceres-bin
87 | cmake ..
88 | make -j8
89 | make test
90 | ```
91 |
92 | 7. If everything worked you should see the following file:
93 |
94 | ```
95 | bundle_adjustment/ceres-solver/ceres-bin/lib/PyCeres.cpython-39-x86_64-linux-gnu.so
96 | ```
97 |
98 | 8. If you want to use this bundle adjustment implementation for a different project make sure to add the path of the shared object to linux PATH (in the code this is done for you). In the python project this would be for example:
99 |
100 | ```
101 | import sys
102 | sys.path.append('../bundle_adjustment/ceres-solver/ceres-bin/lib/')
103 | import PyCeres
104 | ```
105 |
106 | To see the usage of the PyCeres functions go to code/utils/ceres_utils and code/utils/ba_functions.
107 |
108 | ## Note
109 | If you encounter problems while compiling PyCeres, you can also refer to [GASFM](https://github.com/lucasbrynte/gasfm)'s [bundle adjustment instruction](https://github.com/lucasbrynte/gasfm/blob/main/bundle_adjustment/README.md) or create a new environment for compilation.
110 | If the environment for compiling PyCeres is not the same as the network environment, then run_ba in conf file can be set to False during training and inference, and after obtaining the predicted results, run run_ba.py separately for BA.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based mapping
2 |
3 | ### [Paper](https://www.sciencedirect.com/science/article/pii/S1569843224005466) | [arXiv](https://arxiv.org/abs/2402.01134)
4 |
5 |
6 |
7 |
8 | This is the implementation of the DeepAAT architecture, presented in our JAG paper DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based Mapping. The codebase is forked from the implementation of the ICCV 2021 paper [Deep Permutation Equivariant Structure from Motion](https://openaccess.thecvf.com/content/ICCV2021/html/Moran_Deep_Permutation_Equivariant_Structure_From_Motion_ICCV_2021_paper.html), available at [https://github.com/drormoran/Equivariant-SFM](https://github.com/drormoran/Equivariant-SFM). That architecture is also used as a baseline and referred to as ESFM in our paper.
9 |
10 | DeepAAT considers both spatial and spectral characteristics of imagery, enhancing its capability to resolve erroneous matching pairs and accurately predict image poses. DeepAAT marks a significant leap in AAT's efficiency, ensuring thorough scene coverage and precision. Its processing speed outpaces incremental AAT methods by hundreds of times and global AAT methods by tens of times while maintaining a comparable level of reconstruction accuracy. Additionally, DeepAAT's scene clustering and merging strategy facilitate rapid localization and pose determination for large-scale UAV images, even under constrained computing resources. The experimental results demonstrate DeepAAT's substantial improvements over conventional AAT methods, highlighting its potential in the efficiency and accuracy of UAV-based 3D reconstruction tasks.
11 |
12 |
13 | ## Contents
14 |
15 | - [Setup](#Setup)
16 | - [Usage](#Usage)
17 | - [Citation](#Citation)
18 |
19 | ---
20 |
21 | ## Setup
22 | This repository is implemented with python 3.8, and in order to run bundle adjustment requires linux. We have used Ubuntu 22.04. You should also have a CUDA-capable GPU.
23 |
24 | ## Directory structure
25 | The repository should contain the following directories:
26 | ```
27 | DeepAAT
28 | ├── bundle_adjustment
29 | ├── code
30 | ├── scripts
31 | ├── environment.yml
32 | ```
33 |
34 | ## Conda environment
35 | Create the environment using the following commands:
36 | ```
37 | conda env create -f environment.yml
38 | conda activate deepaat
39 | ```
40 |
41 | ## PyCeres
42 | Next follow the bundle adjustment instructions.
43 |
44 | ## Data and pretrained models
45 | Attached to [this](https://www.dropbox.com/scl/fo/gtju43lxu9zgn36ft86ly/AArhmE-Q2QxlmGwwvWUiyKc?rlkey=4lse57283dskw1mfy6oii3ti5&st=ky4qc4nk&dl=0). You can find both the datasets and pretrained models for Euclidean reconstruction of novel scenes.
46 | Download the data and pretrained model, and then modify the path in the conf file accordingly. Due to the large amount of training data, only pretrained models and test data are provided.
47 |
48 | ## Usage
49 | To execute the code, first navigate to the `code` subdirectory. Also make sure that the conda environment is activated.
50 |
51 | To train a model from scratch for reconstruction of novel test scenes, run (Please make sure to modify the corresponding data path and configuration in the conf file correctly):
52 | ```
53 | python multiple_scenes_learning.py --conf path/to/conf
54 | ```
55 | where `path/to/conf` is relative to `code/confs/`, and may e.g. be `training.conf` for training a Euclidean reconstruction model using data augmentation.
56 |
57 | The training phase is succeeded by bundle adjustment, evaluation, and by default also by separate fine-tuning of the model parameters on every test scene.
58 |
59 | To infer a new scene, run (Please make sure to modify the corresponding data path and configuration in the conf file correctly):
60 | ```
61 | python inference.py --conf inference.conf
62 | ```
63 |
64 | ## Citation
65 | If you find this work useful, please cite our paper:
66 | ```
67 | @article{chen2024deepaat,
68 | title={DeepAAT: Deep Automated Aerial Triangulation for Fast UAV-based Mapping},
69 | author={Chen, Zequan and Li, Jianping and Li, Qusheng and Dong, Zhen and Yang, Bisheng},
70 | journal={International Journal of Applied Earth Observation and Geoinformation},
71 | volume={134},
72 | pages={104190},
73 | year={2024},
74 | publisher={Elsevier}
75 | }
76 | ```
77 |
78 | ## Acknowledgement
79 | We make improvements based on [ESFM](https://github.com/drormoran/Equivariant-SFM). We thank the authors for releasing the source code.
--------------------------------------------------------------------------------
/code/loss_functions.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | import torch
3 | from utils import geo_utils
4 | from torch import dtype, nn
5 | from torch.nn import functional as F
6 | from pytorch3d import transforms as py3d_trans
7 |
8 |
9 | class ESFMLoss(nn.Module):
10 | def __init__(self, conf):
11 | super().__init__()
12 | self.infinity_pts_margin = conf.get_float("loss.infinity_pts_margin")
13 | self.normalize_grad = conf.get_bool("loss.normalize_grad")
14 |
15 | self.hinge_loss = conf.get_bool("loss.hinge_loss")
16 | if self.hinge_loss:
17 | self.hinge_loss_weight = conf.get_float("loss.hinge_loss_weight")
18 | else:
19 | self.hinge_loss_weight = 0
20 |
21 | def forward(self, pred_cam, data, epoch=None):
22 | Ps = pred_cam["Ps_norm"]
23 | pts_2d = Ps @ pred_cam["pts3D"] # [m, 3, n]
24 |
25 | if self.normalize_grad:
26 | pts_2d.register_hook(lambda grad: F.normalize(grad, dim=1) / data.valid_pts.sum())
27 |
28 | projected_points = geo_utils.get_positive_projected_pts_mask(pts_2d, self.infinity_pts_margin)
29 | else:
30 | projected_points = geo_utils.get_projected_pts_mask(pts_2d, self.infinity_pts_margin)
31 |
32 | # Calculate hinge Loss
33 | hinge_loss = (self.infinity_pts_margin - pts_2d[:, 2, :]) * self.hinge_loss_weight
34 |
35 | # Calculate reprojection error
36 | pts_2d = (pts_2d / torch.where(projected_points, pts_2d[:, 2, :], torch.ones_like(projected_points).float()).unsqueeze(dim=1))
37 | reproj_err = (pts_2d[:, 0:2, :] - data.norm_M.reshape(Ps.shape[0], 2, -1)).norm(dim=1)
38 |
39 | return torch.where(projected_points, reproj_err, hinge_loss)[data.valid_pts].mean()
40 |
41 |
42 | class GTLoss(nn.Module):
43 | def __init__(self, conf):
44 | super().__init__()
45 | # self.calibrated = conf.get_bool('dataset.calibrated')
46 | self.train_trans = conf.get_bool('train.train_trans')
47 | self.alpha = conf.get_float("dataset.alpha")
48 | self.beta = conf.get_float("dataset.beta")
49 |
50 | def forward(self, pred_cam, data, epoch=None):
51 |
52 | if self.train_trans:
53 | ts_gt = data.ts - data.gpss
54 | orient_err = (data.quats - pred_cam["quats"]).norm(2, dim=1)
55 | translation_err = (ts_gt - pred_cam["ts"]).norm(2, dim=1)
56 | orient_loss = orient_err.mean() * self.alpha
57 | trans_loss = translation_err.mean() * self.beta
58 | loss = orient_loss + trans_loss
59 |
60 | if epoch is not None and epoch % 1000 == 0:
61 | # Print loss
62 | print("GTloss = {}, orient err = {}, trans err = {}".format(loss, orient_loss, trans_loss))
63 | return loss, orient_loss, trans_loss
64 |
65 | else:
66 | orient_err = (data.quats - pred_cam["quats"]).norm(2, dim=1)
67 | orient_loss = orient_err.mean()
68 |
69 | if epoch is not None and epoch % 1000 == 0:
70 | # Print loss
71 | print("orient err = {}".format(orient_loss))
72 | return orient_loss
73 |
74 |
75 | class BCELoss(nn.Module):
76 | def __init__(self):
77 | super().__init__()
78 | self.bceloss = nn.BCELoss()
79 |
80 | def forward(self, pred_mask, data, epoch=None):
81 | # pred_mask_ = pred_mask.reshape(-1, 1) # m,n,1 -> m*n,1
82 | # mask_ = data.mask # 2m,n
83 | # mask_ = mask_[0:mask_.shape[0]:2,:].reshape(-1, 1) # 2m,n -> m*n,1
84 | pred_mask_ = torch.as_tensor(pred_mask.values).reshape(len(pred_mask.values), 1)
85 | gt_mask_ = torch.as_tensor(data.mask_sparse.values).reshape(len(data.mask_sparse.values), 1)
86 |
87 | loss = self.bceloss(pred_mask_, gt_mask_).mean()
88 | if epoch is not None and epoch % 1000 == 0:
89 | # Print loss
90 | print("BCEloss = {}".format(loss))
91 | return loss
92 |
93 | class BCEWithLogitLoss(nn.Module):
94 | def __init__(self):
95 | super().__init__()
96 | self.lbceloss = nn.BCEWithLogitsLoss()
97 |
98 | def forward(self, pred_mask, data, epoch=None):
99 |
100 | pred_mask_ = torch.as_tensor(pred_mask.values).reshape(len(pred_mask.values), 1)
101 | gt_mask_ = torch.as_tensor(data.mask_sparse.values).reshape(len(data.mask_sparse.values), 1)
102 |
103 | loss = self.lbceloss(pred_mask_, gt_mask_).mean()
104 | if epoch is not None and epoch % 1000 == 0:
105 | # Print loss
106 | print("BCEloss = {}".format(loss))
107 | return loss
--------------------------------------------------------------------------------
/code/datasets/ScenesDataSet.py:
--------------------------------------------------------------------------------
1 | from datasets import SceneData
2 | import numpy as np
3 | from torch.utils.data.dataset import Dataset
4 | from datasets import Euclidean
5 |
6 | class DataLoader():
7 | def __init__(self, dataset, batch_size=1, shuffle=False):
8 | self.n = len(dataset)
9 | self.dataset = dataset
10 | self.batch_size = batch_size
11 | self.num_batches = int(np.ceil(self.n / self.batch_size))
12 | self.shuffle=shuffle
13 | self.permutation = self.init_permutation()
14 | self.current_batch = 0
15 | self.device = 'cpu'
16 |
17 | def init_permutation(self):
18 | return np.random.permutation(self.n) if self.shuffle else np.arange(self.n)
19 |
20 | def __iter__(self):
21 | self.current_batch = 0
22 | self.permutation = self.init_permutation()
23 | return self
24 |
25 | def __next__(self):
26 | if self.current_batch == self.num_batches:
27 | raise StopIteration
28 | start_ind = self.current_batch*self.batch_size
29 | end_ind = min((self.current_batch+1)*self.batch_size, self.n)
30 | current_indices = self.permutation[start_ind:end_ind]
31 | self.current_batch += 1
32 | return [self.dataset[i].to(self.device) for i in current_indices]
33 |
34 | def __len__(self):
35 | return self.n
36 |
37 | def to(self, device, **kwargs):
38 | self.device = device
39 | return self
40 |
41 |
42 | class myDataSetds(Dataset):
43 | def __init__(self, conf, datalist, flag):
44 | self.datalist = datalist
45 | self.conf = conf
46 | self.flag = flag
47 | self.dilute_M = conf.get_bool('dataset.diluteM', default=False)
48 |
49 | def __len__(self):
50 | return len(self.datalist)
51 |
52 | def __getitem__(self, idx):
53 | file = self.datalist[idx]
54 | M, Ns, Rs, ts, quats, mask = Euclidean.get_raw_data(self.conf, file, self.flag)
55 | data = SceneData.SceneData(M, Ns, Rs, ts, quats, mask, file, self.dilute_M)
56 | return data
57 |
58 |
59 | class myDataSet():
60 | def __init__(self, conf, flag, datalist, batch_size=1, shuffle=False):
61 | self.n = len(datalist)
62 | self.datalist = datalist
63 | self.batch_size = batch_size
64 | self.num_batches = int(np.ceil(self.n / self.batch_size))
65 | self.shuffle=shuffle
66 | self.permutation = self.init_permutation()
67 | self.current_batch = 0
68 | self.device = 'cpu'
69 | self.conf = conf
70 | self.flag = flag
71 | self.dilute_M = conf.get_bool('dataset.diluteM', default=False)
72 |
73 | # print(f"num_batches {self.num_batches}")
74 | # print(f"permutation {self.permutation}")
75 | # print(f"flag {self.flag}")
76 |
77 | def init_permutation(self):
78 | return np.random.permutation(self.n) if self.shuffle else np.arange(self.n)
79 |
80 | def __iter__(self):
81 | self.current_batch = 0
82 | self.permutation = self.init_permutation()
83 | # print(f"in iter, perm: {self.permutation}")
84 | return self
85 |
86 | def __next__(self):
87 | if self.current_batch == self.num_batches:
88 | raise StopIteration
89 | start_ind = self.current_batch*self.batch_size
90 | # print(f"start ind: {start_ind}")
91 | end_ind = min((self.current_batch+1)*self.batch_size, self.n)
92 | current_indices = self.permutation[start_ind:end_ind]
93 | # print(f"current_indices: {current_indices}")
94 | # print(f"current data: {self.datalist[current_indices[0]]}")
95 | self.current_batch += 1
96 | return [self.loaddata(self.datalist[i]).to(self.device) for i in current_indices]
97 |
98 | def __len__(self):
99 | return self.n
100 |
101 | def to(self, device, **kwargs):
102 | self.device = device
103 | return self
104 |
105 | def loaddata(self, file):
106 | M, Ns, Rs, ts, quats, mask, gpss, color, scale, use_spatial_encoder, dsc_idx, dsc_data, dsc_shape = Euclidean.get_raw_data(self.conf, file, self.flag)
107 | data = SceneData.SceneData(M, Ns, Rs, ts, quats, mask, file, gpss, color, scale, self.dilute_M, use_spatial_encoder,
108 | dsc_idx, dsc_data, dsc_shape)
109 | return data
110 |
111 |
112 | class ScenesDataSet:
113 | def __init__(self, data_list, return_all, min_sample_size=10, max_sample_size=30):
114 | super().__init__()
115 | self.data_list = data_list
116 | self.return_all = return_all
117 | self.min_sample_size = min_sample_size
118 | self.max_sample_size = max_sample_size
119 |
120 | def __getitem__(self, item):
121 | current_data = self.data_list[item]
122 | if self.return_all:
123 | return current_data
124 | else:
125 | max_sample = min(self.max_sample_size, len(current_data.scan_name))
126 | if self.min_sample_size >= max_sample:
127 | sample_fraction = max_sample
128 | else:
129 | sample_fraction = np.random.randint(self.min_sample_size, max_sample + 1)
130 | return SceneData.sample_data(current_data, sample_fraction)
131 |
132 | def __len__(self):
133 | return len(self.data_list)
--------------------------------------------------------------------------------
/code/run_ba.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import geo_utils, ba_functions, general_utils
3 | import os
4 | import pandas as pd
5 |
6 |
7 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined):
8 | outputs = {}
9 |
10 | file = dict(np.load(npzpath))
11 | scan_name = file['scan_name']
12 | xs = file['xs']
13 | Rs_pred = file['Rs']
14 | ts_pred = file['ts']
15 | Rs_gt = file['Rs_gt']
16 | ts_gt = file['ts_gt']
17 | Ks = file['Ks']
18 | Ns = np.linalg.inv(Ks)
19 | Xs = file['pts3D_pred']
20 | Ps = file['Ps']
21 | raw_xs = file['raw_xs']
22 |
23 | outputs['xs'] = xs
24 | outputs['Rs_gt'] = Rs_gt
25 | outputs['ts_gt'] = ts_gt
26 |
27 | ba_res = ba_functions.euc_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=Ks, Xs=Xs.T, Ps=Ps, Ns=Ns,
28 | repeat=repeat, max_iter=max_iter, ba_times=ba_times,
29 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs
30 | outputs['Rs_ba'] = ba_res['Rs']
31 | outputs['ts_ba'] = ba_res['ts']
32 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n
33 | outputs['Ps_ba'] = ba_res['Ps']
34 | if refined:
35 | outputs['valid_points'] = ba_res['valid_points']
36 | outputs['new_xs'] = ba_res['new_xs']
37 | outputs['colidx'] = ba_res['colidx']
38 |
39 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt,
40 | return_alignment=True) # Align Rs_fixed, tx_fixed
41 | outputs['Rs_ba_fixed'] = R_ba_fixed
42 | outputs['ts_ba_fixed'] = t_ba_fixed
43 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba'])
44 | # outputs['Rs_ba_fixed'] = ba_res['Rs']
45 | # outputs['ts_ba_fixed'] = ba_res['ts']
46 | # outputs['Xs_ba_fixed'] = ba_res['Xs'].T
47 | file.update(outputs)
48 | np.savez(npzpath, **file)
49 |
50 | return file, scan_name
51 |
52 |
53 | def _compute_errors(outputs, refined=False):
54 | model_errors = {}
55 |
56 | premask = outputs['premask']
57 | gtmask = outputs['gtmask']
58 |
59 | tp, fp, tn, fn = general_utils.compute_confusion_matrix(premask, gtmask)
60 | accuracy, precision, recall, F1 = general_utils.compute_indexes(tp, fp, tn, fn)
61 | model_errors["TP"] = tp
62 | model_errors["FP"] = fp
63 | model_errors["TN"] = tn
64 | model_errors["FN"] = fn
65 | model_errors["Accuracy"] = accuracy
66 | model_errors["Precision"] = precision
67 | model_errors["Recall"] = recall
68 | model_errors["F1"] = F1
69 |
70 | pts3D_pred = outputs['pts3D_pred']
71 | Ps = outputs['Ps']
72 | Rs_fixed = outputs['Rs']
73 | ts_fixed = outputs['ts']
74 | Rs_gt = outputs['Rs_gt']
75 | ts_gt = outputs['ts_gt']
76 | xs = outputs['xs']
77 | Xs_ba = outputs['Xs_ba']
78 | Ps_ba = outputs['Ps_ba']
79 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs)
80 | if not our_repro_error.shape: return model_errors
81 | model_errors["our_repro"] = np.nanmean(our_repro_error)
82 | model_errors["our_repro_max"] = np.nanmax(our_repro_error)
83 |
84 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt)
85 | model_errors["ts_mean"] = np.mean(ts_error)
86 | model_errors["ts_med"] = np.median(ts_error)
87 | model_errors["ts_max"] = np.max(ts_error)
88 | model_errors["Rs_mean"] = np.mean(Rs_error)
89 | model_errors["Rs_med"] = np.median(Rs_error)
90 | model_errors["Rs_max"] = np.max(Rs_error)
91 |
92 | if refined:
93 | valid_points = outputs['valid_points']
94 | new_xs = outputs['new_xs']
95 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points)
96 | else:
97 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs)
98 | model_errors['repro_ba'] = np.nanmean(repro_ba_error)
99 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error)
100 |
101 | Rs_ba_fixed = outputs['Rs_ba_fixed']
102 | ts_ba_fixed = outputs['ts_ba_fixed']
103 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt)
104 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error)
105 | model_errors["ts_ba_med"] = np.median(ts_ba_error)
106 | model_errors["ts_ba_max"] = np.max(ts_ba_error)
107 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error)
108 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error)
109 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error)
110 |
111 | return model_errors
112 |
113 |
114 | def runba():
115 | basedir = "/home/zeeq/results/test10"
116 | refined = True
117 | repeat = True
118 | max_iter = 50
119 | ba_times = 1
120 | repro_thre = 5
121 |
122 | errors_list = []
123 | for _,_,files in os.walk(basedir):
124 | for f in files:
125 | print(f'processing {f} ...')
126 | npzpath = os.path.join(basedir, f)
127 | outputs, scan_name = _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined)
128 | errors = _compute_errors(outputs, refined)
129 | errors['Scene'] = scan_name
130 | errors_list.append(errors)
131 |
132 | df_errors = pd.DataFrame(errors_list)
133 | mean_errors = df_errors.mean(numeric_only=True)
134 | # df_errors = df_errors.append(mean_errors, ignore_index=True)
135 | df_errors = pd.concat([df_errors, mean_errors], ignore_index=True)
136 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean"
137 | df_errors.set_index("Scene", inplace=True)
138 | df_errors = df_errors.round(3)
139 | print(df_errors.to_string(), flush=True)
140 |
141 | if __name__=="__main__":
142 | runba()
143 |
--------------------------------------------------------------------------------
/code/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import geo_utils, general_utils, sparse_utils
3 | import numpy as np
4 |
5 |
6 | def is_valid_sample(data, min_pts_per_cam):
7 | return data.x.pts_per_cam.min().item() >= min_pts_per_cam
8 |
9 |
10 | def divide_indices_to_train_test(N, n_val, n_test=0):
11 | perm = np.random.permutation(N)
12 | test_indices = perm[:n_test] if n_test>0 else []
13 | val_indices = perm[n_test:n_test+n_val]
14 | train_indices = perm[n_test+n_val:]
15 | return train_indices, val_indices, test_indices
16 |
17 |
18 | def sample_indices(N, num_samples, adjacent):
19 | if num_samples == 1: # Return all the data
20 | indices = np.arange(N)
21 | else:
22 | if num_samples < 1:
23 | num_samples = int(np.ceil(num_samples * N))
24 | num_samples = max(2, num_samples)
25 | if num_samples>=N:
26 | return np.arange(N)
27 | if adjacent:
28 | start_ind = np.random.randint(0,N-num_samples+1)
29 | end_ind = start_ind+num_samples
30 | indices = np.arange(start_ind, end_ind)
31 | else:
32 | indices = np.random.choice(N,num_samples,replace=False)
33 | return indices
34 |
35 |
36 | def radius_sample(ts, num_samples):
37 | total_num = len(ts)
38 | if num_samples>=total_num:
39 | return np.arange(total_num)
40 | idx = np.random.randint(total_num)
41 | center = ts[idx]
42 | dist = ((ts-center)**2).sum(axis=1)
43 | sorted_ids = np.argsort(dist)
44 | return sorted_ids[:num_samples]
45 |
46 |
47 | def simulate_sample(ts, num_samples, pts_per_cam):
48 | total_num = len(ts)
49 | fsamples = 2*num_samples
50 | if fsamples>=total_num:
51 | return np.arange(total_num)
52 | idx = np.random.randint(total_num)
53 | center = ts[idx]
54 | dist = ((ts-center)**2).sum(axis=1)
55 | sorted_ids = np.argsort(dist)
56 | fids = sorted_ids[:fsamples]
57 | visits = pts_per_cam[fids]
58 | sampled_ids = np.argsort(visits)
59 | return fids[sampled_ids]
60 |
61 |
62 | def save_cameras(outputs, conf, curr_epoch, phase):
63 | # xs = outputs['xs']
64 | # M = geo_utils.xs_to_M(xs)
65 | general_utils.save_camera_mat(conf, outputs, outputs['scan_name'], phase, curr_epoch)
66 |
67 |
68 | def get_data_statistics(all_data):
69 | valid_pts = all_data.valid_pts
70 | valid_pts_stat = valid_pts.sum(dim=0).float()
71 | stats = {"Max_2d_pt": all_data.M.max().item(), "Num_2d_pts": valid_pts.sum().item(), "n_pts":all_data.M.shape[-1],
72 | "Cameras_per_pts_mean": valid_pts_stat.mean().item(), "Cameras_per_pts_std": valid_pts_stat.std().item(),
73 | "Num of cameras": all_data.ts.shape[0]}
74 | return stats
75 |
76 |
77 | def get_data_statistics2(all_data, pred_data):
78 | constructed_data = pred_data['pts3D_pred']
79 | constructed_pt3ds_num = constructed_data.shape[1]
80 | valid_pts = all_data.valid_pts
81 | valid_pts_stat = valid_pts.sum(dim=0).float()
82 | mask = all_data.mask
83 | stats = {"Max_2d_pt": all_data.M.max().item(), "Num_2d_pts": valid_pts.sum().item(), "all_pts":all_data.M.shape[-1],
84 | "constructed_pts": constructed_pt3ds_num, "ground_truth_pts": mask[:,mask.sum(axis=0)!=0].shape[1],
85 | "Cameras_per_pts_mean": valid_pts_stat.mean().item(),
86 | "Cameras_per_pts_std": valid_pts_stat.std().item(), "Num of cameras": all_data.ts.shape[0]}
87 | return stats
88 |
89 |
90 | def correct_matches_global(M, Ps, Ns):
91 | M_invalid_pts = np.logical_not(get_M_valid_points(M))
92 |
93 | Xs = geo_utils.n_view_triangulation(Ps, M, Ns)
94 | xs = geo_utils.batch_pflat((Ps @ Xs))[:, 0:2, :]
95 |
96 | # Remove invalid points
97 | xs[np.isnan(xs)] = 0
98 | xs[np.stack((M_invalid_pts, M_invalid_pts), axis=1)] = 0
99 |
100 | return xs.reshape(M.shape)
101 |
102 |
103 | def get_M_valid_points(M):
104 | n_pts = M.shape[-1]
105 |
106 | if type(M) is torch.Tensor:
107 | # m✖2✖n -> m✖n
108 | M_valid_pts = torch.abs(M.reshape(-1, 2, n_pts)).sum(dim=1) != 0
109 | M_valid_pts[:, M_valid_pts.sum(dim=0) < 2] = False
110 | else:
111 | M_valid_pts = np.abs(M.reshape(-1, 2, n_pts)).sum(axis=1) != 0
112 | M_valid_pts[:, M_valid_pts.sum(axis=0) < 2] = False
113 |
114 | return M_valid_pts
115 |
116 |
117 | def M2sparse(M, normalize=False, Ns=None):
118 | n_pts = M.shape[1]
119 | n_cams = int(M.shape[0] / 2)
120 |
121 | # Get indices
122 | valid_pts = get_M_valid_points(M) # m,n
123 | cam_per_pts = valid_pts.sum(dim=0).unsqueeze(1) # n
124 | pts_per_cam = valid_pts.sum(dim=1).unsqueeze(1) # m
125 | mat_indices = torch.nonzero(valid_pts).T
126 |
127 | # Get Values
128 | # reshaped_M = M.reshape(n_cams, 2, n_pts).transpose(1, 2) # [2m, n] -> [m, 2, n] -> [m, n, 2]
129 | if normalize:
130 | norm_M = geo_utils.normalize_M(M, Ns)
131 | mat_vals = norm_M[mat_indices[0], mat_indices[1], :] # nnz,2
132 | else:
133 | mat_vals = M.reshape(n_cams, 2, n_pts).transpose(1, 2)[mat_indices[0], mat_indices[1], :]
134 |
135 | mat_shape = (n_cams, n_pts, 2)
136 | return sparse_utils.SparseMat(mat_vals, mat_indices, cam_per_pts, pts_per_cam, mat_shape)
137 |
138 |
139 | def order_indices(indices, shuffle):
140 | if shuffle:
141 | np.random.shuffle(indices)
142 | M_indices = np.zeros(len(indices)*2, dtype=np.int64)
143 | M_indices[::2] = 2 * indices
144 | M_indices[1::2] = 2 * indices + 1
145 | else:
146 | indices.sort()
147 | M_indices = np.sort(np.concatenate((2 * indices, 2 * indices + 1)))
148 | return indices, M_indices
149 |
150 |
151 | def get_mask_by_reproj(M, M_gt, thred_square):
152 | n_pts = M.shape[-1]
153 | reproj_err = np.square(M.reshape(-1, 2, n_pts).transpose(0,2,1) - M_gt.reshape(-1, 2, n_pts).transpose(0,2,1)).sum(axis=-1) # m*n
154 | out_mat = reproj_err>thred_square
155 | ans = np.zeros((out_mat.shape[0], out_mat.shape[1], 2), dtype = bool)
156 | ans[:,:,0] = out_mat
157 | ans[:,:,1] = out_mat
158 | ans = ans.transpose(0,2,1).reshape(-1, n_pts)
159 | tmpM = np.copy(M)
160 | tmpM[ans] = 0
161 | return np.where(tmpM!=0, 1., 0.)
--------------------------------------------------------------------------------
/code/models/SetOfSet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models.baseNet import BaseNet
4 | from models.layers import *
5 | from utils import geo_utils
6 |
7 |
8 | class SetOfSetBlock(nn.Module):
9 | def __init__(self, d_in, d_out, conf):
10 | super(SetOfSetBlock, self).__init__()
11 | self.block_size = conf.get_int("model.block_size")
12 | self.use_skip = conf.get_bool("model.use_skip")
13 |
14 | modules = []
15 | # modules.extend([SetOfSetLayer(d_in, 16),
16 | # NormalizationLayer(),
17 | # ActivationLayer(),
18 | # SetOfSetLayer(16, 64),
19 | # NormalizationLayer(),
20 | # ActivationLayer(),
21 | # SetOfSetLayer(64, 256),
22 | # NormalizationLayer(),
23 | # ])
24 | modules.extend([SetOfSetLayer(d_in, d_out), NormalizationLayer()])
25 | for i in range(1, self.block_size):
26 | modules.extend([ActivationLayer(), SetOfSetLayer(d_out, d_out), NormalizationLayer()])
27 | self.layers = nn.Sequential(*modules)
28 |
29 | self.final_act = ActivationLayer()
30 |
31 | if self.use_skip:
32 | if d_in == d_out:
33 | self.skip = IdentityLayer()
34 | else:
35 | self.skip = nn.Sequential(ProjLayer(d_in, d_out), NormalizationLayer())
36 |
37 | def forward(self, x):
38 | # x is [m,n,d] sparse matrix
39 | xl = self.layers(x)
40 | if self.use_skip:
41 | xl = self.skip(x) + xl
42 |
43 | out = self.final_act(xl)
44 | return out
45 |
46 | # def __init__(self, d_in, d_out, conf):
47 | # super(SetOfSetBlock, self).__init__()
48 | # self.block_size = conf.get_int("model.block_size")
49 |
50 | # self.firstlayer = SetOfSetLayer(d_in, d_out)
51 | # self.submodules1 = nn.Sequential(NormalizationLayer(), ActivationLayer(), SetOfSetLayer(d_out, d_out))
52 | # self.submodules2 = nn.Sequential(NormalizationLayer(), ActivationLayer(), SetOfSetLayer(d_out, d_out))
53 | # self.lastlayers = nn.Sequential(NormalizationLayer(), ActivationLayer())
54 |
55 | # def forward(self, x):
56 | # # x is [m,n,d] sparse matrix
57 | # x = self.firstlayer(x)
58 | # x = self.submodules1(x) + x
59 | # x = self.submodules2(x) + x
60 | # out = self.lastlayers(x)
61 | # return out
62 |
63 |
64 | class SetOfSetNet(BaseNet):
65 | def __init__(self, conf):
66 | super(SetOfSetNet, self).__init__(conf)
67 | # n is the number of points and m is the number of cameras
68 | num_blocks = conf.get_int('model.num_blocks')
69 | num_feats = conf.get_int('model.num_features')
70 | self.train_trans = conf.get_bool('train.train_trans')
71 | self.use_spatial_encoder = conf.get_bool('dataset.use_spatial_encoder')
72 | self.x_embed_rank = conf.get_int('dataset.x_embed_rank')
73 | self.egps_embed_rank = conf.get_int('dataset.egps_embed_rank')
74 | self.dsc_egps_embed_width = conf.get_int('dataset.dsc_egps_embed_width')
75 | self.gps_embed_width = conf.get_int('dataset.gps_embed_width')
76 |
77 | m_d_out_rot = 4
78 | m_d_out_trans = 3
79 | d_in = 2
80 |
81 | self.embed_x = EmbeddingLayer(self.x_embed_rank, d_in)
82 | self.embed_egps = EmbeddingLayer(self.egps_embed_rank, 3)
83 | self.embed_dsc_egps = ProjLayer(128+self.embed_egps.d_out, self.dsc_egps_embed_width)
84 |
85 | if self.use_spatial_encoder:
86 | self.equivariant_blocks = torch.nn.ModuleList([SetOfSetBlock(self.embed_x.d_out + self.dsc_egps_embed_width, num_feats, conf)])
87 | else:
88 | self.equivariant_blocks = torch.nn.ModuleList([SetOfSetBlock(2, num_feats, conf)])
89 |
90 | for i in range(num_blocks - 1):
91 | self.equivariant_blocks.append(SetOfSetBlock(num_feats, num_feats, conf))
92 |
93 | self.pt_net = nn.Sequential(ProjLayer(num_feats, num_feats//2),
94 | RCNormLayer(),
95 | ActivationLayer(),
96 | ProjLayer(num_feats, num_feats//2),
97 | RCNormLayer(),
98 | ActivationLayer(),
99 | ProjLayer(num_feats, 1))
100 |
101 | # self.embed_gps = PosiEmbedding(self.egps_embed_rank) #nn.Linear(3, 128)
102 | self.embed_gps = nn.Linear(3, self.gps_embed_width)
103 |
104 | # self.m_net = get_linear_layers([num_feats] * 2 + [m_d_out], final_layer=True, batchnorm=False)
105 | # self.n_net = get_linear_layers([num_feats] * 2 + [n_d_out], final_layer=True, batchnorm=False)
106 | self.m_net_rot = get_linear_layers([num_feats] * 2 + [m_d_out_rot], final_layer=True, batchnorm=False)
107 | self.m_net_tran = get_linear_layers([num_feats+self.gps_embed_width, 256, m_d_out_trans], final_layer=True, batchnorm=False)
108 | # self.m_net_tran = get_linear_layers([num_feats+6*self.egps_embed_rank+3, 256, m_d_out_trans], final_layer=True, batchnorm=False)
109 |
110 | def forward(self, data):
111 | x = data.x # x is [m,n,d] sparse matrix
112 | # x = self.embed_x(x)
113 | if self.use_spatial_encoder:
114 | x = self.embed_x(x)
115 | egps = self.embed_egps(data.egps_sparse)
116 | dsc_egps = data.dsc.last_dim_cat(egps)
117 | embed_de = self.embed_dsc_egps(dsc_egps)
118 | x = x.last_dim_cat(embed_de)
119 | # x = x.last_dim_cat(self.embed_dsc_egps(data.dsc.last_dim_cat(self.embed_egps(data.egps_sparse))))
120 |
121 | for eq_block in self.equivariant_blocks:
122 | x = eq_block(x) # [m,n,d_in] -> [m,n,d_out]
123 |
124 | # pt class predictions
125 | pt_out = self.pt_net(x) # [m,n,d_out] -> [m,n,1]
126 |
127 | # Cameras predictions
128 | # x = x*pt_out # [m,n,d_out] -> [m,n,d_out]
129 | m_input = x.mean(dim=1) # [m,d_out]
130 | # m_out = self.m_net(m_input) # [m, d_m]
131 | # pred_cam = self.extract_camera_outputs(m_out)
132 | rot = self.m_net_rot(m_input)
133 | # normed_gps = geo_utils.scale_ts_norm(data.gpss)
134 | if self.train_trans:
135 | tran = self.m_net_tran(torch.cat([m_input, self.embed_gps(data.gpss)], -1))
136 | # tran = self.m_net_tran(m_input)
137 | pred_cam = {"quats": rot, "ts": tran}
138 | else:
139 | pred_cam = {"quats": rot}
140 | # tran = self.m_net_tran(m_input)
141 | # pred_cam = {"quats": rot, "ts": data.gpss}
142 |
143 | return pt_out, pred_cam
--------------------------------------------------------------------------------
/code/utils/general_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from datetime import datetime
3 | from scipy.io import savemat
4 | import shutil
5 | from pyhocon import HOCONConverter,ConfigTree
6 | import sys
7 | import json
8 | from pyhocon import ConfigFactory
9 | import argparse
10 | import os
11 | import numpy as np
12 | import pandas as pd
13 | import portalocker
14 | from utils.Phases import Phases
15 | from utils.path_utils import path_to_exp, path_to_cameras, path_to_code_logs, path_to_conf
16 | import random
17 |
18 |
19 | def log_code(conf):
20 | code_path = path_to_code_logs(conf)
21 |
22 | files_to_log = ["train.py", "single_scene_optimization.py", "multiple_scenes_learning.py", "loss_functions.py"]
23 | for file_name in files_to_log:
24 | shutil.copyfile('{}'.format(file_name), os.path.join(code_path, file_name))
25 |
26 | dirs_to_log = ["datasets", "models"]
27 | for dir_name in dirs_to_log:
28 | shutil.copytree('{}'.format(dir_name), os.path.join(code_path, dir_name))
29 |
30 | # Print conf
31 | with open(os.path.join(code_path, 'exp.conf'), 'w') as conf_log_file:
32 | conf_log_file.write(HOCONConverter.convert(conf, 'hocon'))
33 |
34 |
35 | def save_camera_mat(conf, save_cam_dict, scan, phase, epoch=None):
36 | path_cameras = path_to_cameras(conf, phase, epoch=epoch, scan=scan)
37 | np.savez(path_cameras, **save_cam_dict)
38 | #savemat(path_cameras, save_cam_dict)
39 |
40 |
41 | def write_results(conf, df, file_name="Results", append=False):
42 | exp_path = path_to_exp(conf)
43 | results_file_path = os.path.join(exp_path, '{}.xlsx'.format(file_name))
44 |
45 | if append:
46 | locker_file = os.path.join(exp_path, '{}.lock'.format(file_name))
47 | lock = portalocker.Lock(locker_file, timeout=1000)
48 | with lock:
49 | if os.path.exists(results_file_path):
50 | prev_df = pd.read_excel(results_file_path).set_index("Scene")
51 | merged_err_df = prev_df.append(df)
52 | else:
53 | merged_err_df = df
54 |
55 | merged_err_df.to_excel(results_file_path)
56 | else:
57 | df.to_excel(results_file_path)
58 |
59 |
60 | def init_exp_version():
61 | return '{:%Y_%m_%d_%H_%M_%S}'.format(datetime.now())
62 |
63 |
64 | def get_class(kls):
65 | parts = kls.split('.') # models SetOfSet SetOfSetNet
66 | module = ".".join(parts[:-1]) # models.SetOfSet
67 | m = __import__(module)
68 | for comp in parts[1:]: # SetOfSet SetOfSetNet
69 | m = getattr(m, comp)
70 | return m
71 |
72 |
73 | def count_parameters(model):
74 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
75 |
76 |
77 | def print_error(err_string):
78 | print(err_string, file=sys.stderr)
79 |
80 |
81 | def config_tree_to_string(config):
82 | config_dict={}
83 | for it in config.keys():
84 | if isinstance(config[it],ConfigTree):
85 | it_dict = {key:val for key,val in config[it].items()}
86 | config_dict[it]=it_dict
87 | else:
88 | config_dict[it] = config[it]
89 | return json.dumps(config_dict)
90 |
91 |
92 | def bmvm(bmats, bvecs):
93 | return torch.bmm(bmats, bvecs.unsqueeze(-1)).squeeze()
94 |
95 |
96 | def get_full_conf_vals(conf):
97 | # return a conf file as a dictionary as follow:
98 | # "key.key.key...key": value
99 | # Useful for the conf.put() command
100 | full_vals = {}
101 | for key, val in conf.items():
102 | if isinstance(val, dict):
103 | part_vals = get_full_conf_vals(val)
104 | for part_key, part_val in part_vals.items():
105 | full_vals[key + "." +part_key] = part_val
106 | else:
107 | full_vals[key] = val
108 |
109 | return full_vals
110 |
111 |
112 | def parse_external_params(ext_params_str, conf):
113 | for param in ext_params_str.split(','):
114 | key_val = param.split(':')
115 | if len(key_val) == 3:
116 | conf[key_val[0]][key_val[1]] = key_val[2]
117 | elif len(key_val) == 2:
118 | conf[key_val[0]] = key_val[1]
119 | return conf
120 |
121 |
122 | def init_exp(default_phase):
123 | # Parse Arguments
124 | parser = argparse.ArgumentParser()
125 | parser.add_argument('--conf', type=str)
126 | parser.add_argument('--scan', type=str, default=None)
127 | parser.add_argument('--exp_version', type=str, default=None)
128 | parser.add_argument('--external_params', type=str, default=None)
129 | parser.add_argument('--phase', type=str, default=default_phase)
130 | opt = parser.parse_args()
131 |
132 | # Init Device
133 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
134 |
135 | # Init Conf
136 | conf_file_path = path_to_conf(opt.conf)
137 | conf = ConfigFactory.parse_file(conf_file_path)
138 | conf["original_file_name"] = opt.conf
139 |
140 | # Init external params
141 | if opt.external_params is not None:
142 | conf = parse_external_params(opt.external_params, conf)
143 |
144 | # Init Version
145 | if opt.exp_version is None:
146 | exp_version = init_exp_version()
147 | else:
148 | exp_version = opt.exp_version
149 | conf['exp_version'] = exp_version
150 |
151 | # Init scan
152 | if opt.scan is not None:
153 | conf['dataset']['scan'] = opt.scan
154 | elif 'scan' not in conf['dataset'].keys():
155 | conf['dataset']['scan'] = 'Multiple_Scenes'
156 |
157 | # Init Seed
158 | seed = conf.get_int('random_seed', default=None)
159 | if seed is not None:
160 | torch.manual_seed(seed)
161 | np.random.seed(seed)
162 |
163 | # Init Phase
164 | phase = Phases[opt.phase]
165 |
166 | return conf, device, phase
167 |
168 |
169 | def compute_confusion_matrix(precited, expected):
170 | part = precited ^ expected
171 | pcount = np.bincount(part)
172 | tp_list = list(precited & expected)
173 | fp_list = list(precited & ~expected)
174 | tp = tp_list.count(1)
175 | fp = fp_list.count(1)
176 | tn = pcount[0] - tp
177 | if len(pcount)==2: fn = pcount[1] - fp
178 | else: fn = 0
179 | return tp, fp, tn, fn
180 |
181 |
182 | def compute_indexes(tp, fp, tn, fn):
183 | try:
184 | accuracy = (tp+tn) / (tp+tn+fp+fn)
185 | except(ZeroDivisionError):
186 | print("ZeroDivisionError: division by zero")
187 | accuracy = np.zeros_like(tp)
188 |
189 | try:
190 | precision = tp / (tp+fp)
191 | except(ZeroDivisionError):
192 | print("ZeroDivisionError: division by zero")
193 | precision = np.zeros_like(tp)
194 |
195 | try:
196 | recall = tp / (tp+fn)
197 | except(ZeroDivisionError):
198 | print("ZeroDivisionError: division by zero")
199 | recall = np.zeros_like(tp)
200 |
201 | try:
202 | F1 = (2*precision*recall) / (precision+recall) # F1
203 | except(ZeroDivisionError):
204 | print("ZeroDivisionError: division by zero")
205 | F1 = np.zeros_like(tp)
206 |
207 | return accuracy, precision, recall, F1
--------------------------------------------------------------------------------
/code/re_ba.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import geo_utils, ba_functions, general_utils
3 | import os
4 | import pandas as pd
5 | from scipy.sparse import coo_matrix
6 |
7 |
8 | # get Ps_gt
9 | # m,3,3 m,3,3 m,3
10 | def get_Ps_gt(K_gt, R_gt, T_gt):
11 | m = R_gt.shape[0]
12 | K_RT = np.matmul(K_gt, R_gt.transpose((0, 2, 1))) # K@R.T
13 | Is = np.expand_dims(np.identity(3), 0).repeat(m, axis = 0)
14 | tmp = np.concatenate((Is, -np.expand_dims(T_gt, 2)), axis = 2) # I|-t
15 | Ps_gt = np.matmul(K_RT, tmp)
16 | return Ps_gt
17 |
18 |
19 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined):
20 | outputs = {}
21 |
22 | scan_name = npzpath.split('/')[-1].split('.')[0]
23 | file = dict(np.load(npzpath))
24 | M_row = file['M_row']
25 | M_col = file['M_col']
26 | M_data = file['M_data']
27 | M_shape = file['M_shape']
28 | mask_row = file['mask_row']
29 | mask_col = file['mask_col']
30 | mask_data = file['mask_data']
31 | mask_shape = file['mask_shape']
32 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A
33 | mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A
34 |
35 | enu = file['enu']
36 | Ns = file['Ns']
37 | Rs = file['Rs']
38 | ts = file['ts']
39 | quats = file['quats']
40 | color = file['rgbs']
41 | Ks = np.linalg.inv(Ns)
42 | Ps = get_Ps_gt(Ks, Rs, enu)
43 |
44 | M_gt = M*mask
45 | colidx = (mask!=0).sum(axis=0)>2
46 | M_gt = M_gt[:, colidx]
47 | color_gt = color[:, colidx]
48 | xs_gt = geo_utils.M_to_xs(M_gt)
49 |
50 | xs_raw = geo_utils.M_to_xs(M)
51 | Xs_gt = geo_utils.n_view_triangulation(Ps, M=M_gt, Ns=Ns)
52 | Xs_raw = geo_utils.n_view_triangulation(Ps, M=M, Ns=Ns)
53 |
54 | Rs_error_b, ts_error_b = geo_utils.tranlsation_rotation_errors(Rs, ts, Rs, enu)
55 | print(f"before BA, Rs_error={Rs_error_b}, ts_error={ts_error_b}")
56 |
57 | ba_res = ba_functions.euc_ba(xs_gt, xs_raw, Rs=Rs, ts=enu, Ks=Ks, Xs=Xs_gt.T, Ps=Ps, Ns=Ns,
58 | repeat=repeat, max_iter=max_iter, ba_times=ba_times,
59 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs
60 | outputs['Rs_ba'] = ba_res['Rs']
61 | outputs['ts_ba'] = ba_res['ts']
62 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n
63 | outputs['Ps_ba'] = ba_res['Ps']
64 | if refined:
65 | outputs['valid_points'] = ba_res['valid_points']
66 | outputs['new_xs'] = ba_res['new_xs']
67 | outputs['colidx'] = ba_res['colidx']
68 | # M_new = geo_utils.xs_to_M(outputs['new_xs'])
69 | # mask_new = np.array((M_new!=0), dtype=np.int8)
70 | # colidx2 = (M_new!=0).sum(axis=0)>2
71 | # M_new = M_new[:,colidx2]
72 |
73 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs, ba_res['ts'], enu,
74 | return_alignment=True) # Align Rs_fixed, tx_fixed
75 | outputs['Rs_ba_fixed'] = R_ba_fixed
76 | outputs['ts_ba_fixed'] = t_ba_fixed
77 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba'])
78 |
79 | Rs_error_a, ts_error_a = geo_utils.tranlsation_rotation_errors(R_ba_fixed, t_ba_fixed, Rs, enu)
80 | print(f"after BA, Rs_error={Rs_error_a}, ts_error={ts_error_a}")
81 | np.savetxt(ts_path, t_ba_fixed, delimiter=',')
82 | np.savetxt(enu_path, enu, delimiter=',')
83 | input()
84 | file.update(outputs)
85 | # np.savez(npzpath, **file)
86 |
87 | return file, scan_name
88 |
89 |
90 | def _compute_errors(outputs, refined=False):
91 | model_errors = {}
92 |
93 | premask = outputs['premask']
94 | gtmask = outputs['gtmask']
95 |
96 | tp, fp, tn, fn = general_utils.compute_confusion_matrix(premask, gtmask)
97 | accuracy, precision, recall, F1 = general_utils.compute_indexes(tp, fp, tn, fn)
98 | model_errors["TP"] = tp
99 | model_errors["FP"] = fp
100 | model_errors["TN"] = tn
101 | model_errors["FN"] = fn
102 | model_errors["Accuracy"] = accuracy
103 | model_errors["Precision"] = precision
104 | model_errors["Recall"] = recall
105 | model_errors["F1"] = F1
106 |
107 | pts3D_pred = outputs['pts3D_pred']
108 | Ps = outputs['Ps']
109 | Rs_fixed = outputs['Rs']
110 | ts_fixed = outputs['ts']
111 | Rs_gt = outputs['Rs_gt']
112 | ts_gt = outputs['ts_gt']
113 | xs = outputs['xs']
114 | Xs_ba = outputs['Xs_ba']
115 | Ps_ba = outputs['Ps_ba']
116 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs)
117 | if not our_repro_error.shape: return model_errors
118 | model_errors["our_repro"] = np.nanmean(our_repro_error)
119 | model_errors["our_repro_max"] = np.nanmax(our_repro_error)
120 |
121 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt)
122 | model_errors["ts_mean"] = np.mean(ts_error)
123 | model_errors["ts_med"] = np.median(ts_error)
124 | model_errors["ts_max"] = np.max(ts_error)
125 | model_errors["Rs_mean"] = np.mean(Rs_error)
126 | model_errors["Rs_med"] = np.median(Rs_error)
127 | model_errors["Rs_max"] = np.max(Rs_error)
128 |
129 | if refined:
130 | valid_points = outputs['valid_points']
131 | new_xs = outputs['new_xs']
132 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points)
133 | else:
134 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs)
135 | model_errors['repro_ba'] = np.nanmean(repro_ba_error)
136 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error)
137 |
138 | Rs_ba_fixed = outputs['Rs_ba_fixed']
139 | ts_ba_fixed = outputs['ts_ba_fixed']
140 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt)
141 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error)
142 | model_errors["ts_ba_med"] = np.median(ts_ba_error)
143 | model_errors["ts_ba_max"] = np.max(ts_ba_error)
144 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error)
145 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error)
146 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error)
147 |
148 | return model_errors
149 |
150 |
151 | def runba():
152 | basedir = ""
153 | refined = True
154 | repeat = True
155 | max_iter = 50
156 | ba_times = 1
157 | repro_thre = 2
158 |
159 | errors_list = []
160 | for _,_,files in os.walk(basedir):
161 | for f in files:
162 | print(f'processing {f} ...')
163 | npzpath = os.path.join(basedir, f)
164 | outputs, scan_name = _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined)
165 | errors = _compute_errors(outputs, refined)
166 | errors['Scene'] = scan_name
167 | errors_list.append(errors)
168 |
169 | df_errors = pd.DataFrame(errors_list)
170 | mean_errors = df_errors.mean(numeric_only=True)
171 | df_errors = df_errors.append(mean_errors, ignore_index=True)
172 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean"
173 | df_errors.set_index("Scene", inplace=True)
174 | df_errors = df_errors.round(3)
175 | print(df_errors.to_string(), flush=True)
176 |
177 | if __name__=="__main__":
178 | runba()
--------------------------------------------------------------------------------
/code/datasets/Euclidean.py:
--------------------------------------------------------------------------------
1 | import cv2 # Do not remove
2 | import torch
3 |
4 | import sys
5 | sys.path.append('/path/to/deepaat/code')
6 | import utils.path_utils
7 | from utils import geo_utils, general_utils, dataset_utils
8 | import scipy.io as sio
9 | import numpy as np
10 | import os.path
11 | from scipy.sparse import coo_matrix
12 |
13 |
14 | def get_raw_data(conf, scan, flag):
15 | """
16 | :param conf:
17 | :return:
18 | M - Points Matrix (2mxn)
19 | Ns - Inversed Calibration matrix (Ks-1) (mx3x3)
20 | Ps_gt - GT projection matrices (mx3x4)
21 | NBs - Normzlize Bifocal Tensor (En) (3mx3m)
22 | mask - Inlier points mask (2mxn)
23 | triplets
24 | """
25 |
26 | # Init
27 | # dataset_path_format = os.path.join(utils.path_utils.path_to_datasets(), 'Euclidean', '{}.npz')
28 | # dataset_path_format = os.path.join(conf.get_string('dataset.dataset_path'), '{}.npz')
29 | if flag==0:
30 | dataset_path_format = os.path.join(conf.get_string('dataset.trainset_path'), '{}.npz')
31 | elif flag==1:
32 | dataset_path_format = os.path.join(conf.get_string('dataset.valset_path'), '{}.npz')
33 | else:
34 | dataset_path_format = os.path.join(conf.get_string('dataset.testset_path'), '{}.npz')
35 |
36 | # Get conf parameters
37 | if scan is None:
38 | scan = conf.get_string('dataset.scan')
39 |
40 | # Get raw data
41 | dataset = np.load(dataset_path_format.format(scan))
42 |
43 | # Get bifocal tensors and 2D points
44 | # M = dataset['M']
45 | # mask = dataset['mask']
46 | Rs = dataset['Rs']
47 | ts = dataset['ts']
48 | # ts = dataset['enu']
49 | quats = dataset['quats']
50 | Ns = dataset['Ns']
51 | M_col = dataset['M_col']
52 | M_row = dataset['M_row']
53 | M_data = dataset['M_data']
54 | M_shape = dataset['M_shape']
55 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A
56 | mask_col = dataset['mask_col']
57 | mask_row = dataset['mask_row']
58 | mask_data = dataset['mask_data']
59 | mask_shape = dataset['mask_shape']
60 | mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A
61 | # M_gt = dataset_utils.correct_matches_global(M, Ps_gt, Ns)
62 | # mask = dataset_utils.get_mask_by_reproj(M, M_gt, 2)
63 | gpss = dataset['enu_noisy']
64 | # gpss = dataset['enu']
65 |
66 | scale = 1.0
67 | if conf.get_bool('train.train_trans'):
68 | t0 = ts[0]
69 | nts = ts-t0
70 | scale = np.max(np.linalg.norm(nts, axis=1))
71 | ts = nts/scale
72 | gpss = (gpss-t0)/scale
73 |
74 | # use_gt = conf.get_bool('dataset.use_gt')
75 | # if use_gt:
76 | # M = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float()
77 |
78 | use_spatial_encoder = conf.get_bool('dataset.use_spatial_encoder')
79 |
80 | if flag==0: # shuffle row and col of train set
81 | indices = torch.randperm(Ns.shape[0])
82 | M_indices = torch.zeros(len(indices)*2, dtype=torch.int64)
83 | M_indices[::2] = 2 * indices
84 | M_indices[1::2] = 2 * indices + 1
85 | M = torch.from_numpy(M).float()[M_indices]
86 | Rs = torch.from_numpy(Rs).float()[indices]
87 | ts = torch.from_numpy(ts).float()[indices]
88 | quats = torch.from_numpy(quats).float()[indices]
89 | Ns = torch.from_numpy(Ns).float()[indices]
90 | mask = torch.from_numpy(mask).float()[M_indices]
91 | gpss = torch.from_numpy(gpss).float()[indices]
92 | # shuffle column
93 | idx = torch.randperm(M.shape[1])
94 | M = M[:,idx]
95 | mask = mask[:,idx]
96 | color = None
97 |
98 | if use_spatial_encoder:
99 | dsc_idx_np = dataset['D_idxs']
100 | dsc_data_np = dataset['D_data']
101 | dsc_shape_np = dataset['D_shape']
102 | ord_data = np.arange(dsc_idx_np.shape[1]) + 1
103 | ord_shape = (dsc_shape_np[0], dsc_shape_np[1])
104 | ord_dense = coo_matrix((ord_data, (dsc_idx_np[0], dsc_idx_np[1])), shape=ord_shape).todense().A
105 | ord_dense = ord_dense[indices,:]
106 | ord_dense = ord_dense[:,idx]
107 | ord_sparse = coo_matrix(ord_dense)
108 | ord_new = ord_sparse.data - 1
109 | dsc_idx = torch.from_numpy(np.vstack((ord_sparse.row, ord_sparse.col))).int()
110 | dsc_data = torch.from_numpy(dsc_data_np[ord_new, :]/128.0).float() # nnz,128
111 | dsc_shape = torch.from_numpy(dsc_shape_np).int()
112 | else:
113 | dsc_idx = None
114 | dsc_data = None
115 | dsc_shape = None
116 |
117 | else:
118 | M = torch.from_numpy(M).float()
119 | Rs = torch.from_numpy(Rs).float()
120 | ts = torch.from_numpy(ts).float()
121 | quats = torch.from_numpy(quats).float()
122 | Ns = torch.from_numpy(Ns).float()
123 | mask = torch.from_numpy(mask).float()
124 | gpss = torch.from_numpy(gpss).float()
125 | color = torch.from_numpy(dataset['rgbs']).int()
126 |
127 | if use_spatial_encoder:
128 | dsc_idx = torch.from_numpy(dataset['D_idxs']).int()
129 | dsc_data = torch.from_numpy(dataset['D_data']/128.0).float()
130 | dsc_shape = torch.from_numpy(dataset['D_shape']).int()
131 | else:
132 | dsc_idx = None
133 | dsc_data = None
134 | dsc_shape = None
135 |
136 | # Add Noise
137 | if conf.get_bool("dataset.addNoise"):
138 | noise_mean = conf.get_float("dataset.noise_mean")
139 | noise_std = conf.get_float("dataset.noise_std")
140 | noise_radio = conf.get_float("dataset.noise_radio")
141 | M = geo_utils.addNoise(M, noise_mean, noise_std, noise_radio)
142 | # dsc_data = geo_utils.addNoise(dsc_data, noise_mean, noise_std, noise_radio)
143 |
144 | return M, Ns, Rs, ts, quats, mask, gpss, color, scale, use_spatial_encoder, dsc_idx, dsc_data, dsc_shape
145 |
146 |
147 | def test_Ps_M(Ps, M, Ns):
148 | global_rep_err = geo_utils.calc_global_reprojection_error(Ps.numpy(), M.numpy(), Ns.numpy())
149 | print("Reprojection Error: Mean = {}, Max = {}".format(np.nanmean(global_rep_err), np.nanmax(global_rep_err)))
150 |
151 |
152 | def test_euclidean_dataset(scan):
153 | # dataset_path_format = os.path.join(utils.path_utils.path_to_datasets(), 'Euclidean', '{}.npz')
154 |
155 | # # Get raw data
156 | # dataset = np.load(dataset_path_format.format(scan))
157 | dataset = np.load("/path/to/data.npz")
158 |
159 | # Get bifocal tensors and 2D points
160 | M = dataset['M']
161 | # M_col = dataset['M_col']
162 | # M_row = dataset['M_row']
163 | # M_data = dataset['M_data']
164 | # M_shape = dataset['M_shape']
165 | # M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A
166 | # mask_col = dataset['mask_col']
167 | # mask_row = dataset['mask_row']
168 | # mask_data = dataset['mask_data']
169 | # mask_shape = dataset['mask_shape']
170 | # mask = coo_matrix((mask_data, (mask_row, mask_col)), shape=mask_shape).todense().A
171 | Ps_gt = dataset['Ps_gt']
172 | Ns = dataset['Ns']
173 |
174 | print(M.shape)
175 | M_gt = torch.from_numpy(dataset_utils.correct_matches_global(M, Ps_gt, Ns)).float()
176 | print(M_gt.shape)
177 |
178 | M = torch.from_numpy(M).float()
179 | Ps_gt = torch.from_numpy(Ps_gt).float()
180 | Ns = torch.from_numpy(Ns).float()
181 |
182 | print("Test Ps and M")
183 | test_Ps_M(Ps_gt, M, Ns)
184 |
185 | print("Test Ps and M_gt")
186 | test_Ps_M(Ps_gt, M_gt, Ns)
187 |
188 |
189 | if __name__ == "__main__":
190 | scan = "Alcatraz Courtyard"
191 | test_euclidean_dataset(scan)
--------------------------------------------------------------------------------
/code/models/layers.py:
--------------------------------------------------------------------------------
1 | from turtle import forward
2 | from cv2 import norm
3 | import torch
4 | from torch.nn import Linear, ReLU, BatchNorm1d, Sequential, Module, Identity, Sigmoid, Tanh
5 | from utils.sparse_utils import SparseMat
6 | from utils.pos_enc_utils import get_embedder
7 |
8 |
9 | def get_linear_layers(feats, final_layer=False, batchnorm=True):
10 | layers = []
11 |
12 | # feats = 256*2 + out_dim
13 | # Add layers
14 | for i in range(len(feats) - 2):
15 | layers.append(Linear(feats[i], feats[i + 1]))
16 |
17 | if batchnorm:
18 | layers.append(BatchNorm1d(feats[i + 1], track_running_stats=False))
19 |
20 | layers.append(ReLU())
21 |
22 | # Add final layer
23 | layers.append(Linear(feats[-2], feats[-1]))
24 | if not final_layer:
25 | if batchnorm:
26 | layers.append(BatchNorm1d(feats[-1], track_running_stats=False))
27 |
28 | layers.append(ReLU())
29 |
30 | return Sequential(*layers)
31 |
32 |
33 | class Parameter3DPts(torch.nn.Module):
34 | def __init__(self, n_pts):
35 | super().__init__()
36 |
37 | # Init points randomly
38 | pts_3d = torch.normal(mean=0, std=0.1, size=(3, n_pts), requires_grad=True)
39 |
40 | self.pts_3d = torch.nn.Parameter(pts_3d)
41 |
42 | def forward(self):
43 | return self.pts_3d
44 |
45 |
46 | class SetOfSetLayer(Module):
47 | def __init__(self, d_in, d_out):
48 | super(SetOfSetLayer, self).__init__()
49 | # n is the number of points and m is the number of cameras
50 | self.lin_all = Linear(d_in, d_out) # w1
51 | self.lin_n = Linear(d_in, d_out) # w2
52 | self.lin_m = Linear(d_in, d_out) # w3
53 | self.lin_both = Linear(d_in, d_out) # w4
54 |
55 | def forward(self, x):
56 | # x is [m,n,d] sparse matrix
57 | out_all = self.lin_all(x.values) # [nnz,d_in] -> [nnz,d_out]
58 |
59 | mean_rows = x.mean(dim=0) # [m,n,d_in] -> [n,d_in]
60 | out_rows = self.lin_n(mean_rows) # [n,d_in] -> [n,d_out]
61 |
62 | mean_cols = x.mean(dim=1) # [m,n,d_in] -> [m,d_in]
63 | out_cols = self.lin_m(mean_cols) # [m,d_in] -> [m,d_out]
64 |
65 | out_both = self.lin_both(x.values.mean(dim=0, keepdim=True)) # [1,d_in] -> [1,d_out]
66 |
67 | new_features = (out_all + out_rows[x.indices[1], :] + out_cols[x.indices[0], :] + out_both) / 4 # [nnz,d_out]
68 | new_shape = (x.shape[0], x.shape[1], new_features.shape[1])
69 |
70 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
71 |
72 |
73 | class ProjLayer(Module):
74 | def __init__(self, d_in, d_out):
75 | super(ProjLayer, self).__init__()
76 | # n is the number of points and m is the number of cameras
77 | self.lin_all = Linear(d_in, d_out)
78 |
79 | def forward(self, x):
80 | # x is [m,n,d] sparse matrix
81 | new_features = self.lin_all(x.values) # [nnz,d_in] -> [nnz,d_out]
82 | new_shape = (x.shape[0], x.shape[1], new_features.shape[1])
83 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
84 |
85 |
86 | class NormalizationLayer(Module):
87 | def forward(self, x):
88 | features = x.values
89 | norm_features = features - features.mean(dim=0, keepdim=True)
90 | norm_features = norm_features / norm_features.std(dim=0, keepdim=True)
91 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
92 |
93 |
94 | class CenterNorm(Module):
95 | def forward(self, x):
96 | features = x.values
97 | norm_features = features - features.mean(dim=0, keepdim=True)
98 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
99 |
100 |
101 | class LastNorm(Module):
102 | def forward(self, x):
103 | features = x.values
104 | norm_features = features - features.mean(dim=0, keepdim=True)
105 | maxnum = torch.max(norm_features.max(), -norm_features.min())
106 | norm_features = norm_features*5/maxnum
107 | return SparseMat(norm_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
108 |
109 |
110 | class RCNormLayer(Module):
111 | def forward(self, x): # x.shape = m,n,d
112 | # dim0
113 | mean0 = x.mean(dim=0) # n,d
114 | meandif0 = x.values - mean0[x.indices[1], :] # nnz,d
115 | # dim1
116 | mean1 = x.mean(dim=1) # m,d
117 | meandif1 = x.values - mean1[x.indices[0], :] # nnz,d
118 | # output
119 | new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d
120 | new_vals = torch.cat([meandif0, meandif1], -1) # m,n,2d
121 | return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
122 |
123 |
124 |
125 | class InstanceNormLayer(Module):
126 | # def forward(self, x): # x.shape = m,n,d
127 | # n_features = x.shape[2] # d
128 | # # dim0
129 | # mean0 = x.mean(dim=0) # n,d
130 | # meandif0 = x.values - mean0[x.indices[1], :] # nnz,d
131 | # sum0 = torch.zeros(x.shape[1], n_features, device=x.device) # n,d
132 | # sum0.index_add(0, x.indices[1], meandif0.pow(2.0))
133 | # norm0 = (sum0 / x.cam_per_pts).pow(0.5) # n,d
134 | # normx0 = meandif0 / norm0[x.indices[1], :] # nnz,d
135 | # # dim1
136 | # mean1 = x.mean(dim=1) # m,d
137 | # meandif1 = x.values - mean1[x.indices[0], :] # nnz,d
138 | # sum1 = torch.zeros(x.shape[0], n_features, device=x.device) # m,d
139 | # sum1.index_add(0, x.indices[0], meandif1.pow(2.0))
140 | # norm1 = (sum1 / x.pts_per_cam).pow(0.5) # m,d
141 | # normx1 = meandif1 / norm1[x.indices[0], :] # nnz,d
142 | # # output
143 | # new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d
144 | # new_vals = torch.cat([normx0, normx1], -1) # m,n,2d
145 | # return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
146 |
147 | def forward(self, x): # x.shape = m,n,d
148 | # dim0
149 | mean0 = x.mean(dim=0) # n,d
150 | std0 = x.std(dim=0) # n,d
151 | norm0 = (x.values - mean0[x.indices[1], :]) / std0[x.indices[1], :] # nnz,d
152 | # dim1
153 | mean1 = x.mean(dim=1) # n,d
154 | std1 = x.std(dim=1) # n,d
155 | norm1 = (x.values - mean1[x.indices[0], :]) / std1[x.indices[0], :] # nnz,d
156 | # output
157 | new_shape = (x.shape[0], x.shape[1], x.shape[2]*2) # m,n,2d
158 | new_vals = torch.cat([norm0, norm1], -1) # m,n,2d
159 | return SparseMat(new_vals, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
160 |
161 |
162 | class ActivationLayer(Module):
163 | def __init__(self):
164 | super(ActivationLayer, self).__init__()
165 | self.relu = ReLU()
166 |
167 | def forward(self, x):
168 | new_features = self.relu(x.values)
169 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
170 |
171 |
172 | class IdentityLayer(Module):
173 | def forward(self, x):
174 | return x
175 |
176 |
177 | class EmbeddingLayer(Module):
178 | def __init__(self, multires, in_dim):
179 | super(EmbeddingLayer, self).__init__()
180 | if multires > 0:
181 | self.embed, self.d_out = get_embedder(multires, in_dim)
182 | else:
183 | self.embed, self.d_out = (Identity(), in_dim)
184 |
185 | def forward(self, x):
186 | embeded_features = self.embed(x.values)
187 | new_shape = (x.shape[0], x.shape[1], embeded_features.shape[1])
188 | return SparseMat(embeded_features, x.indices, x.cam_per_pts, x.pts_per_cam, new_shape)
189 |
190 |
191 | class PosiEmbedding(Module):
192 | def __init__(self, num_freqs: int, logscale=True):
193 | """
194 | Defines a function that embeds x to (x, sin(2^k x), cos(2^k x), ...)
195 | """
196 | super(PosiEmbedding, self).__init__()
197 |
198 | if logscale:
199 | self.freq_bands = 2 ** torch.linspace(0, num_freqs - 1, num_freqs)
200 | else:
201 | self.freq_bands = torch.linspace(1, 2 ** (num_freqs - 1), num_freqs)
202 |
203 | def forward(self, x: torch.Tensor) -> torch.Tensor:
204 | out = [x]
205 | for freq in self.freq_bands:
206 | out += [torch.sin(freq * x), torch.cos(freq * x)]
207 | return torch.cat(out, -1)
208 |
209 |
210 | class SigmoidScoreLayer(Module):
211 | def __init__(self):
212 | super(SigmoidScoreLayer, self).__init__()
213 | self.sigmoid = Sigmoid()
214 |
215 | def forward(self, x):
216 | new_features = self.sigmoid(x.values)
217 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
218 |
219 |
220 | class TanhScoreLayer(Module):
221 | def __init__(self):
222 | super(TanhScoreLayer, self).__init__()
223 | self.relu = ReLU()
224 | self.tanh = Tanh()
225 |
226 | def forward(self, x):
227 | new_features = self.relu(x.values)
228 | new_features = self.tanh(new_features)
229 | return SparseMat(new_features, x.indices, x.cam_per_pts, x.pts_per_cam, x.shape)
--------------------------------------------------------------------------------
/code/joint_optimization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.sparse import coo_matrix
3 | from utils import geo_utils, ceres_utils, ba_functions
4 | import pandas as pd
5 |
6 | # m,3,3 m,3,3 m,3
7 | def get_Ps_gt(K_gt, R_gt, T_gt):
8 | m = R_gt.shape[0]
9 | K_RT = np.matmul(K_gt, R_gt.transpose((0, 2, 1))) # K@R.T
10 | Is = np.expand_dims(np.identity(3), 0).repeat(m, axis = 0)
11 | tmp = np.concatenate((Is, -np.expand_dims(T_gt, 2)), axis = 2) # I|-t
12 | Ps_gt = np.matmul(K_RT, tmp)
13 | return Ps_gt
14 |
15 |
16 | def loadtxt(classed_txt):
17 | ids_list = []
18 | with open(classed_txt, 'r') as f:
19 | for line in f:
20 | ids0 = list(map(int, line[:-1].strip().split(' ')))
21 | ids = [x-1 for x in ids0]
22 | ids.sort()
23 | ids_list.append(ids)
24 | return ids_list
25 |
26 |
27 | def get_col_ids(cond, minval):
28 | ids = np.arange(cond.shape[1])
29 | return ids[np.sum(cond!=0, axis=0)>=minval]
30 |
31 |
32 | def make_full_scene_by_predRt_rawM(pred_npzs, raw_npzs, full_npz, classed_txt, outpath):
33 |
34 | full_f = np.load(full_npz)
35 | Rs_gt = full_f['newRs']
36 | ts_gt = full_f['newts']
37 | Ns = full_f['Ns']
38 | Ks = np.linalg.inv(Ns)
39 | M_data = full_f['M_data']
40 | M_row = full_f['M_row']
41 | M_col = full_f['M_col']
42 | M_shape = full_f['M_shape']
43 | M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A
44 | xs = geo_utils.M_to_xs(M)
45 |
46 | ids_list = loadtxt(classed_txt)
47 | Rs = np.zeros(Rs_gt.shape)
48 | ts = np.zeros(ts_gt.shape)
49 |
50 | scan_name = "ortho1_full"
51 | for i in range(len(pred_npzs)):
52 | fpred = np.load(pred_npzs[i])
53 | fraw = np.load(raw_npzs[i])
54 | ids = ids_list[i]
55 | Rs[ids,:] = fpred['Rs']
56 | ts[ids,:] = fpred['ts'] + fraw['ts'][0]
57 |
58 | Ps = get_Ps_gt(Ks, Rs, ts)
59 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=M, Ns=Ns)
60 |
61 | results = {}
62 | results['scan_name'] = scan_name
63 | results['xs'] = xs
64 | results['Rs'] = Rs
65 | results['ts'] = ts
66 | results['Rs_gt'] = Rs_gt
67 | results['ts_gt'] = ts_gt
68 | results['Ks'] = Ks
69 | results['pts3D_pred'] = pts3D_triangulated.T
70 | results['Ps'] = Ps
71 | results['raw_xs'] = xs
72 | results['raw_color'] = full_f['rgbs']
73 | np.savez(outpath, **results)
74 |
75 |
76 | def make_full_scene_by_abaRtM(aba_npzs, raw_npzs, full_npz, classed_txt, outpath):
77 | print("in make_full_scene_by_abaRtM() now ...")
78 | print("1. get full raw data ...")
79 | full_f = np.load(full_npz)
80 | Rs_gt = full_f['newRs']
81 | ts_gt = full_f['newts']
82 | Ns = full_f['Ns']
83 | Ks = np.linalg.inv(Ns)
84 | M_data = full_f['M_data']
85 | M_row = full_f['M_row']
86 | M_col = full_f['M_col']
87 | M_shape = full_f['M_shape']
88 | raw_M = coo_matrix((M_data, (M_row, M_col)), shape=M_shape).todense().A
89 | raw_xs = geo_utils.M_to_xs(raw_M)
90 | xs = np.zeros((M_shape[0]//2, M_shape[1], 2))
91 |
92 | ids_list = loadtxt(classed_txt)
93 | Rs = np.zeros(Rs_gt.shape)
94 | ts = np.zeros(ts_gt.shape)
95 |
96 | print("2. process splited data ...")
97 | scan_name = "ortho1_full"
98 | for i in range(len(aba_npzs)):
99 | faba = np.load(aba_npzs[i])
100 | fraw = np.load(raw_npzs[i])
101 | camids = ids_list[i]
102 | colids = fraw['pt3d_idx'][faba['colidx']]
103 | xstmp = np.zeros((len(camids), M_shape[1], 2))
104 | xstmp[:,colids,:] = faba['new_xs']
105 | xs[camids, :, :] = xstmp
106 | Rs[camids,:] = faba['Rs_ba_fixed']
107 | ts[camids,:] = faba['ts_ba_fixed'] + fraw['ts'][0]
108 |
109 | M=geo_utils.xs_to_M(xs)
110 | valid_colidx = get_col_ids(M,4)
111 |
112 | Ps = get_Ps_gt(Ks, Rs, ts)
113 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=M[:,valid_colidx], Ns=Ns)
114 |
115 | print("3. write results ...")
116 | results = {}
117 | results['precolor'] = full_f['rgbs'][:, valid_colidx]
118 | results['scan_name'] = scan_name
119 | results['xs'] = xs[:,valid_colidx,:]
120 | results['Rs'] = Rs
121 | results['ts'] = ts
122 | results['Rs_gt'] = Rs_gt
123 | results['ts_gt'] = ts_gt
124 | results['Ks'] = Ks
125 | results['pts3D_pred'] = pts3D_triangulated
126 | results['Ps'] = Ps
127 | results['raw_xs'] = raw_xs
128 | results['raw_color'] = full_f['rgbs']
129 | np.savez(outpath, **results)
130 |
131 |
132 | def _runba(npzpath, repeat, max_iter, ba_times, repro_thre, refined, proj_first, proj_second):
133 | print("in _runba() now ...")
134 | outputs = {}
135 |
136 | print("1. load data ...")
137 | file = dict(np.load(npzpath))
138 | scan_name = file['scan_name']
139 | xs = file['xs']
140 | Rs_pred = file['Rs']
141 | ts_pred = file['ts']
142 | Rs_gt = file['Rs_gt']
143 | ts_gt = file['ts_gt']
144 | Ks = file['Ks']
145 | Ns = np.linalg.inv(Ks)
146 | # file['pts3D_pred'] = file['pts3D_pred'].T
147 | Xs = file['pts3D_pred']
148 | Ps = file['Ps']
149 | raw_xs = file['raw_xs']
150 |
151 | outputs['xs'] = xs
152 | outputs['Rs_gt'] = Rs_gt
153 | outputs['ts_gt'] = ts_gt
154 |
155 | print("2. ba now ...")
156 | ba_res = ba_functions.merged_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=Ks, Xs=Xs.T, Ps=Ps, Ns=Ns,
157 | repeat=repeat, max_iter=max_iter, ba_times=ba_times,
158 | repro_thre=repro_thre, refined=refined, proj_first=proj_first, proj_second=proj_second) # Rs, ts, Ps, Xs
159 | outputs['Rs_ba'] = ba_res['Rs']
160 | outputs['ts_ba'] = ba_res['ts']
161 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n
162 | outputs['Ps_ba'] = ba_res['Ps']
163 | if refined:
164 | outputs['valid_points'] = ba_res['valid_points']
165 | outputs['new_xs'] = ba_res['new_xs']
166 | outputs['colidx'] = ba_res['colidx']
167 |
168 | print("3. align cams ...")
169 | R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt,
170 | return_alignment=True) # Align Rs_fixed, tx_fixed
171 | outputs['Rs_ba_fixed'] = R_ba_fixed
172 | outputs['ts_ba_fixed'] = t_ba_fixed
173 | outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba'])
174 | # outputs['Rs_ba_fixed'] = ba_res['Rs']
175 | # outputs['ts_ba_fixed'] = ba_res['ts']
176 | # outputs['Xs_ba_fixed'] = ba_res['Xs'].T
177 |
178 | print("4. save results ...")
179 | file.update(outputs)
180 | np.savez(npzpath, **file)
181 |
182 | return file, scan_name
183 |
184 |
185 | def _compute_errors(outputs, results_file_path, refined=False):
186 | print("in _compute_errors() now ...")
187 | model_errors = {}
188 |
189 | pts3D_pred = outputs['pts3D_pred']
190 | Ps = outputs['Ps']
191 | Rs_fixed = outputs['Rs']
192 | ts_fixed = outputs['ts']
193 | Rs_gt = outputs['Rs_gt']
194 | ts_gt = outputs['ts_gt']
195 | xs = outputs['xs']
196 | Xs_ba = outputs['Xs_ba']
197 | Ps_ba = outputs['Ps_ba']
198 | our_repro_error = geo_utils.reprojection_error_with_points(Ps, pts3D_pred.T, xs)
199 | if not our_repro_error.shape: return model_errors
200 | model_errors["our_repro"] = np.nanmean(our_repro_error)
201 | model_errors["our_repro_max"] = np.nanmax(our_repro_error)
202 |
203 | Rs_error, ts_error = geo_utils.tranlsation_rotation_errors(Rs_fixed, ts_fixed, Rs_gt, ts_gt)
204 | model_errors["ts_mean"] = np.mean(ts_error)
205 | model_errors["ts_med"] = np.median(ts_error)
206 | model_errors["ts_max"] = np.max(ts_error)
207 | model_errors["Rs_mean"] = np.mean(Rs_error)
208 | model_errors["Rs_med"] = np.median(Rs_error)
209 | model_errors["Rs_max"] = np.max(Rs_error)
210 |
211 | if refined:
212 | valid_points = outputs['valid_points']
213 | new_xs = outputs['new_xs']
214 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, new_xs, visible_points=valid_points)
215 | else:
216 | repro_ba_error = geo_utils.reprojection_error_with_points(Ps_ba, Xs_ba.T, xs)
217 | model_errors['repro_ba'] = np.nanmean(repro_ba_error)
218 | model_errors['repro_ba_max'] = np.nanmax(repro_ba_error)
219 |
220 | Rs_ba_fixed = outputs['Rs_ba_fixed']
221 | ts_ba_fixed = outputs['ts_ba_fixed']
222 | Rs_ba_error, ts_ba_error = geo_utils.tranlsation_rotation_errors(Rs_ba_fixed, ts_ba_fixed, Rs_gt, ts_gt)
223 | model_errors["ts_ba_mean"] = np.mean(ts_ba_error)
224 | model_errors["ts_ba_med"] = np.median(ts_ba_error)
225 | model_errors["ts_ba_max"] = np.max(ts_ba_error)
226 | model_errors["Rs_ba_mean"] = np.mean(Rs_ba_error)
227 | model_errors["Rs_ba_med"] = np.median(Rs_ba_error)
228 | model_errors["Rs_ba_max"] = np.max(Rs_ba_error)
229 |
230 | errors_list = []
231 | errors_list.append(model_errors)
232 | df_errors = pd.DataFrame(errors_list)
233 | mean_errors = df_errors.mean(numeric_only=True)
234 | # df_errors = pd.concat([df_errors,mean_errors], axis=0, ignore_index=True)
235 | df_errors = df_errors.append(mean_errors, ignore_index=True)
236 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean"
237 | df_errors.set_index("Scene", inplace=True)
238 | # df_errors = df_errors.round(3)
239 | print(df_errors.to_string(), flush=True)
240 | df_errors.to_excel(results_file_path)
241 |
242 |
243 | if __name__=="__main__":
244 |
245 | refined = True
246 | repeat = True
247 | max_iter = 50
248 | ba_times = 2
249 | repro_thre = 2
250 | proj_first = True
251 | proj_second = True
252 | outputs, scan_name = _runba(out_npz_path, repeat, max_iter, ba_times, repro_thre, refined, proj_first, proj_second)
253 | _compute_errors(outputs, out_xlsx_path, refined=refined)
254 |
--------------------------------------------------------------------------------
/code/utils/ba_functions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from utils import ceres_utils
3 | from utils import geo_utils
4 |
5 |
6 | def euc_ba(xs, raw_xs, Rs, ts, Ks, Xs, Ps=None, Ns=None, repeat=True, max_iter=100, ba_times=2, repro_thre=5, refined=False):
7 | """
8 | Computes bundle adjustment with ceres solver
9 | :param xs: 2d points [m,n,2]
10 | :param Rs: rotations [m,3,3]
11 | :param ts: translations [m,3]
12 | :param Ks: inner parameters, calibration matrices [m,3,3]
13 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed
14 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]]
15 | :param Ns: normalization matrices. If Ks are known, Ns = inv(Ks)
16 | :param repeat: run ba twice. default: True
17 | :param triangulation: For initial point run triangulation. default: False
18 | :param return_repro: compute and return the reprojection errors before and after.
19 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors.
20 | """
21 | results = {}
22 |
23 | visible_points = xs[:, :, 0] != 0 # m,3dptnum
24 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum
25 | visible_xs = xs[visible_points] # 2dptnum,2
26 |
27 | # if Ps is None:
28 | # Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs, ts, Ks)
29 |
30 | # if triangulation:
31 | # if Ns is None:
32 | # Ns = np.linalg.inv(Ks)
33 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns)
34 | # Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
35 | # else:
36 | # Xs = Xs_our
37 |
38 | if refined:
39 | new_Rs, new_ts, new_Ps, _, _, _, _, _ = ceres_utils.run_euc_ceres_iter(
40 | Ps, Xs, xs, Rs, ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre)
41 | # new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1)
42 | if repeat:
43 | visible_points = raw_xs[:, :, 0] != 0 # m,3dptnum
44 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum
45 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, raw_xs, Ns)
46 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
47 | proj_first = True
48 | # print(f"new_Xs1: {new_Xs.shape}")
49 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter(
50 | new_Ps, new_Xs, raw_xs, new_Rs, new_ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_first)
51 | else:
52 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres(
53 | Xs, visible_xs, Rs, ts, Ks, point_indices)
54 | if repeat:
55 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns)
56 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
57 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres(
58 | new_Xs, visible_xs, new_Rs, new_ts, Ks, point_indices)
59 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1)
60 |
61 | results['Rs'] = new_Rs
62 | results['ts'] = new_ts
63 | results['Ps'] = new_Ps
64 | results['Xs'] = new_Xs
65 | if refined:
66 | results['valid_points'] = visible_points
67 | results['new_xs'] = new_xs
68 | results['colidx'] = colidx
69 |
70 | return results
71 |
72 |
73 | def merged_ba(xs, raw_xs, Rs, ts, Ks, Xs, Ps=None, Ns=None, repeat=True, max_iter=100, ba_times=2, repro_thre=5, refined=False,
74 | proj_first=False, proj_second=True):
75 | """
76 | Computes bundle adjustment with ceres solver
77 | :param xs: 2d points [m,n,2]
78 | :param Rs: rotations [m,3,3]
79 | :param ts: translations [m,3]
80 | :param Ks: inner parameters, calibration matrices [m,3,3]
81 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed
82 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]]
83 | :param Ns: normalization matrices. If Ks are known, Ns = inv(Ks)
84 | :param repeat: run ba twice. default: True
85 | :param triangulation: For initial point run triangulation. default: False
86 | :param return_repro: compute and return the reprojection errors before and after.
87 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors.
88 | """
89 | results = {}
90 |
91 | visible_points = xs[:, :, 0] != 0 # m,3dptnum
92 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum
93 |
94 | # if Ps is None:
95 | # Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs, ts, Ks)
96 |
97 | # if triangulation:
98 | # if Ns is None:
99 | # Ns = np.linalg.inv(Ks)
100 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns)
101 | # Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
102 | # else:
103 | # Xs = Xs_our
104 |
105 | if refined:
106 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter(
107 | Ps, Xs, xs, Rs, ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_first)
108 | # new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1)
109 | if repeat:
110 | visible_points = raw_xs[:, :, 0] != 0 # m,3dptnum
111 | point_indices = np.stack(np.where(visible_points)) # 2,2dptnum
112 | # new_Ps = geo_utils.batch_get_camera_matrix_from_rtk(new_Rs, new_ts, Ks)
113 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, raw_xs, Ns)
114 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
115 | # print(f"new_Xs1: {new_Xs.shape}")
116 | new_Rs, new_ts, new_Ps, new_Xs, new_xs, visible_points, point_indices, colidx = ceres_utils.run_euc_ceres_iter(
117 | new_Ps, new_Xs, raw_xs, new_Rs, new_ts, Ks, point_indices, visible_points, max_iter, ba_times, repro_thre, proj_second)
118 | else:
119 | visible_xs = xs[visible_points] # 2dptnum,2
120 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres(
121 | Xs, visible_xs, Rs, ts, Ks, point_indices)
122 | if repeat:
123 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns)
124 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
125 | new_Rs, new_ts, new_Ps, new_Xs = ceres_utils.run_euclidean_python_ceres(
126 | new_Xs, visible_xs, new_Rs, new_ts, Ks, point_indices)
127 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1)
128 |
129 | results['Rs'] = new_Rs
130 | results['ts'] = new_ts
131 | results['Ps'] = new_Ps
132 | results['Xs'] = new_Xs
133 | if refined:
134 | results['valid_points'] = visible_points
135 | results['new_xs'] = new_xs
136 | results['colidx'] = colidx
137 |
138 | return results
139 |
140 |
141 | def proj_ba(Ps, xs, Xs_our=None, Ns=None, repeat=True, triangulation=False, return_repro=True,normalize_in_tri=True):
142 | """
143 | Computes bundle adjustment with ceres solve
144 | :param Ps: cameras [m,3,4]. Ps[i] = Ks[i] @ Rs[i].T @ [I, -ts[i]]
145 | :param xs: 2d points [m,n,2]
146 | :param Xs_our: initial 3d points [n,3] or None if triangulation needed
147 | :param Ns: normalization matrices.
148 | :param repeat: run ba twice. default: True
149 | :param triangulation: For initial point run triangulation. default: False
150 | :param return_repro: compute and return the reprojection errors before and after.
151 | :param normalize_in_tri: Normalize the points and the cameras when computing triangulation. default: True
152 | :return: results. The new camera parameters, 3d points, and if requested the reprojection errors.
153 | """
154 | results = {}
155 |
156 | visible_points = xs[:, :, 0] != 0
157 | point_indices = np.stack(np.where(visible_points))
158 | visible_xs = xs[visible_points]
159 |
160 | if triangulation:
161 | if normalize_in_tri:
162 | if Ns is None:
163 | Ns = geo_utils.batch_get_normalization_matrices(xs)
164 | norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns)
165 | Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
166 | else:
167 | Xs = geo_utils.dlt_triangulation(Ps, xs, visible_points)
168 | else:
169 | Xs = Xs_our
170 |
171 | if return_repro:
172 | results['repro_before'] = np.nanmean(geo_utils.reprojection_error_with_points(Ps, Xs, xs, visible_points))
173 |
174 | new_Ps, new_Xs = ceres_utils.run_projective_python_ceres(Ps, Xs, visible_xs, point_indices)
175 |
176 | if repeat:
177 | if return_repro:
178 | results['repro_middle'] = np.nanmean(geo_utils.reprojection_error_with_points(new_Ps, new_Xs, xs, visible_points))
179 |
180 | if normalize_in_tri:
181 | if Ns is None:
182 | Ns = geo_utils.batch_get_normalization_matrices(xs)
183 | norm_P, norm_x = geo_utils.normalize_points_cams(new_Ps, xs, Ns)
184 | new_Xs = geo_utils.dlt_triangulation(norm_P, norm_x, visible_points)
185 | else:
186 | new_Xs = geo_utils.dlt_triangulation(new_Ps, xs, visible_points)
187 |
188 | new_Ps, new_Xs = ceres_utils.run_projective_python_ceres(new_Ps, new_Xs, visible_xs, point_indices)
189 |
190 | if return_repro:
191 | results['repro_after'] = np.nanmean(geo_utils.reprojection_error_with_points(new_Ps, new_Xs, xs, visible_points))
192 |
193 | new_Xs = np.concatenate([new_Xs, np.ones([new_Xs.shape[0],1])], axis=1)
194 |
195 | results['Ps'] = new_Ps
196 | results['Xs'] = new_Xs
197 | return results
198 |
199 |
--------------------------------------------------------------------------------
/code/datasets/SceneData.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import geo_utils, dataset_utils, sparse_utils
3 | from datasets import Projective, Euclidean
4 | import os.path
5 | from pyhocon import ConfigFactory
6 | import numpy as np
7 | import warnings
8 | from pytorch3d import transforms as py3d_trans
9 |
10 |
11 | class SceneData:
12 | def __init__(self, M, Ns, Rs, ts, quats, mask, scan_name, gpss, color, scale=1.0, dilute_M=False,
13 | use_spatial_encoder=True, dsc_idx=None, dsc_data=None, dsc_shape=None):
14 | n_images = Ns.shape[0]
15 |
16 | # Set attribute
17 | self.scan_name = scan_name
18 | # self.y = Ps_gt
19 | self.M = M
20 | self.Ns = Ns
21 | self.Ks = torch.inverse(Ns)
22 | self.mask = mask
23 | self.gpss = gpss
24 |
25 | # get R, t
26 | # Rs_gt, ts_gt = geo_utils.decompose_camera_matrix(Ps_gt, self.Ks)
27 | self.Rs = Rs
28 | self.ts = ts
29 | self.quats = quats
30 | self.color = color
31 | self.scale = scale
32 |
33 | # Dilute M
34 | if dilute_M:
35 | self.M = geo_utils.dilutePoint(M)
36 |
37 | # M to sparse matrix
38 | self.x = dataset_utils.M2sparse(M, normalize=True, Ns=Ns)
39 |
40 | # mask to sparse matrix
41 | mask_trim = mask[::2].unsqueeze(2) # 2m,n -> m,n,1
42 | mask_indices = self.x.indices
43 | mask_values = mask_trim[mask_indices[0], mask_indices[1], :]
44 | mask_shape = (self.x.shape[0], self.x.shape[1], 1)
45 | self.mask_sparse = sparse_utils.SparseMat(mask_values, mask_indices,
46 | self.x.cam_per_pts, self.x.pts_per_cam, mask_shape)
47 |
48 | egps_indices = self.x.indices
49 | egps_values = self.gpss[self.x.indices[0],:]
50 | egps_shape = (self.x.shape[0], self.x.shape[1], 3)
51 | self.egps_sparse = sparse_utils.SparseMat(egps_values, egps_indices,
52 | self.x.cam_per_pts, self.x.pts_per_cam, egps_shape)
53 |
54 | if use_spatial_encoder:
55 | self.dsc = sparse_utils.SparseMat(dsc_data, dsc_idx, self.x.cam_per_pts, self.x.pts_per_cam, dsc_shape)
56 |
57 | # Get valid points
58 | self.valid_pts = dataset_utils.get_M_valid_points(M)
59 |
60 | # Normalize M
61 | self.norm_M = geo_utils.normalize_M(M, Ns, self.valid_pts).transpose(1, 2).reshape(n_images * 2, -1)
62 |
63 |
64 | def to(self, *args, **kwargs):
65 | for key in self.__dict__:
66 | if not key.startswith('__'):
67 | attr = getattr(self, key)
68 | #if not callable(attr) and (isinstance(attr, sparse_utils.SparseMat) or torch.is_tensor(attr)):
69 | if isinstance(attr, sparse_utils.SparseMat) or torch.is_tensor(attr):
70 | setattr(self, key, attr.to(*args, **kwargs))
71 |
72 | return self
73 |
74 |
75 | def create_scene_data(conf, flag):
76 | # Init
77 | scan = conf.get_string('dataset.scan')
78 | calibrated = conf.get_bool('dataset.calibrated')
79 | dilute_M = conf.get_bool('dataset.diluteM', default=False)
80 |
81 | # Get raw data
82 | if calibrated:
83 | M, Ns, Ps_gt, mask, gpss = Euclidean.get_raw_data(conf, scan, flag)
84 | else:
85 | M, Ns, Ps_gt, mask = Projective.get_raw_data(conf, scan)
86 |
87 | return SceneData(M, Ns, Ps_gt, mask, scan, gpss, dilute_M)
88 |
89 |
90 | def sample_data(data, num_samples, adjacent=True):
91 | # Get indices
92 | # indices = dataset_utils.sample_indices(len(data.y), num_samples, adjacent=adjacent)
93 | indices = dataset_utils.radius_sample(data.ts.numpy(), num_samples)
94 | # indices = dataset_utils.simulate_sample(data.gpss.numpy(), num_samples, data.x.pts_per_cam.numpy())
95 | indices, M_indices = dataset_utils.order_indices(indices, shuffle=True)
96 |
97 | indices = torch.from_numpy(indices).squeeze()
98 | M_indices = torch.from_numpy(M_indices).squeeze()
99 |
100 | # Get sampled data
101 | Rs = data.Rs[indices]
102 | ts = data.ts[indices]
103 | quats = data.quats[indices]
104 | Ns = data.Ns[indices]
105 | M = data.M[M_indices]
106 | mask = data.mask[M_indices]
107 | mask = mask[:,(M!=0).sum(dim=0)>2]
108 | M = M[:,(M!=0).sum(dim=0)>2]
109 |
110 | # shuffle column
111 | idx = torch.randperm(M.shape[1])
112 | M = M[:,idx]
113 | mask = mask[:,idx]
114 |
115 | sampled_data = SceneData(M, Ns, Rs, ts, quats, mask, data.scan_name)
116 | if (sampled_data.x.pts_per_cam == 0).any():
117 | warnings.warn('Cameras with no points for dataset '+ data.scan_name)
118 |
119 | return sampled_data
120 |
121 |
122 | # flag=0 means train; flag=1 means val; flag=2 means test
123 | def create_scene_data_from_list(scan_names_list, conf, flag):
124 | data_list = []
125 | for scan_name in scan_names_list:
126 | conf["dataset"]["scan"] = scan_name
127 | data = create_scene_data(conf, flag)
128 | data_list.append(data)
129 |
130 | return data_list
131 |
132 |
133 | def create_scene_data_from_dir(conf, flag):
134 | if flag==0:
135 | datadir = conf.get_string("dataset.trainset_path")
136 | elif flag==1:
137 | datadir = conf.get_string("dataset.valset_path")
138 | else:
139 | datadir = conf.get_string("dataset.testset_path")
140 |
141 | data_list = []
142 | dilute_M = conf.get_bool('dataset.diluteM', default=False)
143 | for _,_,files in os.walk(datadir):
144 | for f in files:
145 | # Get raw data
146 | f = f.split('.')[0] # get name only
147 | M, Ns, Rs, ts, quats, mask = Euclidean.get_raw_data(conf, f, flag)
148 | data = SceneData(M, Ns, Rs, ts, quats, mask, f, dilute_M)
149 | data_list.append(data)
150 | return data_list
151 |
152 |
153 | def get_data_list(conf, flag):
154 | if flag==0:
155 | datadir = conf.get_string("dataset.trainset_path")
156 | elif flag==1:
157 | datadir = conf.get_string("dataset.valset_path")
158 | else:
159 | datadir = conf.get_string("dataset.testset_path")
160 |
161 | data_list = []
162 | for _,_,files in os.walk(datadir):
163 | for f in files:
164 | data_list.append(f.split('.')[0])
165 | return data_list
166 |
167 |
168 | def test_dataset():
169 | # Prepare configuration
170 | dataset_dict = {"images_path": "/home/labs/waic/hodaya/PycharmProjects/GNN-for-SFM/datasets/images/",
171 | "normalize_pts": True,
172 | "normalize_f": True,
173 | "use_gt": False,
174 | "calibrated": False,
175 | "scan": "Alcatraz Courtyard",
176 | "edge_min_inliers": 30,
177 | "use_all_edges": True,
178 | }
179 |
180 | train_dict = {"infinity_pts_margin": 1e-4,
181 | "hinge_loss_weight": 1,
182 | }
183 | loss_dict = {"infinity_pts_margin": 1e-4,
184 | "normalize_grad": False,
185 | "hinge_loss": True,
186 | "hinge_loss_weight" : 1
187 | }
188 | conf_dict = {"dataset": dataset_dict, "loss":loss_dict}
189 |
190 | print("Test projective")
191 | conf = ConfigFactory.from_dict(conf_dict)
192 | data = create_scene_data(conf)
193 | test_data(data, conf)
194 |
195 | print('Test move to device')
196 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
197 | new_data = data.to(device)
198 |
199 | print(os.linesep)
200 | print("Test Euclidean")
201 | conf = ConfigFactory.from_dict(conf_dict)
202 | conf["dataset"]["calibrated"] = True
203 | data = create_scene_data(conf)
204 | test_data(data, conf)
205 |
206 | print(os.linesep)
207 | print("Test use_gt GT")
208 | conf = ConfigFactory.from_dict(conf_dict)
209 | conf["dataset"]["use_gt"] = True
210 | data = create_scene_data(conf)
211 | test_data(data, conf)
212 |
213 |
214 | def test_data(data, conf):
215 | import loss_functions
216 |
217 | # Test Losses of GT and random on data
218 | repLoss = loss_functions.ESFMLoss(conf)
219 | cams_gt = prepare_cameras_for_loss_func(data.y, data)
220 | cams_rand = prepare_cameras_for_loss_func(torch.rand(data.y.shape), data)
221 |
222 | print("Loss for GT: Reprojection = {}".format(repLoss(cams_gt, data)))
223 | print("Loss for rand: Reprojection = {}".format(repLoss(cams_rand, data)))
224 |
225 |
226 | def prepare_cameras_for_loss_func(Ps, data):
227 | Vs_invT = Ps[:, 0:3, 0:3]
228 | Vs = torch.inverse(Vs_invT).transpose(1, 2)
229 | ts = torch.bmm(-Vs.transpose(1, 2), Ps[:, 0:3, 3].unsqueeze(dim=-1)).squeeze()
230 | pts_3D = torch.from_numpy(geo_utils.n_view_triangulation(Ps.numpy(), data.M.numpy(), data.Ns.numpy())).float()
231 | return {"Ps": torch.bmm(data.Ns, Ps), "pts3D": pts_3D}
232 |
233 |
234 | def get_subset(data, subset_size):
235 | # Get subset indices
236 | valid_pts = dataset_utils.get_M_valid_points(data.M)
237 | n_cams = valid_pts.shape[0]
238 |
239 | first_idx = valid_pts.sum(dim=1).argmax().item()
240 | curr_pts = valid_pts[first_idx].clone()
241 | valid_pts[first_idx] = False
242 | indices = [first_idx]
243 |
244 | for i in range(subset_size - 1):
245 | shared_pts = curr_pts.expand(n_cams, -1) & valid_pts
246 | next_idx = shared_pts.sum(dim=1).argmax().item()
247 | curr_pts = curr_pts | valid_pts[next_idx]
248 | valid_pts[next_idx] = False
249 | indices.append(next_idx)
250 |
251 | print("Cameras are:")
252 | print(indices)
253 |
254 | indices = torch.tensor(indices)
255 | M_indices = torch.sort(torch.cat((2 * indices, 2 * indices + 1)))[0]
256 | Rs = data.Rs[indices]
257 | ts = data.ts[indices]
258 | quats = data.quats[indices]
259 | Ns = data.Ns[indices]
260 | M = data.M[M_indices]
261 | M = M[:, (M != 0).sum(dim=0) > 2]
262 | mask = data.mask[M_indices]
263 | mask = mask[:,(M!=0).sum(dim=0)>2]
264 | return SceneData(M, Ns, Rs, ts, quats, mask, data.scan_name + "_{}".format(subset_size))
265 |
266 |
267 | if __name__ == "__main__":
268 | test_dataset()
269 |
270 |
--------------------------------------------------------------------------------
/code/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import plotly
2 | import os
3 | import utils.path_utils
4 | import plotly.express as px
5 | from plotly.subplots import make_subplots
6 | import numpy as np
7 | import torch
8 | import plotly.graph_objects as go
9 | from utils import geo_utils
10 | from matplotlib import image
11 |
12 |
13 | def plot_img_sets_error_bar(err_list, imgs_sets, path, title):
14 | max_val = 20
15 |
16 | # Prepare illegal values
17 | #illegal_idx = (err_list == float("inf"))
18 | illegal_idx = torch.logical_or((err_list > max_val), (err_list == float("inf")))
19 | final_err = err_list.clone()
20 | final_err[illegal_idx] = max_val
21 | colors = np.array(['#636efa', ] * final_err.shape[0])
22 | colors[illegal_idx] = 'crimson'
23 |
24 | colors = colors.tolist()
25 | final_err = final_err.tolist()
26 |
27 | # Create figure
28 | fig = px.bar(x=imgs_sets, y=final_err)
29 | fig.update_xaxes(title='Images sets')
30 | fig.update_yaxes(title=title, range=[0, max_val])
31 | fig.update_traces(marker_color=colors)
32 | fig.update_layout(xaxis_type='category')
33 |
34 | plotly.offline.plot(fig, filename=path)
35 |
36 |
37 | def plot_img_reprojection_error_bar(err_list, img_list):
38 | return go.Bar(x=img_list, y=err_list.tolist())
39 | # max_val = conf.get_int('plot.reproj_err_bar_max', default=20)
40 | #
41 | # # Create figure
42 | # fig = go.Figure([go.Bar(x=img_list.tolist(), y=err_list.tolist())])
43 | # fig.update_xaxes(title='Images')
44 | # fig.update_yaxes(title='Reprojection Error', range=[0, max_val])
45 | # # fig.update_traces(marker_color='crimson')
46 | # fig.update_layout(xaxis_type='category')
47 | #
48 | # path = os.path.join(general_utils.path_to_exp(conf), 'reprojection_err.html')
49 | # plotly.offline.plot(fig, filename=path)
50 |
51 |
52 | def plot_error_per_images_bar(repreoj_err, symetric_epipolar_dist, imgs_sets, conf, sub_name=""):
53 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'reprojection_err' + sub_name)
54 | plot_img_sets_error_bar(repreoj_err, imgs_sets, path, 'Mean Reprojection Error')
55 |
56 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'SymEpDist' + sub_name)
57 | plot_img_sets_error_bar(symetric_epipolar_dist, imgs_sets, path, 'Mean Symmetric Epipolar Distance')
58 |
59 |
60 | def plot_matrix_heatmap(data_matrix, indices, zmax):
61 | mask = data_matrix == 0
62 | mat = data_matrix.clone().numpy()
63 | mat[mask] = None
64 | #fig = px.imshow(mat, x=indices, y=indices)
65 | hm = go.Heatmap(z=mat, x=indices, y=indices, zmin=0, zmax=zmax)
66 |
67 | return hm
68 |
69 |
70 | def plot_heatmaps(repreoj_err, symetric_epipolar_dist, global_reprojection_error, edges, img_list, conf, path=None, static_path=None):
71 | repreoj_err_edges = torch.zeros(repreoj_err.shape)
72 | repreoj_err_edges[edges[0], edges[1]] = repreoj_err[edges[0], edges[1]]
73 |
74 | symetric_epipolar_dist_edges = torch.zeros(symetric_epipolar_dist.shape)
75 | symetric_epipolar_dist_edges[edges[0], edges[1]] = symetric_epipolar_dist[edges[0], edges[1]]
76 |
77 | zmax = conf.get_int('plot.color_bar_max', default=5)
78 | hm_rep_err = plot_matrix_heatmap(repreoj_err, list(map(str, img_list)), zmax)
79 | hm_rep_err_edges = plot_matrix_heatmap(repreoj_err_edges, list(map(str, img_list)), zmax)
80 |
81 | hm_sed = plot_matrix_heatmap(symetric_epipolar_dist, list(map(str, img_list)), zmax)
82 | hm_sed_edges = plot_matrix_heatmap(symetric_epipolar_dist_edges, list(map(str, img_list)), zmax)
83 |
84 | bar_global_rep_err = plot_img_reprojection_error_bar(global_reprojection_error, img_list)
85 |
86 | fig = make_subplots(2, 4, subplot_titles=['Reprojection Error', 'Symmetric Epipolar Distance',
87 | 'Reprojection Error - Triplets', 'Symmetric Epipolar Distance - Triplets',
88 | 'Global Reprojection Error'],
89 | specs=[[{}, {}, {}, {}], [{"colspan": 4}, None, None, None]])
90 |
91 | fig.add_trace(hm_rep_err, 1, 1)
92 | fig.update_xaxes(type='category', title='Image', row=1, col=1)
93 | fig.update_yaxes(type='category', title='Image', row=1, col=1)
94 |
95 | fig.add_trace(hm_sed, 1, 2)
96 | fig.update_xaxes(type='category', title='Image', row=1, col=2)
97 | fig.update_yaxes(type='category', title='Image', row=1, col=2)
98 |
99 | fig.add_trace(hm_rep_err_edges, 1, 3)
100 | fig.update_xaxes(type='category', title='Image', row=1, col=3)
101 | fig.update_yaxes(type='category', title='Image', row=1, col=3)
102 |
103 | fig.add_trace(hm_sed_edges, 1, 4)
104 | fig.update_xaxes(type='category', title='Image', row=1, col=4)
105 | fig.update_yaxes(type='category', title='Image', row=1, col=4)
106 |
107 | fig.add_trace(bar_global_rep_err, 2, 1)
108 | max_rep_err = conf.get_int('plot.reproj_err_bar_max', default=20)
109 | fig.update_xaxes(type='category', title='Image', row=2, col=1)
110 | fig.update_yaxes(title='Reprojection Error', range=[0, max_rep_err], row=2, col=1)
111 |
112 | # fig.update_layout(width=1000)
113 | if path is None:
114 | path = os.path.join(utils.path_utils.path_to_exp(conf), 'errors_heatmap.html')
115 | plotly.offline.plot(fig, filename=path)
116 | if static_path is not None:
117 | fig.write_image(static_path)
118 |
119 |
120 | def plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan, epoch=None, bundle_adjustment=False):
121 | Rs_gt = outputs['Rs_gt']
122 | ts_gt = outputs['ts_gt']
123 |
124 | # Rs_pred = outputs['Rs_fixed']
125 | # ts_pred = outputs['ts_fixed']
126 | # pts3D = outputs['pts3D_pred_fixed'][:3,:]
127 | Rs_pred = outputs['Rs']
128 | ts_pred = outputs['ts']
129 | pts3D = outputs['pts3D_pred'][:3,:]
130 | Rs_error = errors['Rs_mean']
131 | ts_error = errors['ts_mean']
132 | plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=scan, epoch=epoch)
133 |
134 | if bundle_adjustment:
135 | Rs_pred = outputs['Rs_ba_fixed']
136 | ts_pred = outputs['ts_ba_fixed']
137 | pts3D = outputs['Xs_ba_fixed'][:3,:]
138 | Rs_error = errors['Rs_ba_mean']
139 | ts_error = errors['ts_ba_mean']
140 | plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=scan+'_ba', epoch=epoch)
141 |
142 | def get_points_colors(images_path, image_names, xs, first_occurence=False):
143 | m, n, _ = xs.shape
144 | points_colors = np.zeros([n, 3])
145 | if first_occurence:
146 | images_indices = (geo_utils.xs_valid_points(xs)).argmax(axis=0)
147 | unique_images = np.unique(images_indices)
148 | for i, image_ind in enumerate(unique_images):
149 | image_name = str(image_names[image_ind][0]).split('/')[1]
150 | im = image.imread(os.path.join(images_path, image_name))
151 | # read the image to ndarray
152 | points_in_image = np.where(image_ind == images_indices)[0]
153 | for point_ind in points_in_image:
154 | point_2d_in_image = xs[image_ind, point_ind].astype(int)
155 | points_colors[point_ind] = im[point_2d_in_image[1], point_2d_in_image[0]]
156 | else:
157 | valid_points = geo_utils.xs_valid_points(xs)
158 | colors = np.zeros([m, n, 3])
159 | for image_ind in range(m):
160 | image_name = str(image_names[image_ind][0]).split('/')[1]
161 | im = image.imread(os.path.join(images_path, image_name))
162 | points_in_image = np.where(valid_points[image_ind])[0]
163 | for point_ind in points_in_image:
164 | point_2d_in_image = xs[image_ind, point_ind].astype(int)
165 | colors[image_ind, point_ind] = im[point_2d_in_image[1], point_2d_in_image[0]]
166 | for point_ind in range(n):
167 | points_colors[point_ind] = np.mean(colors[valid_points[:, point_ind], point_ind], axis=0)
168 |
169 | return points_colors
170 |
171 | def plot_cameras(Rs_pred, ts_pred, pts3D, Rs_gt, ts_gt, Rs_error, ts_error, conf, phase, scan=None, epoch=None):
172 | data = []
173 | data.append(get_3D_quiver_trace(ts_gt, Rs_gt[:, :3, 2], color='#86CE00', name='cam_gt', cam_size=2))
174 | data.append(get_3D_quiver_trace(ts_pred, Rs_pred[:, :3, 2], color='#C4451C', name='cam_learn', cam_size=2))
175 | # data.append(get_3D_scater_trace(ts_gt.T, color='#86CE00', name='cam_gt', size=2))
176 | # data.append(get_3D_scater_trace(ts_pred.T, color='#C4451C', name='cam_learn', size=2))
177 | data.append(get_3D_scater_trace(pts3D, '#3366CC', '3D points', size=0.5))
178 |
179 | fig = go.Figure(data=data)
180 | fig.update_layout(title='Cameras: Rotation Mean = {:.5f}, Translation Mean = {:.5f}, Points num = {}'.format(Rs_error.mean(), ts_error.mean(), pts3D.shape[1]), showlegend=True)
181 |
182 | path = utils.path_utils.path_to_plots(conf, phase, epoch=epoch, scan=scan)
183 | plotly.offline.plot(fig, filename=path, auto_open=False)
184 |
185 | # return path
186 |
187 |
188 | def get_3D_quiver_trace(points, directions, color='#bd1540', name='', cam_size=1):
189 | assert points.shape[1] == 3, "3d cone plot input points are not correctely shaped "
190 | assert len(points.shape) == 2, "3d cone plot input points are not correctely shaped "
191 | assert directions.shape[1] == 3, "3d cone plot input directions are not correctely shaped "
192 | assert len(directions.shape) == 2, "3d cone plot input directions are not correctely shaped "
193 |
194 | trace = go.Cone(
195 | name=name,
196 | x=points[:, 0],
197 | y=points[:, 1],
198 | z=points[:, 2],
199 | u=directions[:, 0],
200 | v=directions[:, 1],
201 | w=directions[:, 2],
202 | sizemode='absolute',
203 | sizeref=cam_size,
204 | showscale=False,
205 | colorscale=[[0, color], [1, color]],
206 | anchor="tail"
207 | )
208 |
209 | return trace
210 |
211 |
212 | def get_3D_scater_trace(points, color, name,size=0.5):
213 | assert points.shape[0] == 3, "3d plot input points are not correctely shaped "
214 | assert len(points.shape) == 2, "3d plot input points are not correctely shaped "
215 |
216 | trace = go.Scatter3d(
217 | name=name,
218 | x=points[0, :],
219 | y=points[1, :],
220 | z=points[2, :],
221 | mode='markers',
222 | marker=dict(
223 | size=size,
224 | color=color,
225 | )
226 | )
227 |
228 | return trace
229 |
230 |
231 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dsfm
2 | channels:
3 | - bottler
4 | - comet_ml
5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
6 | - conda-forge
7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
9 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
10 | - defaults
11 | dependencies:
12 | - _libgcc_mutex=0.1=conda_forge
13 | - _openmp_mutex=4.5=1_llvm
14 | - aadict=0.2.3=pyh9f0ad1d_0
15 | - alsa-lib=1.2.3=h516909a_0
16 | - asset=0.6.13=pyh9f0ad1d_0
17 | - attrs=21.2.0=pyhd8ed1ab_0
18 | - binutils_impl_linux-64=2.36.1=h193b22a_2
19 | - binutils_linux-64=2.36=hf3e587d_1
20 | - blas=2.111=mkl
21 | - blas-devel=3.9.0=11_linux64_mkl
22 | - bokeh=2.4.0=py38h578d9bd_0
23 | - brotlipy=0.7.0=py38h497a2fe_1001
24 | - bzip2=1.0.8=h7f98852_4
25 | - c-ares=1.17.2=h7f98852_0
26 | - ca-certificates=2024.7.2=h06a4308_0
27 | - cairo=1.16.0=h6cf1ce9_1008
28 | - certifi=2024.7.4=py38h06a4308_0
29 | - cffi=1.14.6=py38h3931269_1
30 | - chardet=4.0.0=py38h578d9bd_1
31 | - charset-normalizer=2.0.0=pyhd8ed1ab_0
32 | - click=8.0.3=py38h578d9bd_0
33 | - cloudpickle=2.0.0=pyhd8ed1ab_0
34 | - colorama=0.4.4=pyh9f0ad1d_0
35 | - comet_ml=3.18.1=py38
36 | - configobj=5.0.6=py_0
37 | - cryptography=3.4.8=py38ha5dfef3_0
38 | - cudatoolkit=10.2.89=h8f6ccaa_9
39 | - cvxpy=1.1.15=py38h578d9bd_0
40 | - cvxpy-base=1.1.15=py38h43a58ef_0
41 | - cycler=0.10.0=py_2
42 | - cytoolz=0.11.0=py38h497a2fe_3
43 | - dask=2021.9.1=pyhd8ed1ab_0
44 | - dask-core=2021.9.1=pyhd8ed1ab_0
45 | - dataclasses=0.8=pyhc8e2a94_3
46 | - dbus=1.13.6=h48d8840_2
47 | - distributed=2021.9.1=py38h578d9bd_0
48 | - dulwich=0.20.25=py38h497a2fe_0
49 | - ecos=2.0.8=py38h6c62de6_1
50 | - eigen=3.4.0=h4bd325d_0
51 | - et_xmlfile=1.0.1=py_1001
52 | - everett=2.0.0=pyhd8ed1ab_0
53 | - expat=2.4.1=h9c3ff4c_0
54 | - ffmpeg=4.3.2=hca11adc_0
55 | - fontconfig=2.13.1=hba837de_1005
56 | - freeglut=3.2.1=h9c3ff4c_2
57 | - freetype=2.10.4=h0708190_1
58 | - fsspec=2021.10.0=pyhd8ed1ab_0
59 | - fvcore=0.1.5.post20210804=pyhd8ed1ab_0
60 | - gcc_impl_linux-64=11.2.0=h82a94d6_10
61 | - gcc_linux-64=11.2.0=h39a9532_1
62 | - gettext=0.19.8.1=h73d1719_1008
63 | - gflags=2.2.2=he1b5a44_1004
64 | - glib=2.68.4=h9c3ff4c_1
65 | - glib-tools=2.68.4=h9c3ff4c_1
66 | - globre=0.1.5=pyh9f0ad1d_0
67 | - glog=0.5.0=h48cff8f_0
68 | - gmp=6.2.1=h58526e2_0
69 | - gnutls=3.6.13=h85f3911_1
70 | - graphite2=1.3.13=h58526e2_1001
71 | - gst-plugins-base=1.18.5=hf529b03_0
72 | - gstreamer=1.18.5=h76c114f_0
73 | - gxx_impl_linux-64=11.2.0=h82a94d6_10
74 | - gxx_linux-64=11.2.0=hacbe6df_1
75 | - harfbuzz=3.0.0=h83ec7ef_1
76 | - hdf5=1.12.1=nompi_h2750804_101
77 | - heapdict=1.0.1=py_0
78 | - icu=68.1=h58526e2_0
79 | - idna=3.1=pyhd3deb0d_0
80 | - importlib-metadata=4.8.1=py38h578d9bd_0
81 | - iopath=0.1.8=pyhd8ed1ab_0
82 | - jasper=2.0.14=ha77e612_2
83 | - jbig=2.1=h7f98852_2003
84 | - jinja2=3.0.2=pyhd8ed1ab_0
85 | - jpeg=9d=h36c2ea0_0
86 | - jsonschema=4.1.0=pyhd8ed1ab_0
87 | - kernel-headers_linux-64=2.6.32=he073ed8_14
88 | - kiwisolver=1.3.2=py38h1fd1430_0
89 | - krb5=1.19.2=hcc1bbae_2
90 | - lame=3.100=h7f98852_1001
91 | - lcms2=2.12=hddcbb42_0
92 | - ld_impl_linux-64=2.36.1=hea4e1c9_2
93 | - lerc=2.2.1=h9c3ff4c_0
94 | - libblas=3.9.0=11_linux64_mkl
95 | - libcblas=3.9.0=11_linux64_mkl
96 | - libclang=11.1.0=default_ha53f305_1
97 | - libcurl=7.79.1=h2574ce0_1
98 | - libdeflate=1.7=h7f98852_5
99 | - libedit=3.1.20191231=he28a2e2_2
100 | - libev=4.33=h516909a_1
101 | - libevent=2.1.10=h9b69904_4
102 | - libffi=3.4.2=h9c3ff4c_4
103 | - libgcc-devel_linux-64=11.2.0=h0952999_10
104 | - libgcc-ng=11.2.0=h1d223b6_10
105 | - libgfortran-ng=11.2.0=h69a702a_10
106 | - libgfortran5=11.2.0=h5c6108e_10
107 | - libglib=2.68.4=h174f98d_1
108 | - libglu=9.0.0=he1b5a44_1001
109 | - libgomp=11.2.0=h1d223b6_10
110 | - libiconv=1.16=h516909a_0
111 | - liblapack=3.9.0=11_linux64_mkl
112 | - liblapacke=3.9.0=11_linux64_mkl
113 | - libllvm11=11.1.0=hf817b99_2
114 | - libnghttp2=1.43.0=h812cca2_1
115 | - libogg=1.3.4=h7f98852_1
116 | - libopencv=4.5.3=py38h66b0e6b_5
117 | - libopus=1.3.1=h7f98852_1
118 | - libpng=1.6.37=h21135ba_2
119 | - libpq=13.3=hd57d9b9_1
120 | - libprotobuf=3.18.1=h780b84a_0
121 | - libsanitizer=11.2.0=he4da1e4_10
122 | - libssh2=1.10.0=ha56f1ee_2
123 | - libstdcxx-devel_linux-64=11.2.0=h0952999_10
124 | - libstdcxx-ng=11.2.0=he4da1e4_10
125 | - libtiff=4.3.0=hf544144_1
126 | - libuuid=2.32.1=h7f98852_1000
127 | - libuv=1.42.0=h7f98852_0
128 | - libvorbis=1.3.7=h9c3ff4c_0
129 | - libwebp-base=1.2.1=h7f98852_0
130 | - libxcb=1.13=h7f98852_1003
131 | - libxkbcommon=1.0.3=he3ba5ed_0
132 | - libxml2=2.9.12=h72842e0_0
133 | - libzlib=1.2.11=h36c2ea0_1013
134 | - llvm-openmp=12.0.1=h4bd325d_1
135 | - locket=0.2.0=py_2
136 | - lz4-c=1.9.3=h9c3ff4c_1
137 | - matplotlib=3.4.3=py38h578d9bd_1
138 | - matplotlib-base=3.4.3=py38hf4fb855_1
139 | - metis=5.1.0=h58526e2_1006
140 | - mkl=2021.3.0=h726a3e6_557
141 | - mkl-devel=2021.3.0=ha770c72_558
142 | - mkl-include=2021.3.0=h726a3e6_557
143 | - mpfr=4.1.0=h9202a9a_1
144 | - msgpack-python=1.0.2=py38h1fd1430_1
145 | - mysql-common=8.0.25=ha770c72_2
146 | - mysql-libs=8.0.25=hfa10184_2
147 | - ncurses=6.2=h58526e2_4
148 | - nettle=3.6=he412f7d_0
149 | - ninja=1.10.2=h4bd325d_1
150 | - nspr=4.30=h9c3ff4c_0
151 | - nss=3.69=hb5efdd6_1
152 | - numpy=1.21.2=py38he2449b9_0
153 | - nvidia-ml=7.352.0=py_0
154 | - nvidiacub=1.10.0=0
155 | - olefile=0.46=pyh9f0ad1d_1
156 | - opencv=4.5.3=py38h578d9bd_5
157 | - openh264=2.1.1=h780b84a_0
158 | - openjpeg=2.4.0=hb52868f_1
159 | - openpyxl=3.0.9=pyhd8ed1ab_0
160 | - openssl=1.1.1w=h7f8727e_0
161 | - osqp=0.6.2=py38h43a58ef_2
162 | - packaging=21.0=pyhd8ed1ab_0
163 | - pandas=1.3.3=py38h43a58ef_0
164 | - partd=1.2.0=pyhd8ed1ab_0
165 | - pcre=8.45=h9c3ff4c_0
166 | - pip=21.3=pyhd8ed1ab_0
167 | - pixman=0.40.0=h36c2ea0_0
168 | - plotly=5.3.1=pyhd8ed1ab_0
169 | - portalocker=2.3.2=py38h578d9bd_0
170 | - psutil=5.8.0=py38h497a2fe_1
171 | - pthread-stubs=0.4=h36c2ea0_1001
172 | - py-opencv=4.5.3=py38he5a9106_5
173 | - pycparser=2.20=pyh9f0ad1d_2
174 | - pyhocon=0.3.58=pyhd8ed1ab_0
175 | - pyopenssl=21.0.0=pyhd8ed1ab_0
176 | - pyparsing=2.4.7=pyh9f0ad1d_0
177 | - pyqt=5.12.3=py38h578d9bd_7
178 | - pyqt-impl=5.12.3=py38h7400c14_7
179 | - pyqt5-sip=4.19.18=py38h709712a_7
180 | - pyqtchart=5.12=py38h7400c14_7
181 | - pyqtwebengine=5.12.1=py38h7400c14_7
182 | - pyrsistent=0.17.3=py38h497a2fe_2
183 | - pysocks=1.7.1=py38h578d9bd_3
184 | - python-dateutil=2.8.2=pyhd8ed1ab_0
185 | - pytz=2021.3=pyhd8ed1ab_0
186 | - pyyaml=5.4.1=py38h497a2fe_1
187 | - qdldl-python=0.1.5=py38h43a58ef_1
188 | - qt=5.12.9=hda022c4_4
189 | - readline=8.1=h46c0cb4_0
190 | - requests=2.26.0=pyhd8ed1ab_0
191 | - requests-toolbelt=0.9.1=py_0
192 | - rhash=1.4.1=h7f98852_0
193 | - scipy=1.7.1=py38h56a6a73_0
194 | - scs=2.1.4=py38h6afa1d1_0
195 | - semantic_version=2.8.5=pyh9f0ad1d_0
196 | - setuptools=58.2.0=py38h578d9bd_0
197 | - six=1.16.0=pyh6c4a22f_0
198 | - sortedcontainers=2.4.0=pyhd8ed1ab_0
199 | - sqlite=3.36.0=h9cd32fc_2
200 | - suitesparse=5.10.1=h9e50725_1
201 | - sysroot_linux-64=2.12=he073ed8_14
202 | - tabulate=0.8.9=pyhd8ed1ab_0
203 | - tbb=2021.3.0=h4bd325d_0
204 | - tblib=1.7.0=pyhd8ed1ab_0
205 | - tenacity=8.0.1=pyhd8ed1ab_0
206 | - termcolor=1.1.0=py_2
207 | - tk=8.6.11=h27826a3_1
208 | - toolz=0.11.1=py_0
209 | - tornado=6.1=py38h497a2fe_1
210 | - tqdm=4.62.3=pyhd8ed1ab_0
211 | - urllib3=1.26.7=pyhd8ed1ab_0
212 | - websocket-client=0.57.0=py38h578d9bd_4
213 | - wheel=0.37.0=pyhd8ed1ab_1
214 | - wrapt=1.13.1=py38h497a2fe_0
215 | - wurlitzer=3.0.2=py38h578d9bd_0
216 | - x264=1!161.3030=h7f98852_1
217 | - xlrd=2.0.1=pyhd8ed1ab_3
218 | - xorg-fixesproto=5.0=h7f98852_1002
219 | - xorg-inputproto=2.3.2=h7f98852_1002
220 | - xorg-kbproto=1.0.7=h7f98852_1002
221 | - xorg-libice=1.0.10=h7f98852_0
222 | - xorg-libsm=1.2.3=hd9c2040_1000
223 | - xorg-libx11=1.7.2=h7f98852_0
224 | - xorg-libxau=1.0.9=h7f98852_0
225 | - xorg-libxdmcp=1.1.3=h7f98852_0
226 | - xorg-libxext=1.3.4=h7f98852_1
227 | - xorg-libxfixes=5.0.3=h7f98852_1004
228 | - xorg-libxi=1.7.10=h7f98852_0
229 | - xorg-libxrender=0.9.10=h7f98852_1003
230 | - xorg-renderproto=0.11.1=h7f98852_1002
231 | - xorg-xextproto=7.3.0=h7f98852_1002
232 | - xorg-xproto=7.0.31=h7f98852_1007
233 | - xz=5.2.5=h516909a_1
234 | - yacs=0.1.6=py_0
235 | - yaml=0.2.5=h516909a_0
236 | - zict=2.0.0=py_0
237 | - zipp=3.6.0=pyhd8ed1ab_0
238 | - zlib=1.2.11=h36c2ea0_1013
239 | - zstd=1.5.0=ha95c52a_0
240 | - pip:
241 | - absl-py==2.1.0
242 | - aiohappyeyeballs==2.4.0
243 | - aiohttp==3.10.5
244 | - aiosignal==1.3.1
245 | - async-timeout==4.0.3
246 | - cachetools==5.4.0
247 | - cmake==3.30.2
248 | - easydict==1.13
249 | - filelock==3.15.4
250 | - frozenlist==1.4.1
251 | - google-auth==2.33.0
252 | - google-auth-oauthlib==1.0.0
253 | - grpcio==1.65.4
254 | - joblib==1.4.2
255 | - lit==18.1.8
256 | - markdown==3.6
257 | - markupsafe==2.1.5
258 | - mpmath==1.3.0
259 | - multidict==6.0.5
260 | - networkx==3.1
261 | - nvidia-cublas-cu11==11.10.3.66
262 | - nvidia-cuda-cupti-cu11==11.7.101
263 | - nvidia-cuda-nvrtc-cu11==11.7.99
264 | - nvidia-cuda-runtime-cu11==11.7.99
265 | - nvidia-cudnn-cu11==8.5.0.96
266 | - nvidia-cufft-cu11==10.9.0.58
267 | - nvidia-curand-cu11==10.2.10.91
268 | - nvidia-cusolver-cu11==11.4.0.1
269 | - nvidia-cusparse-cu11==11.7.4.91
270 | - nvidia-nccl-cu11==2.14.3
271 | - nvidia-nvtx-cu11==11.7.91
272 | - oauthlib==3.2.2
273 | - pillow==10.4.0
274 | - protobuf==5.27.3
275 | - pyasn1==0.6.0
276 | - pyasn1-modules==0.4.0
277 | - pyceres==2.3
278 | - pytorch3d==0.3.0
279 | - requests-oauthlib==2.0.0
280 | - rsa==4.9
281 | - scikit-learn==1.3.2
282 | - sympy==1.13.2
283 | - tensorboard==2.14.0
284 | - tensorboard-data-server==0.7.2
285 | - threadpoolctl==3.5.0
286 | - torch==2.0.0
287 | - torch-cluster==1.6.3
288 | - torch-geometric==2.5.3
289 | - torch-scatter==2.1.2
290 | - torch-sparse==0.6.18
291 | - torch-spline-conv==1.2.2
292 | - torchaudio==2.0.1
293 | - torchvision==0.15.1
294 | - triton==2.0.0
295 | - typing-extensions==4.12.2
296 | - werkzeug==3.0.3
297 | - yarl==1.9.4
298 |
--------------------------------------------------------------------------------
/code/utils/ceres_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("/path/to/ceres-solver/ceres-bin/lib/") # so
3 |
4 | import PyCeres
5 | import numpy as np
6 | import scipy.io as sio
7 | import cv2
8 | from utils import geo_utils
9 |
10 |
11 | def order_cam_param_for_c(Rs, ts, Ks):
12 | """
13 | Orders a [m, 12] matrix for the ceres function as follows:
14 | Ps_for_c[i, 0:3] 3 parameters for the vector representing the rotation
15 | Ps_for_c[i, 3:6] 3 parameters for the location of the camera
16 | Ps_for_c[i, 6:11] 5 parameters for the upper triangular part of the calibration matrix
17 | :param Rs: [m,3,3]
18 | :param ts: [m,3]
19 | :param Ks: [m,3,3]
20 | :return: Ps_for_c [m, 12]
21 | """
22 | n_cam = len(Rs)
23 | Ps_for_c = np.zeros([n_cam, 12])
24 | for i in range(n_cam):
25 | Ps_for_c[i, 0:3] = cv2.Rodrigues(Rs[i].T)[0].T
26 | Ps_for_c[i, 3:6] = (-Rs[i].T @ ts[i].reshape([3, 1])).T
27 | Ps_for_c[i, 6:11] = [Ks[i, 0, 0], Ks[i, 0, 1], Ks[i, 0, 2], Ks[i, 1, 1], Ks[i, 1, 2]]
28 | Ps_for_c[i, -1] = 1.0
29 | return Ps_for_c
30 |
31 |
32 | def reorder_from_c_to_py(Ps_for_c, Ks):
33 | """
34 | Read back the camera parameters from the
35 | :param Ps_for_c:
36 | :return: Rs, ts, Ps
37 | """
38 | n_cam = len(Ps_for_c)
39 | Rs = np.zeros([n_cam, 3, 3])
40 | ts = np.zeros([n_cam, 3])
41 | Ps = np.zeros([n_cam, 3,4])
42 | for i in range(n_cam):
43 | Rs[i] = cv2.Rodrigues(Ps_for_c[i, 0:3])[0].T
44 | ts[i] = -Rs[i] @ Ps_for_c[i, 3:6].reshape([3, 1]).flatten()
45 | Ps[i] = geo_utils.get_camera_matrix(R=Rs[i], t=ts[i], K=Ks[i])
46 | return Rs, ts, Ps
47 |
48 |
49 | def run_euclidean_ceres(Xs, xs, Rs, ts, Ks, point_indices):
50 | """
51 | Calls a c++ function that optimizes the camera parameters and the 3D points for a lower reprojection error.
52 | :param Xs: [n, 3]
53 | :param xs: [v,2]
54 | :param Rs: [m,3,3]
55 | :param ts: [m,3]
56 | :param Ks: [m,3,3]
57 | :param point_indices: [2,v]
58 | :return:
59 | new_Rs, new_ts, new_Ps, new_Xs Which have a lower reprojection error
60 | """
61 | if Xs.shape[-1] == 4:
62 | Xs = Xs[:,:3]
63 | assert Xs.shape[-1] == 3
64 | assert xs.shape[-1] == 2
65 | n_cam = len(Rs)
66 | n_pts = Xs.shape[0]
67 | n_observe = xs.shape[0]
68 |
69 | Ps_for_c = order_cam_param_for_c(Rs, ts, Ks).astype(np.double)
70 | Xs_flat = Xs.flatten("C").astype(np.double)
71 | Ps_for_c_flat = Ps_for_c.flatten("C").astype(np.double)
72 | xs_flat = xs.flatten("C").astype(np.double)
73 | point_indices = point_indices.flatten("C")
74 |
75 | Xsu = np.zeros_like(Xs_flat)
76 | Psu = np.zeros_like(Ps_for_c_flat)
77 |
78 | PyCeres.eucPythonFunctionOursBA(Xs_flat, xs_flat, Ps_for_c_flat, point_indices, Xsu, Psu, n_cam, n_pts, n_observe)
79 |
80 | new_Ps_for_c = Ps_for_c + Psu.reshape([n_cam, 12], order="C")
81 |
82 | new_Rs, new_ts, new_Ps = reorder_from_c_to_py(new_Ps_for_c, Ks)
83 | new_Xs = Xs + Xsu.reshape([n_pts,3], order="C")
84 |
85 | return new_Rs, new_ts, new_Ps, new_Xs
86 |
87 |
88 | def run_projective_ceres(Ps, Xs, xs, point_indices):
89 | """
90 | Calls the c++ function, that loops over the variables:
91 | for i in range(v):
92 | xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex)
93 | :param Ps: [m, 3, 4]
94 | :param Xs: [n, 3]
95 | :param xs: [v, 2]
96 | :param point_indices: [2,v]
97 | :return: new_Ps: [m, 12]
98 | new_Xs: [n,3]
99 | """
100 | if Xs.shape[-1] == 4:
101 | Xs = Xs[:,:3]
102 | assert Xs.shape[-1] == 3
103 | assert xs.shape[-1] == 2
104 | m = Ps.shape[0]
105 | n = Xs.shape[0]
106 | v = point_indices.shape[1]
107 | Ps_single_flat = Ps.reshape([-1, 12], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab
108 |
109 | Ps_flat = Ps_single_flat.flatten("C") # row major as in python
110 | Xs_flat = Xs.flatten("C")
111 | xs_flat = xs.flatten("C")
112 | point_idx_flat = point_indices.flatten("C")
113 |
114 | Psu = np.zeros_like(Ps_flat)
115 | Xsu = np.zeros_like(Xs_flat)
116 |
117 | PyCeres.pythonFunctionOursBA(Xs_flat, xs_flat, Ps_flat, point_idx_flat, Xsu, Psu, m, n, v)
118 | Psu = Psu.reshape([m,12], order="C")
119 | Psu = Psu.reshape([m,3,4], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab
120 | Xsu = Xsu.reshape([n,3])
121 |
122 | new_Ps = Ps + Psu
123 | new_Xs = Xs + Xsu
124 |
125 | return new_Ps, new_Xs
126 |
127 | def run_euclidean_python_ceres(Xs, xs, Rs, ts, Ks, point_indices, print_out=True, max_iter=100):
128 | """
129 | Calls a c++ function that optimizes the camera parameters and the 3D points for a lower reprojection error.
130 | :param Xs: [n, 3]
131 | :param xs: [v,2]
132 | :param Rs: [m,3,3]
133 | :param ts: [m,3]
134 | :param Ks: [m,3,3]
135 | :param point_indices: [2,v]
136 | :return:
137 | new_Rs, new_ts, new_Ps, new_Xs Which have a lower reprojection error
138 | """
139 | if Xs.shape[-1] == 4:
140 | Xs = Xs[:,:3]
141 | assert Xs.shape[-1] == 3
142 | assert xs.shape[-1] == 2
143 | n_cam = len(Rs)
144 | n_pts = Xs.shape[0]
145 | n_observe = xs.shape[0]
146 |
147 | Ps_for_c = order_cam_param_for_c(Rs, ts, Ks).astype(np.double)
148 | Xs_flat = Xs.flatten("C").astype(np.double)
149 | Ps_for_c_flat = Ps_for_c.flatten("C").astype(np.double)
150 | xs_flat = xs.flatten("C").astype(np.double)
151 | point_indices = point_indices.flatten("C")
152 |
153 | Xsu = np.zeros_like(Xs_flat)
154 | Psu = np.zeros_like(Ps_for_c_flat)
155 |
156 | problem = PyCeres.Problem()
157 | for i in range(n_observe): # loop over the observations
158 | camIndex = int(point_indices[i])
159 | point3DIndex = int(point_indices[i + n_observe])
160 |
161 | cost_function = PyCeres.eucReprojectionError(xs_flat[2 * i], xs_flat[2 * i + 1],
162 | Ps_for_c_flat[12 * camIndex:12 * (camIndex + 1)],
163 | Xs_flat[3 * point3DIndex:3 * (point3DIndex + 1)])
164 |
165 | loss_function = PyCeres.HuberLoss(0.1)
166 | problem.AddResidualBlock(cost_function, loss_function, Psu[12 * camIndex:12 * (camIndex + 1)],
167 | Xsu[3 * point3DIndex:3 * (point3DIndex + 1)])
168 |
169 | options = PyCeres.SolverOptions()
170 |
171 | options.function_tolerance = 0.0001
172 | options.max_num_iterations = max_iter
173 | options.num_threads = 8
174 |
175 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_SCHUR
176 | options.minimizer_progress_to_stdout = True
177 | if not print_out:
178 | PyCeres.LoggingType = PyCeres.LoggingType.SILENT
179 |
180 | summary = PyCeres.Summary()
181 | PyCeres.Solve(options, problem, summary)
182 | if print_out:
183 | print(summary.FullReport())
184 |
185 | if ~Psu.any():
186 | print('Warning no change to Ps')
187 | if ~Xsu.any():
188 | print('Warning no change to Xs')
189 |
190 | new_Ps_for_c = Ps_for_c + Psu.reshape([n_cam, 12], order="C")
191 |
192 | new_Rs, new_ts, new_Ps = reorder_from_c_to_py(new_Ps_for_c, Ks)
193 | new_Xs = Xs + Xsu.reshape([n_pts,3], order="C")
194 |
195 | return new_Rs, new_ts, new_Ps, new_Xs
196 |
197 | def get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre):
198 | errors = geo_utils.reprojection_error_with_points(Ps, Xs, xs, valid_points)
199 | outliers_mask = errors[:]>repro_thre
200 | outliers_col = outliers_mask.sum(axis=0)>0
201 | inliers_col = ~outliers_col
202 | Xs = Xs[inliers_col, :]
203 | xs = xs[:, inliers_col, :]
204 | valid_points = xs.sum(axis=-1)!=0
205 | point_indices = np.stack(np.where(valid_points))
206 | return Xs, xs, valid_points, point_indices, np.where(inliers_col)[0]
207 |
208 | def run_euc_ceres_iter(Ps, Xs, xs, Rs, ts, Ks, point_indices, valid_points, max_iter=100, ba_times=2, repro_thre=5, proj_first=False):
209 | # Ns = np.linalg.inv(Ks)
210 | # norm_P, norm_x = geo_utils.normalize_points_cams(Ps, xs, Ns)
211 | colidx = np.where(valid_points.sum(axis=0)!=0)[0]
212 | if proj_first:
213 | Xs, xs, valid_points, point_indices, colidx = get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre)
214 |
215 | for _ in range(ba_times):
216 | valid_xs = xs[valid_points]
217 | Rs, ts, Ps, Xs = run_euclidean_python_ceres(Xs, valid_xs, Rs, ts, Ks, point_indices, max_iter=max_iter)
218 | Xs, xs, valid_points, point_indices, newcolidx = get_new_scenespt(Ps, Xs, xs, valid_points, repro_thre)
219 | colidx = colidx[newcolidx]
220 | return Rs, ts, Ps, Xs, xs, valid_points, point_indices, colidx
221 |
222 |
223 | def run_projective_python_ceres(Ps, Xs, xs, point_indices, print_out=True):
224 | """
225 | Calls the c++ function, that loops over the variables:
226 | for i in range(v):
227 | xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex)
228 | :param Ps: [m, 3, 4]
229 | :param Xs: [n, 3]
230 | :param xs: [v, 2]
231 | :param point_indices: [2,v]
232 | :return: new_Ps: [m, 12]
233 | new_Xs: [n,3]
234 | """
235 | if Xs.shape[-1] == 4:
236 | Xs = Xs[:,:3]
237 | assert Xs.shape[-1] == 3
238 | assert xs.shape[-1] == 2
239 | m = Ps.shape[0]
240 | n = Xs.shape[0]
241 | v = point_indices.shape[1]
242 | Ps_single_flat = Ps.reshape([-1, 12], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab
243 |
244 | Ps_flat = Ps_single_flat.flatten("C").astype(np.double) # row major as in python
245 | Xs_flat = Xs.flatten("C").astype(np.double)
246 | xs_flat = xs.flatten("C")
247 | point_idx_flat = point_indices.flatten("C")
248 |
249 | Psu = np.zeros_like(Ps_flat)
250 | Xsu = np.zeros_like(Xs_flat)
251 |
252 | problem = PyCeres.Problem()
253 | for i in range(v): # loop over the observations
254 | camIndex = int(point_idx_flat[i])
255 | point3DIndex = int(point_idx_flat[i + v])
256 |
257 | cost_function = PyCeres.projReprojectionError(xs_flat[2*i], xs_flat[2*i + 1], Ps_flat[12*camIndex:12*(camIndex+1)], Xs_flat[3 *point3DIndex:3*(point3DIndex+1)])
258 |
259 | loss_function = PyCeres.HuberLoss(0.1)
260 | problem.AddResidualBlock(cost_function, loss_function, Psu[12*camIndex:12*(camIndex+1)], Xsu[3 *point3DIndex:3*(point3DIndex+1)])
261 |
262 |
263 | options = PyCeres.SolverOptions()
264 |
265 | options.function_tolerance = 0.0001
266 | options.max_num_iterations = 100
267 | options.num_threads = 8
268 |
269 | options.linear_solver_type = PyCeres.LinearSolverType.DENSE_SCHUR
270 | options.minimizer_progress_to_stdout = True
271 |
272 | summary = PyCeres.Summary()
273 | PyCeres.Solve(options, problem, summary)
274 | if print_out:
275 | print(summary.FullReport())
276 | Psu = Psu.reshape([m,12], order="C")
277 | Psu = Psu.reshape([m,3,4], order="F") # [m, 12] Each camera is in *column* major as in matlab! the cpp code assumes it because the original code was in matlab
278 | Xsu = Xsu.reshape([n,3])
279 |
280 | new_Ps = Ps + Psu
281 | new_Xs = Xs + Xsu
282 |
283 | return new_Ps, new_Xs
--------------------------------------------------------------------------------
/bundle_adjustment/custom_cpp_cost_functions.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 |
9 | namespace py = pybind11;
10 |
11 |
12 |
13 | // Parses a numpy array and extracts the pointer to the first element.
14 | // Requires that the numpy array be either an array or a row/column vector
15 | double* _ParseNumpyData(py::array_t& np_buf) {
16 | py::buffer_info info = np_buf.request();
17 | // This is essentially just all error checking. As it will always be the info
18 | // ptr
19 | if (info.ndim > 2) {
20 | std::string error_msg("Number of dimensions must be <=2. This function"
21 | "only allows either an array or row/column vector(2D matrix) "
22 | + std::to_string(info.ndim));
23 | throw std::runtime_error(
24 | error_msg);
25 | }
26 | if (info.ndim == 2) {
27 | // Row or Column Vector. Represents 1 parameter
28 | if (info.shape[0] == 1 || info.shape[1] == 1) {
29 | }
30 | else {
31 | std::string error_msg
32 | ("Matrix is not a row or column vector and instead has size "
33 | + std::to_string(info.shape[0]) + "x"
34 | + std::to_string(info.shape[1]));
35 | throw std::runtime_error(
36 | error_msg);
37 | }
38 | }
39 | return (double*)info.ptr;
40 | }
41 |
42 | struct ExampleFunctor {
43 | template
44 | bool operator()(const T *const x, T *residual) const {
45 | residual[0] = T(10.0) - x[0];
46 | return true;
47 | }
48 |
49 | static ceres::CostFunction *Create() {
50 | return new ceres::AutoDiffCostFunction(new ExampleFunctor);
53 | }
54 | };
55 |
56 | struct projReprojectionError {
57 | projReprojectionError(double observed_x, double observed_y, double* Porig, double* Xorig, double blockNum)
58 | : _observed_x(observed_x), _observed_y(observed_y), _Porig(Porig), _Xorig(Xorig), _blockNum(blockNum) {}
59 |
60 | template
61 | bool operator()(const T* const P, const T* const X,
62 | T* residuals) const {
63 |
64 | T Pnow[12];
65 | T Xnow[4];
66 | T projection[3];
67 | for (int i = 0; i < 12; i++) {
68 | Pnow[i] = P[i] + _Porig[i];
69 | }
70 | for (int i = 0; i < 3; i++) {
71 | Xnow[i] = X[i] + _Xorig[i];
72 | }
73 | Xnow[3] = T(1.0);
74 | projection[0] = Pnow[0] * Xnow[0] + Pnow[3] * Xnow[1] + Pnow[6] * Xnow[2] + Pnow[9] * Xnow[3];
75 | projection[1] = Pnow[1] * Xnow[0] + Pnow[4] * Xnow[1] + Pnow[7] * Xnow[2] + Pnow[10] * Xnow[3];
76 | projection[2] = Pnow[2] * Xnow[0] + Pnow[5] * Xnow[1] + Pnow[8] * Xnow[2] + Pnow[11] * Xnow[3];
77 |
78 | projection[0] = projection[0] / projection[2];
79 | projection[1] = projection[1] / projection[2];
80 |
81 | residuals[0] = (projection[0] - _observed_x) / _blockNum;
82 | residuals[1] = (projection[1] - _observed_y) / _blockNum;
83 |
84 |
85 | return true;
86 | }
87 |
88 | static ceres::CostFunction* CreateMyRep(const double observed_x,
89 | const double observed_y, double* Porig, double* Xorig, double blocksNum) {
90 | return (new ceres::AutoDiffCostFunction(
91 | new projReprojectionError(observed_x, observed_y, Porig, Xorig, blocksNum)));
92 | }
93 |
94 | double _observed_x;
95 | double _observed_y;
96 |
97 | double _blockNum;
98 | double* _Porig;
99 | double* _Xorig;
100 |
101 |
102 | };
103 |
104 |
105 | struct eucReprojectionError {
106 | eucReprojectionError(double observed_x, double observed_y, double * Porig, double * Xorig,double blockNum)
107 | : _observed_x(observed_x), _observed_y(observed_y), _Porig(Porig), _Xorig(Xorig), _blockNum(blockNum){}
108 |
109 | template
110 | bool operator()(const T* const P, const T* const X,
111 | T* residuals) const {
112 | // camera[0,1,2] are the angle-axis rotation.
113 |
114 | T Rnow[3];
115 | T tnow[3];
116 |
117 |
118 | T Xnow[4];
119 | T Xrot[3];
120 |
121 | T projection[3];
122 | for (int i = 0; i < 3; i++){
123 | tnow[i] = P[i+3] + _Porig[i+3]; // P[3:6] camera location
124 | Rnow[i] = P[i] + _Porig[i]; // P[0:3] camera rotation
125 | }
126 | for (int i = 0; i < 3; i++){
127 | Xnow[i] = X[i] + _Xorig[i];
128 | }
129 | ceres::AngleAxisRotatePoint(Rnow, Xnow, Xrot);
130 | Xrot[0]+=tnow[0];
131 | Xrot[1]+=tnow[1];
132 | Xrot[2]+=tnow[2];
133 |
134 | projection[0]=(Xrot[0]*_Porig[6]+Xrot[1]*_Porig[7]+Xrot[2]*_Porig[8])/Xrot[2]; // P[6:5] K
135 | projection[1]=(Xrot[1]*_Porig[9]+Xrot[2]*_Porig[10])/Xrot[2];
136 |
137 | residuals[0] = (projection[0] - _observed_x) ;
138 | residuals[1] = (projection[1] - _observed_y) ;
139 | //residuals[0] = ceres::sqrt((projection[0] - _observed_x)*(projection[0] - _observed_x) + (projection[1] - _observed_y)*(projection[1] - _observed_y));
140 | //residuals[1] = T(0.0);
141 |
142 | return true;
143 | }
144 |
145 | static ceres::CostFunction* CreateMyRepEuc(const double observed_x,
146 | const double observed_y, double * Porig, double * Xorig, double blocksNum) {
147 | return (new ceres::AutoDiffCostFunction(
148 | new eucReprojectionError(observed_x, observed_y, Porig, Xorig, blocksNum)));
149 | }
150 | double _observed_x;
151 | double _observed_y;
152 | double _blockNum;
153 | double * _Porig;
154 | double * _Xorig;
155 | };
156 |
157 |
158 |
159 | //****** Complete function here
160 | /* The gateway function */
161 | void pythonFunctionOursBA(double* Xs, double* xs, double* Ps, double* camPointmap, double* Xsu, double* Psu, int n_cam, int n_pts, int n_observe)
162 | {
163 |
164 | ceres::Problem problem;
165 | for (int i = 0; i < n_observe; i++) {
166 |
167 | int camIndex = int(camPointmap[i]);
168 | int point3DIndex = int(camPointmap[i + n_observe]);
169 |
170 | ceres::CostFunction* cost_function =
171 | projReprojectionError::CreateMyRep(xs[2*i], xs[2*i + 1], Ps + 12 * (camIndex), Xs + 3 * (point3DIndex), 1);
172 |
173 | ceres::LossFunction* loss_function = new ceres::HuberLoss(0.1);
174 | problem.AddResidualBlock(cost_function,
175 | loss_function,
176 | Psu + 12 * (camIndex), Xsu + 3 * (point3DIndex));
177 | }
178 |
179 | ceres::Solver::Options options;
180 | options.function_tolerance = 0.0001;
181 | options.max_num_iterations = 100;
182 | options.num_threads = 24;
183 |
184 | options.linear_solver_type = ceres::DENSE_SCHUR;
185 | options.minimizer_progress_to_stdout = true;
186 |
187 | ceres::Solver::Summary summary;
188 | ceres::Solve(options, &problem, &summary);
189 | std::cout << summary.FullReport().c_str() << "\n";
190 | }
191 | void eucPythonFunctionOursBA(double* Xs, double* xs, double* Ps, double* camPointmap , double* Xsu, double* Psu, int nrows, int ncols, int n_observe)
192 | {
193 |
194 | ceres::Problem problem;
195 | for (int i = 0; i < n_observe; i++){
196 |
197 | int camIndex = int(camPointmap[i]);
198 | int point3DIndex = int(camPointmap[i+n_observe]);
199 |
200 | ceres::CostFunction* cost_function =
201 | eucReprojectionError::CreateMyRepEuc(xs[2*i], xs[2*i + 1], Ps+(12*camIndex), Xs+(3*point3DIndex),1);
202 | ceres::LossFunction * loss_function = new ceres::HuberLoss(0.1);
203 |
204 | problem.AddResidualBlock(cost_function,
205 | loss_function,
206 | Psu+(12*camIndex), Xsu+(3*point3DIndex));
207 | }
208 |
209 | ceres::Solver::Options options;
210 | options.function_tolerance = 0.0001;
211 | //options.max_num_iterations = 100;
212 | options.max_num_iterations = 100;
213 | //options.num_threads = 8;
214 | options.num_threads = 24;
215 |
216 | options.linear_solver_type = ceres::DENSE_SCHUR;
217 | options.minimizer_progress_to_stdout = true;
218 |
219 | ceres::Solver::Summary summary;
220 | ceres::Solve(options, &problem, &summary);
221 | std::cout << summary.FullReport() << "\n";
222 | }
223 |
224 |
225 | //***** complete function ends here
226 |
227 | void add_custom_cost_functions(py::module &m) {
228 |
229 | // // Use pybind11 code to wrap your own cost function which is defined in C++s
230 |
231 |
232 | // // Here is an example
233 | // m.def("CreateCustomExampleCostFunction", &ExampleFunctor::Create);
234 |
235 | m.def("projReprojectionError", [](
236 | double observed_x, double observed_y, py::array_t& _Porig, py::array_t& _Xorig
237 | ) {
238 | double* Porig = _ParseNumpyData(_Porig);
239 | double* Xorig = _ParseNumpyData(_Xorig);
240 | py::gil_scoped_release release;
241 |
242 | return projReprojectionError::CreateMyRep(observed_x, observed_y, Porig, Xorig, 1);
243 | } , py::return_value_policy::reference);
244 | m.def("eucReprojectionError", [](
245 | double observed_x, double observed_y, py::array_t& _Porig, py::array_t& _Xorig
246 | ) {
247 | double* Porig = _ParseNumpyData(_Porig);
248 | double* Xorig = _ParseNumpyData(_Xorig);
249 | py::gil_scoped_release release;
250 |
251 | return eucReprojectionError::CreateMyRepEuc(observed_x, observed_y, Porig, Xorig, 1);
252 | } , py::return_value_policy::reference);
253 |
254 |
255 | m.def("pythonFunctionOursBA", [](
256 | py::array_t& _Xs,
257 | py::array_t& _xs,
258 | py::array_t& _Ps,
259 | py::array_t& _camPointmap,
260 | py::array_t& _Xsu,
261 | py::array_t& _Psu,
262 | int n_cam, int n_pts, int n_observe
263 | ) {
264 | double* Xs = _ParseNumpyData(_Xs);
265 | double* xs = _ParseNumpyData(_xs);
266 | double* Ps = _ParseNumpyData(_Ps);
267 | double* camPointmap = _ParseNumpyData(_camPointmap);
268 | double* Xsu = _ParseNumpyData(_Xsu);
269 | double* Psu = _ParseNumpyData(_Psu);
270 | py::gil_scoped_release release;
271 |
272 | pythonFunctionOursBA(Xs, xs, Ps, camPointmap, Xsu, Psu, n_cam, n_pts, n_observe);
273 |
274 | });
275 |
276 | m.def("eucPythonFunctionOursBA", [](
277 | py::array_t& _Xs,
278 | py::array_t& _xs,
279 | py::array_t& _Ps,
280 | py::array_t& _camPointmap,
281 | py::array_t& _Xsu,
282 | py::array_t& _Psu,
283 | int n_cam, int n_pts, int n_observe
284 | ) {
285 | double* Xs = _ParseNumpyData(_Xs);
286 | double* xs = _ParseNumpyData(_xs);
287 | double* Ps = _ParseNumpyData(_Ps);
288 | double* camPointmap = _ParseNumpyData(_camPointmap);
289 | double* Xsu = _ParseNumpyData(_Xsu);
290 | double* Psu = _ParseNumpyData(_Psu);
291 | py::gil_scoped_release release;
292 |
293 | eucPythonFunctionOursBA(Xs, xs, Ps, camPointmap, Xsu, Psu, n_cam, n_pts, n_observe);
294 |
295 | });
296 | }
297 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | import loss_functions
5 | import evaluation
6 | import copy
7 | from utils import path_utils, dataset_utils, plot_utils, geo_utils
8 | from time import time
9 | import pandas as pd
10 | from utils.Phases import Phases
11 | from utils.path_utils import path_to_exp
12 | import os
13 | from torch.utils.data.dataloader import DataLoader
14 | from torch.utils.tensorboard import SummaryWriter
15 |
16 | # import fitlog
17 |
18 | # for each epoch
19 | def epoch_train(train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts):
20 | model.train()
21 | train_losses = []
22 | bce_losses = []
23 | orient_losses = []
24 | trans_losses = []
25 |
26 | # for each batch
27 | for train_batch in train_data: # Loop over all sets - 30
28 | batch_loss = torch.tensor([0.0], device=train_data.device)
29 | optimizer.zero_grad()
30 |
31 | # for each data
32 | for curr_data in train_batch:
33 |
34 | pred_mask, pred_cam = model(curr_data)
35 | bceloss = loss_func1(pred_mask, curr_data, epoch)
36 | gtloss, orient_loss, trans_loss = loss_func2(pred_cam, curr_data, epoch)
37 | loss = bceloss + gtloss
38 | batch_loss += loss
39 | train_losses.append(loss.item())
40 | bce_losses.append(bceloss.item())
41 | orient_losses.append(orient_loss.item())
42 | trans_losses.append(trans_loss.item())
43 |
44 | if batch_loss.item()>0:
45 | batch_loss.backward()
46 | optimizer.step()
47 |
48 | scheduler.step()
49 | mean_loss = torch.tensor(train_losses).mean()
50 | mean_bceloss = torch.tensor(bce_losses).mean()
51 | mean_oriloss = torch.tensor(orient_losses).mean()
52 | mean_trloss = torch.tensor(trans_losses).mean()
53 | return mean_loss, mean_bceloss, mean_oriloss, mean_trloss
54 |
55 |
56 | def sim_epoch_train(conf, tbwriter, train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts, phase, validation_data, best_validation_metric):
57 | model.train()
58 | train_losses = []
59 | bce_losses = []
60 | orient_losses = []
61 | trans_losses = []
62 |
63 | i = 0
64 | epoch_size = len(train_data)
65 | print(f"there are {epoch_size} data in each epoch")
66 | eval_intervals = conf.get_int('train.eval_intervals', default=500)
67 | validation_metric = conf.get_list('train.validation_metric', default=["our_repro"])
68 | no_ba_during_training = not conf.get_bool('ba.only_last_eval')
69 | train_trans = conf.get_bool('train.train_trans')
70 | save_predictions = conf.get_bool('train.save_predictions')
71 |
72 | best_epoch = 0
73 | best_model = torch.empty(0)
74 |
75 | for train_batch in train_data: # Loop over all sets
76 | batch_loss = torch.tensor([0.0], device=train_data.device)
77 | optimizer.zero_grad()
78 |
79 | for curr_data in train_batch:
80 | i+=1
81 | epoch0 = epoch*epoch_size+i
82 | # if not dataset_utils.is_valid_sample(curr_data, min_valid_pts):
83 | # print('{} {} has a camera with not enough points'.format(epoch0, curr_data.scan_name))
84 | # continue
85 |
86 | pred_mask, pred_cam = model(curr_data)
87 | bceloss = loss_func1(pred_mask, curr_data, epoch0)
88 | if train_trans:
89 | gtloss, orient_loss, trans_loss = loss_func2(pred_cam, curr_data, epoch0)
90 | loss = bceloss + gtloss
91 | batch_loss += loss
92 | train_losses.append(loss.item())
93 | bce_losses.append(bceloss.item())
94 | orient_losses.append(orient_loss.item())
95 | trans_losses.append(trans_loss.item())
96 | else:
97 | orient_loss = loss_func2(pred_cam, curr_data, epoch0)
98 | loss = bceloss + orient_loss
99 | batch_loss += loss
100 | train_losses.append(loss.item())
101 | bce_losses.append(bceloss.item())
102 | orient_losses.append(orient_loss.item())
103 | # if torch.isnan(gtloss):
104 | # print(epoch0)
105 | # continue
106 | # print(bceloss.item())
107 | # print(gtloss.item())
108 |
109 | # print(batch_loss.item())
110 | if batch_loss.item()>0:
111 | batch_loss.backward()
112 | optimizer.step()
113 |
114 | mean_loss = torch.tensor(train_losses).mean()
115 | mean_bceloss = torch.tensor(bce_losses).mean()
116 | mean_oriloss = torch.tensor(orient_losses).mean()
117 | if train_trans: mean_trloss = torch.tensor(trans_losses).mean()
118 |
119 | epoch0 = epoch*epoch_size+i
120 | if epoch0 % 100 == 0:
121 | print('{} Train Loss: {}'.format(epoch0, mean_loss))
122 | tbwriter.add_scalar("Total_Loss", mean_loss, epoch0)
123 | tbwriter.add_scalar("BCE_Loss", mean_bceloss, epoch0)
124 | tbwriter.add_scalar("Orient_Loss", mean_oriloss, epoch0)
125 | if train_trans: tbwriter.add_scalar("Trans_Loss", mean_trloss, epoch0)
126 | # if epoch0 % 20 == 0:
127 | # scheduler.step()
128 |
129 | if epoch0!=0 and (epoch0 % eval_intervals == 0): # Eval current results
130 | if phase is Phases.TRAINING:
131 | validation_errors = epoch_evaluation(validation_data, model, conf, epoch0, Phases.VALIDATION, save_predictions=save_predictions,bundle_adjustment=no_ba_during_training)
132 | else:
133 | validation_errors = epoch_evaluation(train_data, model, conf, epoch0, phase, save_predictions=save_predictions,bundle_adjustment=no_ba_during_training)
134 |
135 | metric = validation_errors.loc[["Mean"], validation_metric].sum(axis=1).values.item()
136 | # fitlog.add_metric({"dev":{"Acc":metric}}, step=epoch)
137 |
138 | if metric < best_validation_metric:
139 | best_validation_metric = metric
140 | best_epoch = epoch0
141 | best_model = copy.deepcopy(model)
142 | print('Updated best validation metric: {}'.format(best_validation_metric))
143 | path = path_utils.path_to_model(conf, phase, epoch=epoch0)
144 | torch.save(best_model.state_dict(), path)
145 |
146 | scheduler.step()
147 | return best_epoch, best_validation_metric, best_model
148 |
149 |
150 | def epoch_evaluation(data_loader, model, conf, epoch, phase, save_predictions=False, bundle_adjustment=True):
151 | refined = conf.get_bool("ba.refined")
152 | errors_list = []
153 | model.eval()
154 | with torch.no_grad():
155 | for batch_data in data_loader:
156 | for curr_data in batch_data:
157 | # Get predictions
158 | begin_time = time()
159 | pred_mask, pred_cam = model(curr_data)
160 | # pred_mask = model(curr_data)
161 | pred_time = time() - begin_time
162 |
163 | # Eval results
164 | # outputs = evaluation.prepare_ptpredictions(curr_data, pred_mask)
165 | outputs = evaluation.prepare_predictions_2(curr_data, pred_mask, pred_cam, conf, epoch, bundle_adjustment, refined=refined)
166 | errors = evaluation.compute_errors(outputs, conf, bundle_adjustment, refined=refined)
167 |
168 | errors['Inference time'] = pred_time
169 | errors['Scene'] = curr_data.scan_name
170 |
171 | # Get scene statistics on final evaluation
172 | if epoch is None:
173 | # stats = dataset_utils.get_data_statistics(curr_data)
174 | stats = dataset_utils.get_data_statistics2(curr_data, outputs)
175 | errors.update(stats)
176 |
177 | errors_list.append(errors)
178 |
179 | if save_predictions:
180 | dataset_utils.save_cameras(outputs, conf, curr_epoch=epoch, phase=phase)
181 | if conf.get_bool('dataset.calibrated'):
182 | plot_utils.plot_cameras_before_and_after_ba(outputs, errors, conf, phase, scan=curr_data.scan_name, epoch=epoch, bundle_adjustment=bundle_adjustment)
183 |
184 | df_errors = pd.DataFrame(errors_list)
185 | mean_errors = df_errors.mean(numeric_only=True)
186 | df_errors = df_errors.append(mean_errors, ignore_index=True)
187 | df_errors.at[df_errors.last_valid_index(), "Scene"] = "Mean"
188 | df_errors.set_index("Scene", inplace=True)
189 | df_errors = df_errors.round(3)
190 | print(df_errors.to_string(), flush=True)
191 | model.train()
192 |
193 | return df_errors
194 |
195 |
196 | def train(conf, train_data, model, phase, validation_data=None, test_data=None):
197 | # fitlog.set_log_dir(os.path.join(conf.get_string('dataset.results_path'), 'logs'))
198 | tbwriter = SummaryWriter(log_dir=os.path.join(path_to_exp(conf), 'tb'), flush_secs=60)
199 |
200 | num_of_epochs = conf.get_int('train.num_of_epochs')
201 | eval_intervals = conf.get_int('train.eval_intervals', default=500)
202 | validation_metric = conf.get_list('train.validation_metric', default=["our_repro"])
203 |
204 | # Loss functions
205 | loss_func2 = getattr(loss_functions, conf.get_string('loss.func'))(conf)
206 | loss_func1 = loss_functions.BCEWithLogitLoss()
207 |
208 | # Optimizer params
209 | lr = conf.get_float('train.lr')
210 | scheduler_milestone = conf.get_list('train.scheduler_milestone')
211 | gamma = conf.get_float('train.gamma', default=0.1)
212 |
213 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
214 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=scheduler_milestone, gamma=gamma)
215 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
216 |
217 | best_validation_metric = math.inf
218 | best_epoch = 0
219 | best_model = torch.empty(0)
220 | converge_time = -1
221 | begin_time = time()
222 |
223 | no_ba_during_training = not conf.get_bool('ba.only_last_eval')
224 | min_valid_pts = conf.get_int('train.min_valid_pts', default=20)
225 |
226 | if not conf.get_bool("dataset.sample"): # using preprocessed dataset
227 | for epoch in range(num_of_epochs):
228 | print(f"epoch {epoch} now ...")
229 | bepoch, bmetric, bmodel = sim_epoch_train(conf, tbwriter, train_data, model, loss_func1, loss_func2, optimizer, scheduler, epoch, min_valid_pts, phase, validation_data, best_validation_metric)
230 | if bepoch>best_epoch:
231 | best_epoch=bepoch
232 | best_model=bmodel
233 | if bmetricthred
249 | prediction = premask.values.numpy().squeeze()
250 | outputs['prediction'] = prediction
251 | premask = prediction>thred
252 | gtmask = data.mask_sparse.values.cpu().numpy().squeeze()
253 | outputs['premask'] = premask+0
254 | outputs['gtmask'] = gtmask.astype(int)
255 |
256 | # precisions, recalls, thresholds = precision_recall_curve(gtmask.reshape(-1), prediction.reshape(-1))
257 | # plt.plot(precisions, recalls)
258 | # plt.xlabel('Recall')
259 | # plt.ylabel('Precision')
260 | # print(f"pr thresholds: {thresholds}")
261 | # prpath = os.path.join(conf.get_string('dataset.results_path'), "pr_img")
262 | # if not os.path.exists(prpath): os.mkdir(prpath)
263 | # plt.savefig(prpath+"/"+str(epoch)+".jpg")
264 |
265 | # processing M
266 | M = data.M.cpu().numpy()
267 | mask = np.zeros(M.shape)
268 | mask[0:mask.shape[0]:2]=predensemask
269 | mask[1:mask.shape[0]:2]=predensemask
270 | preM = M*mask
271 | colidx = (mask!=0).sum(axis=0)>2
272 | preM = preM[:, colidx]
273 | precolor = color[:, colidx]
274 | outputs['precolor'] = precolor
275 | outputs['raw_color'] = color
276 | # M = M[(mask!=0).sum(axis=1)!=0, :]
277 | # print(M.shape)
278 | # input()
279 | raw_xs = geo_utils.M_to_xs(M)
280 | xs = geo_utils.M_to_xs(preM)
281 |
282 | Rs_pred = py3d_trans.quaternion_to_matrix(geo_utils.norm_quats(pred_cam["quats"])).cpu().numpy()
283 | if train_trans:
284 | ts_pred = (data.gpss.cpu().numpy() + pred_cam["ts"].cpu().numpy()) * data.scale
285 | # ts_pred = pred_cam["ts"].cpu().numpy() * data.scale
286 | else:
287 | ts_pred = data.gpss.cpu().numpy() * data.scale
288 |
289 | #Rs_gt, ts_gt = geo_utils.decompose_camera_matrix(data.y.cpu().numpy(), Ks) # For alignment and R,t errors
290 | Rs_gt, ts_gt = data.Rs.cpu().numpy(), data.ts.cpu().numpy()
291 | outputs['Rs_gt'] = Rs_gt
292 | outputs['Rs'] = Rs_pred
293 | outputs['ts_gt'] = ts_gt * data.scale
294 | outputs['ts'] = ts_pred
295 |
296 | Ks = data.Ks.cpu().numpy() # data.Ns.inverse().cpu().numpy()
297 | outputs['Ks'] = Ks
298 | Ps = geo_utils.batch_get_camera_matrix_from_rtk(Rs_pred, ts_pred, Ks)
299 | pts3D_triangulated = geo_utils.n_view_triangulation(Ps, M=preM, Ns=Ns)
300 | outputs['xs'] = xs # to compute reprojection error later
301 | outputs['raw_xs'] = raw_xs
302 | outputs['Ps'] = Ps # Ps = K@(R|t)
303 | outputs['pts3D_pred'] = pts3D_triangulated # 4,n
304 |
305 | #Rs_pred, ts_pred = geo_utils.decompose_camera_matrix(Ps_norm)
306 |
307 | # Rs_fixed, ts_fixed, similarity_mat = geo_utils.align_cameras(Rs_pred, Rs_gt, ts_pred, ts_gt, return_alignment=True) # Align Rs_fixed, tx_fixed
308 | # outputs['Rs_fixed'] = Rs_fixed
309 | # outputs['ts_fixed'] = ts_fixed
310 | # outputs['pts3D_pred_fixed'] = (similarity_mat @ pts3D_triangulated)
311 | # outputs['Rs_fixed'] = Rs_pred
312 | # outputs['ts_fixed'] = ts_pred
313 | # outputs['pts3D_pred_fixed'] = pts3D_triangulated
314 |
315 | if bundle_adjustment:
316 | repeat = conf.get_bool('ba.repeat')
317 | max_iter = conf.get_int('ba.max_iter')
318 | ba_times = conf.get_int('ba.ba_times')
319 | repro_thre = conf.get_float('ba.repro_thre')
320 | ba_res = ba_functions.euc_ba(xs, raw_xs, Rs=Rs_pred, ts=ts_pred, Ks=np.linalg.inv(Ns),
321 | Xs=pts3D_triangulated.T, Ps=Ps, Ns=Ns,
322 | repeat=repeat, max_iter=max_iter, ba_times=ba_times,
323 | repro_thre=repro_thre, refined=refined) # Rs, ts, Ps, Xs
324 | outputs['Rs_ba'] = ba_res['Rs']
325 | outputs['ts_ba'] = ba_res['ts']
326 | outputs['Xs_ba'] = ba_res['Xs'].T # 4,n
327 | outputs['Ps_ba'] = ba_res['Ps']
328 | if refined:
329 | outputs['valid_points'] = ba_res['valid_points']
330 | outputs['new_xs'] = ba_res['new_xs']
331 |
332 | # R_ba_fixed, t_ba_fixed, similarity_mat = geo_utils.align_cameras(ba_res['Rs'], Rs_gt, ba_res['ts'], ts_gt,
333 | # return_alignment=True) # Align Rs_fixed, tx_fixed
334 | # outputs['Rs_ba_fixed'] = R_ba_fixed
335 | # outputs['ts_ba_fixed'] = t_ba_fixed
336 | # outputs['Xs_ba_fixed'] = (similarity_mat @ outputs['Xs_ba'])
337 | outputs['Rs_ba_fixed'] = ba_res['Rs']
338 | outputs['ts_ba_fixed'] = ba_res['ts']
339 | outputs['Xs_ba_fixed'] = ba_res['Xs'].T
340 |
341 | # else:
342 | # if bundle_adjustment:
343 | # repeat = conf.get_bool('ba.repeat')
344 | # triangulation = conf.get_bool('ba.triangulation')
345 | # ba_res = ba_functions.proj_ba(Ps=Ps, xs=xs, Xs_our=pts3D_triangulated.T, Ns=Ns, repeat=repeat,
346 | # triangulation=triangulation, return_repro=True, normalize_in_tri=True) # Ps, Xs
347 | # outputs['Xs_ba'] = ba_res['Xs'].T # 4,n
348 | # outputs['Ps_ba'] = ba_res['Ps']
349 |
350 | return outputs
351 |
352 |
353 | def prepare_ptpredictions(data, pred_mask):
354 | # Take the inputs from pred cam and turn to ndarray
355 | outputs = {}
356 | outputs['scan_name'] = data.scan_name
357 |
358 | premask = pred_mask.to('cpu')
359 | prediction = premask.values.numpy().squeeze()
360 | outputs['prediction'] = prediction
361 | premask = prediction>0.8
362 | gtmask = data.mask_sparse.values.cpu().numpy().squeeze()
363 | outputs['premask'] = premask+0
364 | outputs['gtmask'] = gtmask.astype(int)
365 |
366 | return outputs
--------------------------------------------------------------------------------