├── LICENSE ├── PureACL ├── __init__.py ├── evaluation.py ├── localization │ ├── __init__.py │ ├── base_refiner.py │ ├── feature_extractor.py │ ├── localizer.py │ ├── model3d.py │ ├── refiners.py │ └── tracker.py ├── pixlib │ ├── README.md │ ├── __init__.py │ ├── configs │ │ └── train_pixloc_kitti.yaml │ ├── datasets │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── ford.py │ │ ├── transformations.py │ │ └── view.py │ ├── geometry │ │ ├── __init__.py │ │ ├── check_jacobians.py │ │ ├── costs.py │ │ ├── interpolation.py │ │ ├── losses.py │ │ ├── optimization.py │ │ ├── utils.py │ │ └── wrappers.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── base_optimizer.py │ │ ├── classic_optimizer.py │ │ ├── gaussiannet.py │ │ ├── gnnet.py │ │ ├── learned_optimizer.py │ │ ├── s2dnet.py │ │ ├── two_view_refiner.py │ │ ├── unet.py │ │ └── utils.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── experiments.py │ │ ├── stdout_capturing.py │ │ ├── tensor.py │ │ └── tools.py ├── settings.py ├── utils │ ├── colmap.py │ ├── data.py │ ├── eval.py │ ├── io.py │ ├── quaternions.py │ └── tools.py └── visualization │ ├── animation.py │ ├── viz_2d.py │ └── viz_3d.py ├── README.md ├── architecture.jpg ├── ford_data_process ├── Project_grd2sat.py ├── SuperPoint_gen.py ├── angle_func.py ├── avi_gener.py ├── check_cross_view_center_distance.py ├── check_crossview_corres.py ├── check_groundImg_orientation.py ├── downloading_satellite_images.py ├── filelist_txt_gener.py ├── get_gps_coverage.py ├── gps_coord_func.py ├── input_libs.py ├── orb_points_gen.py ├── other_data_downloader.sh ├── pose_func.py ├── project_grd_images.py ├── project_lidar_to_camera.py ├── project_pointcloud_to_camera.py ├── raw_data_downloader.sh ├── shift_to_pose.py ├── show_all_logs_traj.py ├── superpoint.py ├── transformations.py ├── vel_npy_gener.py └── weights │ └── superpoint_v1.pth ├── requirements.txt └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 The Australian National University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /PureACL/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | formatter = logging.Formatter( 4 | fmt='[%(asctime)s %(name)s %(levelname)s] %(message)s', 5 | datefmt='%m/%d/%Y %H:%M:%S') 6 | handler = logging.StreamHandler() 7 | handler.setFormatter(formatter) 8 | handler.setLevel(logging.INFO) 9 | 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | logger.addHandler(handler) 13 | logger.propagate = False 14 | 15 | 16 | def set_logging_debug(mode: bool): 17 | if mode: 18 | logger.setLevel(logging.DEBUG) 19 | -------------------------------------------------------------------------------- /PureACL/localization/__init__.py: -------------------------------------------------------------------------------- 1 | from .model3d import Model3D # noqa 2 | from .localizer import PoseLocalizer, RetrievalLocalizer # noqa 3 | from .refiners import PoseRefiner, RetrievalRefiner # noqa 4 | from .tracker import SimpleTracker # noqa 5 | -------------------------------------------------------------------------------- /PureACL/localization/feature_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | from omegaconf import DictConfig, OmegaConf as oc 3 | import numpy as np 4 | import torch 5 | 6 | from ..pixlib.datasets.view import resize, numpy_image_to_torch 7 | 8 | 9 | class FeatureExtractor(torch.nn.Module): 10 | default_conf: Dict = dict( 11 | resize=1024, 12 | resize_by='max', 13 | ) 14 | 15 | def __init__(self, model: torch.nn.Module, device: torch.device, 16 | conf: Union[Dict, DictConfig]): 17 | super().__init__() 18 | self.conf = oc.merge(oc.create(self.default_conf), oc.create(conf)) 19 | self.device = device 20 | self.model = model 21 | 22 | assert hasattr(self.model, 'scales') 23 | assert self.conf.resize_by in ['max', 'max_force'], self.conf.resize_by 24 | self.to(device) 25 | self.eval() 26 | 27 | def prepare_input(self, image: np.array) -> torch.Tensor: 28 | return numpy_image_to_torch(image).to(self.device).unsqueeze(0) 29 | 30 | @torch.no_grad() 31 | def __call__(self, image: np.array, scale_image: int = 1): 32 | """Extract feature-maps for a given image. 33 | Args: 34 | image: input image (H, W, C) 35 | """ 36 | image = image.astype(np.float32) # better for resizing 37 | scale_resize = (1., 1.) 38 | if self.conf.resize is not None: 39 | target_size = self.conf.resize // scale_image 40 | if (max(image.shape[:2]) > target_size or 41 | self.conf.resize_by == 'max_force'): 42 | image, scale_resize = resize(image, target_size, max, 'linear') 43 | 44 | image_tensor = self.prepare_input(image) 45 | pred = self.model({'image': image_tensor}) 46 | features = pred['feature_maps'] 47 | assert len(self.model.scales) == len(features) 48 | 49 | features = [feat.squeeze(0) for feat in features] # remove batch dim 50 | confidences = pred.get('confidences') 51 | if confidences is not None: 52 | confidences = [c.squeeze(0) for c in confidences] 53 | 54 | scales = [(scale_resize[0]/s, scale_resize[1]/s) 55 | for s in self.model.scales] 56 | 57 | return features, scales, confidences 58 | -------------------------------------------------------------------------------- /PureACL/localization/localizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from typing import Optional, Dict, Tuple, Union 4 | from omegaconf import DictConfig, OmegaConf as oc 5 | from tqdm import tqdm 6 | import torch 7 | 8 | from .model3d import Model3D 9 | from .feature_extractor import FeatureExtractor 10 | from .refiners import PoseRefiner, RetrievalRefiner 11 | 12 | from ..utils.data import Paths 13 | from ..utils.io import parse_image_lists, parse_retrieval, load_hdf5 14 | from ..utils.quaternions import rotmat2qvec 15 | from ..pixlib.utils.experiments import load_experiment 16 | from ..pixlib.models import get_model 17 | from ..pixlib.geometry import Camera 18 | 19 | logger = logging.getLogger(__name__) 20 | # TODO: despite torch.no_grad in BaseModel, requires_grad flips in ref interp 21 | torch.set_grad_enabled(False) 22 | 23 | 24 | class Localizer: 25 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 26 | device: Optional[torch.device] = None): 27 | if device is None: 28 | if torch.cuda.is_available(): 29 | device = torch.device('cuda:0') 30 | else: 31 | device = torch.device('cpu') 32 | 33 | self.model3d = Model3D(paths.reference_sfm) 34 | cameras = parse_image_lists(paths.query_list, with_intrinsics=True) 35 | self.queries = {n: c for n, c in cameras} 36 | 37 | # Loading feature extractor and optimizer from experiment or scratch 38 | conf = oc.create(conf) 39 | conf_features = conf.features.get('conf', {}) 40 | conf_optim = conf.get('optimizer', {}) 41 | if conf.get('experiment'): 42 | pipeline = load_experiment( 43 | conf.experiment, 44 | {'extractor': conf_features, 'optimizer': conf_optim}) 45 | pipeline = pipeline.to(device) 46 | logger.debug( 47 | 'Use full pipeline from experiment %s with config:\n%s', 48 | conf.experiment, oc.to_yaml(pipeline.conf)) 49 | extractor = pipeline.extractor 50 | optimizer = pipeline.optimizer 51 | if isinstance(optimizer, torch.nn.ModuleList): 52 | optimizer = list(optimizer) 53 | else: 54 | assert 'name' in conf.features 55 | extractor = get_model(conf.features.name)(conf_features) 56 | optimizer = get_model(conf.optimizer.name)(conf_optim) 57 | 58 | self.paths = paths 59 | self.conf = conf 60 | self.device = device 61 | self.optimizer = optimizer 62 | self.extractor = FeatureExtractor( 63 | extractor, device, conf.features.get('preprocessing', {})) 64 | 65 | def run_query(self, name: str, camera: Camera): 66 | raise NotImplementedError 67 | 68 | def run_batched(self, skip: Optional[int] = None, 69 | ) -> Tuple[Dict[str, Tuple], Dict]: 70 | output_poses = {} 71 | output_logs = { 72 | 'paths': self.paths.asdict(), 73 | 'configuration': oc.to_yaml(self.conf), 74 | 'localization': {}, 75 | } 76 | 77 | logger.info('Starting the localization process...') 78 | query_names = list(self.queries.keys())[::skip or 1] 79 | for name in tqdm(query_names): 80 | camera = Camera.from_colmap(self.queries[name]) 81 | try: 82 | ret = self.run_query(name, camera) 83 | except RuntimeError as e: 84 | if 'CUDA out of memory' in str(e): 85 | logger.info('Out of memory') 86 | torch.cuda.empty_cache() 87 | ret = {'success': False} 88 | else: 89 | raise 90 | output_logs['localization'][name] = ret 91 | if ret['success']: 92 | R, tvec = ret['T_refined'].numpy() 93 | elif 'T_init' in ret: 94 | R, tvec = ret['T_init'].numpy() 95 | else: 96 | continue 97 | output_poses[name] = (rotmat2qvec(R), tvec) 98 | 99 | return output_poses, output_logs 100 | 101 | 102 | class RetrievalLocalizer(Localizer): 103 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 104 | device: Optional[torch.device] = None): 105 | super().__init__(paths, conf, device) 106 | 107 | if paths.global_descriptors is not None: 108 | global_descriptors = load_hdf5(paths.global_descriptors) 109 | else: 110 | global_descriptors = None 111 | 112 | self.refiner = RetrievalRefiner( 113 | self.device, self.optimizer, self.model3d, self.extractor, paths, 114 | self.conf.refinement, global_descriptors=global_descriptors) 115 | 116 | if paths.hloc_logs is not None: 117 | logger.info('Reading hloc logs...') 118 | with open(paths.hloc_logs, 'rb') as f: 119 | self.logs = pickle.load(f)['loc'] 120 | self.retrieval = {q: [self.model3d.dbs[i].name for i in loc['db']] 121 | for q, loc in self.logs.items()} 122 | elif paths.retrieval_pairs is not None: 123 | self.logs = None 124 | self.retrieval = parse_retrieval(paths.retrieval_pairs) 125 | else: 126 | raise ValueError 127 | 128 | def run_query(self, name: str, camera: Camera): 129 | dbs = [self.model3d.name2id[r] for r in self.retrieval[name]] 130 | loc = None if self.logs is None else self.logs[name] 131 | ret = self.refiner.refine(name, camera, dbs, loc=loc) 132 | return ret 133 | 134 | 135 | class PoseLocalizer(Localizer): 136 | def __init__(self, paths: Paths, conf: Union[DictConfig, Dict], 137 | device: Optional[torch.device] = None): 138 | super().__init__(paths, conf, device) 139 | 140 | self.refiner = PoseRefiner( 141 | device, self.optimizer, self.model3d, self.extractor, paths, 142 | self.conf.refinement) 143 | 144 | logger.info('Reading hloc logs...') 145 | with open(paths.log_path, 'rb') as f: 146 | self.logs = pickle.load(f)['loc'] 147 | 148 | def run_query(self, name: str, camera: Camera): 149 | loc = self.logs[name] 150 | if loc['PnP_ret']['success']: 151 | ret = self.refiner.refine(name, camera, loc) 152 | else: 153 | ret = {'success': False} 154 | return ret 155 | -------------------------------------------------------------------------------- /PureACL/localization/model3d.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict 3 | from typing import Dict, List, Optional 4 | import numpy as np 5 | 6 | from ..utils.colmap import read_model 7 | from ..utils.quaternions import weighted_pose 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Model3D: 13 | def __init__(self, path): 14 | logger.info('Reading COLMAP model %s.', path) 15 | self.cameras, self.dbs, self.points3D = read_model(path) 16 | self.name2id = {i.name: i.id for i in self.dbs.values()} 17 | 18 | def covisbility_filtering(self, dbids): 19 | clusters = do_covisibility_clustering(dbids, self.dbs, self.points3D) 20 | dbids = clusters[0] 21 | return dbids 22 | 23 | def pose_approximation(self, qname, dbids, global_descriptors, alpha=8): 24 | """Described in: 25 | Benchmarking Image Retrieval for Visual Localization. 26 | Noé Pion, Martin Humenberger, Gabriela Csurka, 27 | Yohann Cabon, Torsten Sattler. 3DV 2020. 28 | """ 29 | dbs = [self.dbs[i] for i in dbids] 30 | 31 | dbdescs = np.stack([global_descriptors[im.name] for im in dbs]) 32 | qdesc = global_descriptors[qname] 33 | sim = dbdescs @ qdesc 34 | weights = sim**alpha 35 | weights /= weights.sum() 36 | 37 | tvecs = [im.tvec for im in dbs] 38 | qvecs = [im.qvec for im in dbs] 39 | return weighted_pose(tvecs, qvecs, weights) 40 | 41 | def get_dbid_to_p3dids(self, p3did_to_dbids): 42 | """Link the database images to selected 3D points.""" 43 | dbid_to_p3dids = defaultdict(list) 44 | for p3id, obs_dbids in p3did_to_dbids.items(): 45 | for obs_dbid in obs_dbids: 46 | dbid_to_p3dids[obs_dbid].append(p3id) 47 | return dict(dbid_to_p3dids) 48 | 49 | def get_p3did_to_dbids(self, dbids: List, loc: Optional[Dict] = None, 50 | inliers: Optional[List] = None, 51 | point_selection: str = 'all', 52 | min_track_length: int = 3): 53 | """Return a dictionary mapping 3D point ids to their covisible dbids. 54 | This function can use hloc sfm logs to only select inliers. 55 | Which can be further used to select top reference images / in 56 | sufficient track length selection of points. 57 | """ 58 | p3did_to_dbids = defaultdict(set) 59 | if point_selection == 'all': 60 | for dbid in dbids: 61 | p3dids = self.dbs[dbid].point3D_ids 62 | for p3did in p3dids[p3dids != -1]: 63 | p3did_to_dbids[p3did].add(dbid) 64 | elif point_selection in ['inliers', 'matched']: 65 | if loc is None: 66 | raise ValueError('"{point_selection}" point selection requires' 67 | ' localization logs.') 68 | 69 | # The given SfM model must match the localization SfM model! 70 | for (p3did, dbidxs), inlier in zip(loc["keypoint_index_to_db"][1], 71 | inliers): 72 | if inlier or point_selection == 'matched': 73 | obs_dbids = set(loc["db"][dbidx] for dbidx in dbidxs) 74 | obs_dbids &= set(dbids) 75 | if len(obs_dbids) > 0: 76 | p3did_to_dbids[p3did] |= obs_dbids 77 | else: 78 | raise ValueError(f"{point_selection} point selection not defined.") 79 | 80 | # Filter unstable points (min track length) 81 | p3did_to_dbids = { 82 | i: v 83 | for i, v in p3did_to_dbids.items() 84 | if len(self.points3D[i].image_ids) >= min_track_length 85 | } 86 | 87 | return p3did_to_dbids 88 | 89 | def rerank_and_filter_db_images(self, dbids: List, ninl_dbs: List, 90 | num_dbs: int, min_matches_db: int = 0): 91 | """Re-rank the images by inlier count and filter invalid images.""" 92 | dbids = [dbids[i] for i in np.argsort(-ninl_dbs) 93 | if ninl_dbs[i] > min_matches_db] 94 | # Keep top num_images matched image images 95 | dbids = dbids[:num_dbs] 96 | return dbids 97 | 98 | def get_db_inliers(self, loc: Dict, dbids: List, inliers: List): 99 | """Get the number of inliers for each db.""" 100 | inliers = loc["PnP_ret"]["inliers"] 101 | dbids = loc["db"] 102 | ninl_dbs = np.zeros(len(dbids)) 103 | for (_, dbidxs), inl in zip(loc["keypoint_index_to_db"][1], inliers): 104 | if not inl: 105 | continue 106 | for dbidx in dbidxs: 107 | ninl_dbs[dbidx] += 1 108 | return ninl_dbs 109 | 110 | 111 | def do_covisibility_clustering(frame_ids, all_images, points3D): 112 | clusters = [] 113 | visited = set() 114 | 115 | for frame_id in frame_ids: 116 | # Check if already labeled 117 | if frame_id in visited: 118 | continue 119 | 120 | # New component 121 | clusters.append([]) 122 | queue = {frame_id} 123 | while len(queue): 124 | exploration_frame = queue.pop() 125 | 126 | # Already part of the component 127 | if exploration_frame in visited: 128 | continue 129 | visited.add(exploration_frame) 130 | clusters[-1].append(exploration_frame) 131 | 132 | observed = all_images[exploration_frame].point3D_ids 133 | connected_frames = set( 134 | j for i in observed if i != -1 for j in points3D[i].image_ids) 135 | connected_frames &= set(frame_ids) 136 | connected_frames -= visited 137 | queue |= connected_frames 138 | 139 | clusters = sorted(clusters, key=len, reverse=True) 140 | return clusters 141 | -------------------------------------------------------------------------------- /PureACL/localization/refiners.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Optional, List 3 | 4 | from .base_refiner import BaseRefiner 5 | from ..pixlib.geometry import Pose, Camera 6 | from ..utils.colmap import qvec2rotmat 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PoseRefiner(BaseRefiner): 12 | default_config = dict( 13 | min_matches_total=10, 14 | ) 15 | 16 | def refine(self, qname: str, qcamera: Camera, loc: Dict) -> Dict: 17 | # Unpack initial query pose 18 | T_init = Pose.from_Rt(qvec2rotmat(loc["PnP_ret"]["qvec"]), 19 | loc["PnP_ret"]["tvec"]) 20 | fail = {'success': False, 'T_init': T_init} 21 | 22 | num_inliers = loc["PnP_ret"]["num_inliers"] 23 | if num_inliers < self.conf.min_matches_total: 24 | logger.debug(f"Too few inliers: {num_inliers}") 25 | return fail 26 | 27 | # Fetch database inlier matches count 28 | dbids = loc["db"] 29 | inliers = loc["PnP_ret"]["inliers"] 30 | ninl_dbs = self.model3d.get_db_inliers(loc, dbids, inliers) 31 | 32 | # Re-rank and filter database images 33 | dbids = self.model3d.rerank_and_filter_db_images( 34 | dbids, ninl_dbs, self.conf.num_dbs, self.conf.min_matches_db) 35 | 36 | # Abort if no image matches the minimum number of inliers criterion 37 | if len(dbids) == 0: 38 | logger.debug("No DB image with min num matches") 39 | return fail 40 | 41 | # Select the 3D points and collect their observations 42 | p3did_to_dbids = self.model3d.get_p3did_to_dbids( 43 | dbids, loc, inliers, self.conf.point_selection, 44 | self.conf.min_track_length) 45 | 46 | # Abort if there are not enough 3D points after filtering 47 | if len(p3did_to_dbids) < self.conf.min_points_opt: 48 | logger.debug("Not enough valid 3D points to optimize") 49 | return fail 50 | 51 | ret = self.refine_query_pose(qname, qcamera, T_init, p3did_to_dbids) 52 | ret = {**ret, 'dbids': dbids} 53 | return ret 54 | 55 | 56 | class RetrievalRefiner(BaseRefiner): 57 | default_config = dict( 58 | multiscale=None, 59 | filter_covisibility=False, 60 | do_pose_approximation=False, 61 | do_inlier_ranking=False, 62 | ) 63 | 64 | def __init__(self, *args, **kwargs): 65 | self.global_descriptors = kwargs.pop('global_descriptors', None) 66 | super().__init__(*args, **kwargs) 67 | 68 | def refine(self, qname: str, qcamera: Camera, dbids: List[int], 69 | loc: Optional[Dict] = None) -> Dict: 70 | 71 | if self.conf.do_inlier_ranking: 72 | assert loc is not None 73 | 74 | if self.conf.do_inlier_ranking and loc['PnP_ret']['success']: 75 | inliers = loc['PnP_ret']['inliers'] 76 | ninl_dbs = self.model3d.get_db_inliers(loc, dbids, inliers) 77 | dbids = self.model3d.rerank_and_filter_db_images( 78 | dbids, ninl_dbs, self.conf.num_dbs, 79 | self.conf.min_matches_db) 80 | else: 81 | assert self.conf.point_selection == 'all' 82 | dbids = dbids[:self.conf.num_dbs] 83 | if self.conf.do_pose_approximation or self.conf.filter_covisibility: 84 | dbids = self.model3d.covisbility_filtering(dbids) 85 | inliers = None 86 | 87 | if self.conf.do_pose_approximation: 88 | if self.global_descriptors is None: 89 | raise RuntimeError( 90 | 'Pose approximation requires global descriptors') 91 | Rt_init = self.model3d.pose_approximation( 92 | qname, dbids, self.global_descriptors) 93 | else: 94 | id_init = dbids[0] 95 | image_init = self.model3d.dbs[id_init] 96 | Rt_init = (image_init.qvec2rotmat(), image_init.tvec) 97 | T_init = Pose.from_Rt(*Rt_init) 98 | fail = {'success': False, 'T_init': T_init, 'dbids': dbids} 99 | 100 | p3did_to_dbids = self.model3d.get_p3did_to_dbids( 101 | dbids, loc, inliers, self.conf.point_selection, 102 | self.conf.min_track_length) 103 | 104 | # Abort if there are not enough 3D points after filtering 105 | if len(p3did_to_dbids) < self.conf.min_points_opt: 106 | logger.debug("Not enough valid 3D points to optimize") 107 | return fail 108 | 109 | ret = self.refine_query_pose(qname, qcamera, T_init, p3did_to_dbids, 110 | self.conf.multiscale) 111 | ret = {**ret, 'dbids': dbids} 112 | return ret 113 | -------------------------------------------------------------------------------- /PureACL/localization/tracker.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | class BaseTracker: 5 | def __init__(self, refiner): 6 | # attach the tracker to the refiner 7 | refiner.tracker = self 8 | 9 | # attach the tracker to the optimizer(s) 10 | opts = refiner.optimizer 11 | opts = opts if isinstance(opts, (tuple, list)) else [opts] 12 | for opt in opts: 13 | opt.logging_fn = self.log_optim_iter 14 | 15 | def log_dense(self, **args): 16 | raise NotImplementedError 17 | 18 | def log_optim_done(self, **args): 19 | raise NotImplementedError 20 | 21 | def log_optim_iter(self, **args): 22 | raise NotImplementedError 23 | 24 | 25 | class SimpleTracker(BaseTracker): 26 | def __init__(self, refiner): 27 | super().__init__(refiner) 28 | 29 | self.dense = defaultdict(dict) 30 | self.costs = [] 31 | self.T = [] 32 | self.dt = [] 33 | self.p3d = None 34 | self.p3d_ids = None 35 | self.num_iters = [] 36 | 37 | def log_dense(self, **args): 38 | feats = [f.cpu() for f in args['features']] 39 | weights = [w.cpu()[0] for w in args['weight']] 40 | data = (args['image'], feats, weights) 41 | self.dense[args['name']][args['image_scale']] = data 42 | 43 | def log_optim_done(self, **args): 44 | self.p3d = args['p3d'] 45 | self.p3d_ids = args['p3d_ids'] 46 | 47 | def log_optim_iter(self, **args): 48 | if args['i'] == 0: # new scale or level 49 | self.costs.append([]) 50 | self.T.append(args['T_init'].cpu()) 51 | self.num_iters.append(None) 52 | 53 | valid = args['valid'].float() 54 | cost = ((valid*args['cost']).sum(-1)/valid.sum(-1)) 55 | 56 | self.costs[-1].append(cost.cpu().numpy()) 57 | self.dt.append(args['T_delta'].magnitude()[1].cpu().numpy()) 58 | self.num_iters[-1] = args['i']+1 59 | self.T.append(args['T'].cpu()) 60 | -------------------------------------------------------------------------------- /PureACL/pixlib/README.md: -------------------------------------------------------------------------------- 1 | # PixLib - training library 2 | 3 | `pixlib` is built on top of a framework whose core principles are: 4 | 5 | - modularity: it is easy to add a new dataset or model with custom loss and metrics; 6 | - reusability: components like geometric primitives, training loop, or experiment tools are reused across projects; 7 | - reproducibility: a training run is parametrized by a configuration, which is saved and reused for evaluation; 8 | - simplicity: it has few external dependencies, and can be easily grasped by a new user. 9 | 10 | ## Framework structure 11 | `pixlib` includes of the following components: 12 | - [`datasets/`](./datasets) contains the dataloaders, all inherited from [`BaseDataset`](./datasets/base_dataset.py). Each loader is configurable and produces a set of batched data dictionaries. 13 | - [`models/`](./models) contains the deep networks and learned blocks, all inherited from [`BaseModel`](./models/base_model.py). Each model is configurable, takes as input data, and outputs predictions. It also exposes its own loss and evaluation metrics. 14 | - [`geometry/`](pixlib/geometry) groups Numpy/PyTorch primitives for 3D vision: poses and camera models, linear algebra, optimization, etc. 15 | - [`utils/`](./utils) contains various utilities, for example to [manage experiments](./utils/experiments.py). 16 | 17 | Datasets, models, and training runs are parametrized by [omegaconf](https://github.com/omry/omegaconf) configurations. See examples of training configurations in [`configs/`](./configs/) as `.yaml` files. 18 | 19 | ## Workflow 20 |
21 | Training:
22 | 23 | The following command starts a new training run: 24 | ```bash 25 | python3 -m sidfm.pixlib.train experiment_name \ 26 | --conf sidfm/pixlib/configs/config_name.yaml 27 | ``` 28 | 29 | It creates a new directory `experiment_name/` in `TRAINING_PATH` and dumps the configuration, model checkpoints, logs of stdout, and [Tensorboard](https://pytorch.org/docs/stable/tensorboard.html) summaries. 30 | 31 | Extra flags can be given: 32 | 33 | - `--overfit` loops the training and validation sets on a single batch ([useful to test losses and metrics](http://karpathy.github.io/2019/04/25/recipe/)). 34 | - `--restore` restarts the training from the last checkpoint (last epoch) of the same experiment. 35 | - `--distributed` uses all GPUs available with multiple processes and batch norm synchronization. 36 | - individual configuration entries to overwrite the YAML entries. Examples: `train.lr=0.001` or `data.batch_size=8`. 37 | 38 | **Monitoring the training:** Launch a Tensorboard session with `tensorboard --logdir=path/to/TRAINING_PATH` to visualize losses and metrics, and compare them across experiments. Press `Ctrl+C` to gracefully interrupt the training. 39 |
40 | 41 |
42 | Inference with a trained model:
43 | 44 | After training, you can easily load a model to evaluate it: 45 | ```python 46 | from sidfm.pixlib.utils.experiments import load_experiment 47 | 48 | test_conf = {} # will overwrite the training and default configurations 49 | model = load_experiment('name_of_my_experiment', test_conf) 50 | model = model.eval().cuda() # optionally move the model to GPU 51 | predictions = model(data) # data is a dictionary of tensors 52 | ``` 53 | 54 |
55 | 56 |
57 | Adding new datasets or models:
58 | 59 | We simply need to create a new file in [`datasets/`](./datasets/) or [`models/`](./models/). This makes it easy to collaborate on the same codebase. Each class should inherit from the base class, declare a `default_conf`, and define some specific methods. Have a look at the base files [`BaseDataset`](./datasets/base_dataset.py) and [`BaseModel`](./models/base_model.py) for more details. Please follow [PEP 8](https://www.python.org/dev/peps/pep-0008/) and use relative imports. 60 | 61 |
62 | -------------------------------------------------------------------------------- /PureACL/pixlib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanWang-Shan/PureACL/21a4c2e64f0eeafaa09117b6bc8aef40d9cdf4e3/PureACL/pixlib/__init__.py -------------------------------------------------------------------------------- /PureACL/pixlib/configs/train_pixloc_kitti.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | name: kitti 3 | max_num_points3D: 10000 4 | force_num_points3D: true 5 | batch_size: 1 6 | num_workers: 0 7 | seed: 1 8 | model: 9 | name: two_view_refiner 10 | success_thresh: 3 11 | normalize_features: true 12 | duplicate_optimizer_per_scale: true 13 | normalize_dt: false 14 | extractor: 15 | name: unet 16 | encoder: vgg16 17 | decoder: [64, 64, 64, 32] 18 | output_scales: [0, 2, 4] 19 | output_dim: [32, 128, 128] 20 | freeze_batch_normalization: false 21 | do_average_pooling: false 22 | compute_uncertainty: true 23 | checkpointed: true 24 | optimizer: 25 | name: learned_optimizer 26 | num_iters: 15 27 | pad: 3 28 | lambda_: 0.01 29 | verbose: false 30 | loss_fn: scaled_barron(0, 0.1) 31 | jacobi_scaling: false 32 | learned_damping: true 33 | damping: 34 | type: constant 35 | train: 36 | seed: 0 37 | epochs: 200 38 | log_every_iter: 50 39 | eval_every_iter: 500 40 | dataset_callback_fn: sample_new_items 41 | lr: 1.0e-05 42 | lr_scaling: [[100, ['dampingnet.const']]] 43 | median_metrics: 44 | - loss/reprojection_error 45 | - loss/reprojection_error/init 46 | clip_grad: 1.0 47 | -------------------------------------------------------------------------------- /PureACL/pixlib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from PureACL.pixlib.utils.tools import get_class 2 | from PureACL.pixlib.datasets.base_dataset import BaseDataset 3 | 4 | 5 | def get_dataset(name): 6 | return get_class(name, __name__, BaseDataset) 7 | -------------------------------------------------------------------------------- /PureACL/pixlib/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for dataset. 3 | See mnist.py for an example of dataset. 4 | """ 5 | 6 | from abc import ABCMeta, abstractmethod 7 | import collections 8 | import logging 9 | from omegaconf import OmegaConf 10 | import omegaconf 11 | import torch 12 | #from torch._six import string_classes 13 | from torch.utils.data import DataLoader, Sampler, get_worker_info 14 | from torch.utils.data._utils.collate import (default_collate_err_msg_format, 15 | np_str_obj_array_pattern) 16 | 17 | from ..utils.tools import set_num_threads, set_seed 18 | 19 | logger = logging.getLogger(__name__) 20 | string_classes = (str, bytes) 21 | 22 | 23 | class LoopSampler(Sampler): 24 | def __init__(self, loop_size, total_size=None): 25 | self.loop_size = loop_size 26 | self.total_size = total_size - (total_size % loop_size) 27 | 28 | def __iter__(self): 29 | return (i % self.loop_size for i in range(self.total_size)) 30 | 31 | def __len__(self): 32 | return self.total_size 33 | 34 | 35 | def worker_init_fn(i): 36 | info = get_worker_info() 37 | if hasattr(info.dataset, 'conf'): 38 | conf = info.dataset.conf 39 | set_seed(info.id + conf.seed) 40 | set_num_threads(conf.num_threads) 41 | else: 42 | set_num_threads(1) 43 | 44 | 45 | def collate(batch): 46 | """Difference with PyTorch default_collate: it can stack of other objects. 47 | """ 48 | if not isinstance(batch, list): # no batching 49 | return batch 50 | elem = batch[0] 51 | elem_type = type(elem) 52 | if isinstance(elem, torch.Tensor): 53 | out = None 54 | if torch.utils.data.get_worker_info() is not None: 55 | # If we're in a background process, concatenate directly into a 56 | # shared memory tensor to avoid an extra copy 57 | numel = sum([x.numel() for x in batch]) 58 | storage = elem.storage()._new_shared(numel) 59 | out = elem.new(storage) 60 | return torch.stack(batch, 0, out=out) 61 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 62 | and elem_type.__name__ != 'string_': 63 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 64 | # array of string classes and object 65 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 66 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 67 | 68 | return collate([torch.as_tensor(b) for b in batch]) 69 | elif elem.shape == (): # scalars 70 | return torch.as_tensor(batch) 71 | elif isinstance(elem, float): 72 | return torch.tensor(batch, dtype=torch.float64) 73 | elif isinstance(elem, int): 74 | return torch.tensor(batch) 75 | elif isinstance(elem, string_classes): 76 | return batch 77 | elif isinstance(elem, collections.abc.Mapping): 78 | return {key: collate([d[key] for d in batch]) for key in elem} 79 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 80 | return elem_type(*(collate(samples) for samples in zip(*batch))) 81 | elif isinstance(elem, collections.abc.Sequence): 82 | # check to make sure that the elements in batch have consistent size 83 | it = iter(batch) 84 | elem_size = len(next(it)) 85 | if not all(len(elem) == elem_size for elem in it): 86 | raise RuntimeError('each element in list of batch should be of equal size') 87 | transposed = zip(*batch) 88 | return [collate(samples) for samples in transposed] 89 | else: 90 | # try to stack anyway in case the object implements stacking. 91 | return torch.stack(batch, 0) 92 | 93 | 94 | class BaseDataset(metaclass=ABCMeta): 95 | """ 96 | What the dataset model is expect to declare: 97 | default_conf: dictionary of the default configuration of the dataset. 98 | It overwrites base_default_conf in BaseModel, and it is overwritten by 99 | the user-provided configuration passed to __init__. 100 | Configurations can be nested. 101 | 102 | _init(self, conf): initialization method, where conf is the final 103 | configuration object (also accessible with `self.conf`). Accessing 104 | unkown configuration entries will raise an error. 105 | 106 | get_dataset(self, split): method that returns an instance of 107 | torch.utils.data.Dataset corresponding to the requested split string, 108 | which can be `'train'`, `'val'`, or `'test'`. 109 | """ 110 | base_default_conf = { 111 | 'name': 'ford', 112 | 'num_workers': 0, 113 | 'train_batch_size': 1, 114 | 'val_batch_size': 1, 115 | 'test_batch_size': 1, 116 | 'shuffle_training': True, 117 | 'batch_size': 1, 118 | 'num_threads': 1, 119 | 'seed': 0, 120 | 'mul_query': False, 121 | 'two_view': True, 122 | } 123 | default_conf = {} 124 | 125 | def __init__(self, conf): 126 | """Perform some logic and call the _init method of the child model.""" 127 | default_conf = OmegaConf.merge( 128 | OmegaConf.create(self.base_default_conf), 129 | OmegaConf.create(self.default_conf)) 130 | OmegaConf.set_struct(default_conf, False) 131 | if isinstance(conf, dict): 132 | conf = OmegaConf.create(conf) 133 | self.conf = OmegaConf.merge(default_conf, conf) 134 | OmegaConf.set_readonly(self.conf, True) 135 | logger.info(f'Creating dataset {self.__class__.__name__}') 136 | self._init(self.conf) 137 | 138 | @abstractmethod 139 | def _init(self, conf): 140 | """To be implemented by the child class.""" 141 | raise NotImplementedError 142 | 143 | @abstractmethod 144 | def get_dataset(self, split): 145 | """To be implemented by the child class.""" 146 | raise NotImplementedError 147 | 148 | def get_data_loader(self, split, shuffle=None, pinned=True, 149 | distributed=False): 150 | """Return a data loader for a given split.""" 151 | assert split in ['train', 'val', 'test'] 152 | dataset = self.get_dataset(split) 153 | try: 154 | batch_size = self.conf[split+'_batch_size'] 155 | except omegaconf.MissingMandatoryValue: 156 | batch_size = self.conf.batch_size 157 | num_workers = self.conf.get('num_workers', batch_size) 158 | if distributed: 159 | shuffle = False 160 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) 161 | else: 162 | sampler = None 163 | if shuffle is None: 164 | shuffle = (split == 'train' and self.conf.shuffle_training) 165 | return DataLoader( 166 | dataset, batch_size=batch_size, shuffle=shuffle, 167 | sampler=sampler, pin_memory=pinned, collate_fn=collate, 168 | num_workers=num_workers, worker_init_fn=worker_init_fn) 169 | 170 | def get_overfit_loader(self, split): 171 | """Return an overfit data loader. 172 | The training set is composed of a single duplicated batch, while 173 | the validation and test sets contain a single copy of this same batch. 174 | This is useful to debug a model and make sure that losses and metrics 175 | correlate well. 176 | """ 177 | assert split in ['train', 'val', 'test'] 178 | dataset = self.get_dataset('train') 179 | sampler = LoopSampler( 180 | self.conf.batch_size, 181 | len(dataset) if split == 'train' else self.conf.batch_size) 182 | num_workers = self.conf.get('num_workers', self.conf.batch_size) 183 | return DataLoader(dataset, batch_size=self.conf.batch_size, 184 | pin_memory=True, num_workers=num_workers, 185 | sampler=sampler, worker_init_fn=worker_init_fn) 186 | -------------------------------------------------------------------------------- /PureACL/pixlib/datasets/view.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import cv2 4 | # TODO: consider using PIL instead of OpenCV as it is heavy and only used here 5 | import torch 6 | 7 | from ..geometry import Camera, Pose 8 | 9 | 10 | def numpy_image_to_torch(image): 11 | """Normalize the image tensor and reorder the dimensions.""" 12 | if image.ndim == 3: 13 | image = image.transpose((2, 0, 1)) # HxWxC to CxHxW 14 | elif image.ndim == 2: 15 | image = image[None] # add channel axis 16 | else: 17 | raise ValueError(f'Not an image: {image.shape}') 18 | return torch.from_numpy(image / 255.).float() 19 | 20 | 21 | def read_image(path, grayscale=False): 22 | mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR 23 | image = cv2.imread(str(path), mode) 24 | if image is None: 25 | raise IOError(f'Could not read image at {path}.') 26 | if not grayscale: 27 | image = image[..., ::-1] 28 | return image 29 | 30 | 31 | def resize(image, size, fn=None, interp='linear'): 32 | """Resize an image to a fixed size, or according to max or min edge.""" 33 | h, w = image.shape[:2] 34 | if isinstance(size, int): 35 | scale = size / fn(h, w) 36 | h_new, w_new = int(round(h*scale)), int(round(w*scale)) 37 | # TODO: we should probably recompute the scale like in the second case 38 | scale = (scale, scale) 39 | elif isinstance(size, (tuple, list)): 40 | h_new, w_new = size 41 | scale = (w_new / w, h_new / h) 42 | else: 43 | raise ValueError(f'Incorrect new size: {size}') 44 | mode = { 45 | 'linear': cv2.INTER_LINEAR, 46 | 'cubic': cv2.INTER_CUBIC, 47 | 'nearest': cv2.INTER_NEAREST}[interp] 48 | return cv2.resize(image, (w_new, h_new), interpolation=mode), scale 49 | 50 | 51 | def crop(image, size, *, random=True, other=None, camera=None, 52 | return_bbox=False, centroid=None): 53 | """Random or deterministic crop of an image, adjust depth and intrinsics. 54 | """ 55 | h, w = image.shape[:2] 56 | h_new, w_new = (size, size) if isinstance(size, int) else size 57 | if random: 58 | top = np.random.randint(0, h - h_new + 1) 59 | left = np.random.randint(0, w - w_new + 1) 60 | elif centroid is not None: 61 | x, y = centroid 62 | top = np.clip(int(y) - h_new // 2, 0, h - h_new) 63 | left = np.clip(int(x) - w_new // 2, 0, w - w_new) 64 | else: 65 | top = left = 0 66 | 67 | image = image[top:top+h_new, left:left+w_new] 68 | ret = [image] 69 | if other is not None: 70 | ret += [other[top:top+h_new, left:left+w_new]] 71 | if camera is not None: 72 | ret += [camera.crop((left, top), (w_new, h_new))] 73 | if return_bbox: 74 | ret += [(top, top+h_new, left, left+w_new)] 75 | return ret 76 | 77 | 78 | def zero_pad(size, *images): 79 | ret = [] 80 | for image in images: 81 | h, w = image.shape[:2] 82 | padded = np.zeros((size, size)+image.shape[2:], dtype=image.dtype) 83 | padded[:h, :w] = image 84 | ret.append(padded) 85 | return ret 86 | 87 | 88 | def read_view(conf, image_path: Path, camera: Camera, T_w2cam: Pose, 89 | p3D: np.ndarray, p3D_idxs: np.ndarray, *, 90 | rotation=0, random=False): 91 | 92 | img = read_image(image_path, conf.grayscale) 93 | img = img.astype(np.float32) 94 | name = image_path.name 95 | 96 | # we assume that the pose and camera were already rotated during preprocess 97 | if rotation != 0: 98 | img = np.rot90(img, rotation) 99 | 100 | if conf.resize: 101 | scales = (1, 1) 102 | if conf.resize_by == 'max': 103 | img, scales = resize(img, conf.resize, fn=max) 104 | elif (conf.resize_by == 'min' or 105 | (conf.resize_by == 'min_if' 106 | and min(*img.shape[:2]) < conf.resize)): 107 | img, scales = resize(img, conf.resize, fn=min) 108 | if scales != (1, 1): 109 | camera = camera.scale(scales) 110 | 111 | if conf.crop: 112 | if conf.optimal_crop: 113 | p2D, valid = camera.world2image(T_w2cam * p3D[p3D_idxs]) 114 | p2D = p2D[valid].numpy() 115 | centroid = tuple(p2D.mean(0)) if len(p2D) > 0 else None 116 | random = False 117 | else: 118 | centroid = None 119 | img, camera, bbox = crop( 120 | img, conf.crop, random=random, 121 | camera=camera, return_bbox=True, centroid=centroid) 122 | elif conf.pad: 123 | img, = zero_pad(conf.pad, img) 124 | # we purposefully do not update the image size in the camera object 125 | 126 | data = { 127 | 'name': name, 128 | 'image': numpy_image_to_torch(img), 129 | 'camera': camera.float(), 130 | 'T_w2cam': T_w2cam.float(), 131 | } 132 | return data 133 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from .wrappers import Pose, Camera # noqa 2 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/check_jacobians.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | 4 | from . import Pose, Camera 5 | from .costs import DirectAbsoluteCost 6 | from .interpolation import Interpolator 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def compute_J(fn_J, inp): 12 | with torch.enable_grad(): 13 | return torch.autograd.functional.jacobian(fn_J, inp) 14 | 15 | 16 | def compute_J_batched(fn, inp): 17 | inp_ = inp.reshape(-1) 18 | fn_ = lambda x: fn(x.reshape(inp.shape)) # noqa 19 | J = compute_J(fn_, inp_) 20 | if len(J.shape) != 3: 21 | raise ValueError('Only supports a single leading batch dimension.') 22 | J = J.reshape(J.shape[:-1] + inp.shape) 23 | J = J.diagonal(dim1=0, dim2=-2).permute(2, 0, 1) 24 | return J 25 | 26 | 27 | def local_param(delta): 28 | dt, dw = delta.split(3, dim=-1) 29 | return Pose.from_aa(dw, dt) 30 | 31 | 32 | def toy_problem(seed=0, n_points=500): 33 | torch.random.manual_seed(seed) 34 | aa = torch.randn(3) / 10 35 | t = torch.randn(3) / 5 36 | T_w2q = Pose.from_aa(aa, t) 37 | 38 | w, h = 640, 480 39 | fx, fy = 300., 350. 40 | cx, cy = w/2, h/2 41 | radial = [0.1, 0.01] 42 | camera = Camera(torch.tensor([w, h, fx, fy, cx, cy] + radial)).float() 43 | torch.testing.assert_allclose((w, h), camera.size.long()) 44 | torch.testing.assert_allclose((fx, fy), camera.f) 45 | torch.testing.assert_allclose((cx, cy), camera.c) 46 | 47 | p3D = torch.randn(n_points, 3) 48 | p3D[:, -1] += 2 49 | 50 | dim = 16 51 | F_ref = torch.randn(n_points, dim) 52 | F_query = torch.randn(dim, h, w) 53 | 54 | return T_w2q, camera, p3D, F_ref, F_query 55 | 56 | 57 | def print_J_diff(prefix, J, J_auto): 58 | logger.info('Check J %s: pass=%r, max_diff=%e, shape=%r', 59 | prefix, 60 | torch.allclose(J, J_auto), 61 | torch.abs(J-J_auto).max(), 62 | tuple(J.shape)) 63 | 64 | 65 | def test_J_pose(T: Pose, p3D: torch.Tensor): 66 | J = T.J_transform(T * p3D) 67 | fn = lambda d: (local_param(d) @ T) * p3D # noqa 68 | delta = torch.zeros(6).to(p3D) 69 | J_auto = compute_J(fn, delta) 70 | print_J_diff('pose transform', J, J_auto) 71 | 72 | 73 | def test_J_undistort(camera: Camera, p3D: torch.Tensor): 74 | p2D, valid = camera.project(p3D) 75 | J = camera.J_undistort(p2D) 76 | J_auto = compute_J_batched(camera.undistort, p2D) 77 | J, J_auto = J[valid], J_auto[valid] 78 | print_J_diff('undistort', J, J_auto) 79 | 80 | 81 | def test_J_world2image(camera: Camera, p3D: torch.Tensor): 82 | _, valid = camera.world2image(p3D) 83 | J, _ = camera.J_world2image(p3D) 84 | J_auto = compute_J_batched(lambda x: camera.world2image(x)[0], p3D) 85 | J, J_auto = J[valid], J_auto[valid] 86 | print_J_diff('world2image', J, J_auto) 87 | 88 | 89 | def test_J_geometric_cost(T_w2q: Pose, camera: Camera, p3D: torch.Tensor): 90 | def forward(T): 91 | p3D_q = T * p3D 92 | p2D, visible = camera.world2image(p3D_q) 93 | return p2D, visible, p3D_q 94 | 95 | _, valid, p3D_q = forward(T_w2q) 96 | J = camera.J_world2image(p3D_q)[0] @ T_w2q.J_transform(p3D_q) 97 | delta = torch.zeros(6).to(p3D) 98 | fn = lambda d: forward(local_param(d) @ T_w2q)[0] # noqa 99 | J_auto = compute_J(fn, delta) 100 | J, J_auto = J[valid], J_auto[valid] 101 | print_J_diff('geometric cost', J, J_auto) 102 | 103 | 104 | def test_J_direct_absolute_cost(T_w2q: Pose, camera: Camera, p3D: torch.Tensor, 105 | F_ref, F_query): 106 | interpolator = Interpolator(mode='cubic', pad=2) 107 | cost = DirectAbsoluteCost(interpolator, normalize=True) 108 | 109 | args = (camera, p3D, F_ref, F_query) 110 | res, valid, weight, F_q_p2D, info = cost.residuals( 111 | T_w2q, *args, do_gradients=True) 112 | J, _ = cost.jacobian(T_w2q, camera, *info) 113 | 114 | delta = torch.zeros(6).to(p3D) 115 | fn = lambda d: cost.residuals(local_param(d) @ T_w2q, *args)[0] # noqa 116 | J_auto = compute_J(fn, delta) 117 | 118 | J, J_auto = J[valid], J_auto[valid] 119 | print_J_diff('direct absolute cost', J, J_auto) 120 | 121 | 122 | def main(): 123 | T_w2q, camera, p3D, F_ref, F_query = toy_problem() 124 | test_J_pose(T_w2q, p3D) 125 | test_J_undistort(camera, p3D) 126 | test_J_world2image(camera, p3D) 127 | 128 | # perform the checsk in double precision to factor out numerical errors 129 | T_w2q, camera, p3D, F_ref, F_query = ( 130 | x.to(torch.double) for x in (T_w2q, camera, p3D, F_ref, F_query)) 131 | 132 | test_J_geometric_cost(T_w2q, camera, p3D) 133 | test_J_direct_absolute_cost(T_w2q, camera, p3D, F_ref, F_query) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/costs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional, Tuple 3 | from torch import Tensor 4 | 5 | from . import Pose, Camera 6 | from .optimization import J_normalization 7 | from .interpolation import Interpolator 8 | 9 | 10 | class DirectAbsoluteCost: 11 | def __init__(self, interpolator: Interpolator, normalize: bool = True): 12 | self.interpolator = interpolator 13 | self.normalize = normalize 14 | 15 | # def residuals( 16 | # self, T_w2q: Pose, camera: Camera, p3D: Tensor, 17 | # F_ref: Tensor, F_query: Tensor, 18 | # confidences: Optional[Tuple[Tensor, Tensor]] = None, 19 | # do_gradients: bool = False): 20 | # 21 | # p3D_q = T_w2q * p3D 22 | # p2D, visible = camera.world2image(p3D_q) 23 | # F_p2D_raw, valid, gradients = self.interpolator( 24 | # F_query, p2D, return_gradients=do_gradients) 25 | # valid = valid & visible 26 | # 27 | # if confidences is not None: 28 | # C_ref, C_query = confidences 29 | # C_query_p2D, _, _ = self.interpolator( 30 | # C_query, p2D, return_gradients=False) 31 | # if C_ref is not None: 32 | # weight = C_ref * C_query_p2D 33 | # else: 34 | # weight = C_query_p2D 35 | # weight = weight.squeeze(-1).masked_fill(~valid, 0.) 36 | # else: 37 | # weight = None 38 | # 39 | # if self.normalize: 40 | # F_p2D = torch.nn.functional.normalize(F_p2D_raw, dim=-1) 41 | # else: 42 | # F_p2D = F_p2D_raw 43 | # 44 | # res = F_p2D - F_ref 45 | # info = (p3D_q, F_p2D_raw, gradients) 46 | # return res, valid, weight, F_p2D, info 47 | def residuals( 48 | self, T_q2r: Pose, camera: Camera, p3D: Tensor, 49 | F_ref: Tensor, F_query: Tensor, 50 | confidences: Optional[Tuple[Tensor, Tensor, int]] = None, 51 | do_gradients: bool = False): 52 | 53 | p3D_r = T_q2r * p3D # q_3d to ref_3d 54 | p2D, visible = camera.world2image(p3D_r) # ref_3d to ref_2d 55 | F_p2D_raw, valid, gradients = self.interpolator( 56 | F_ref, p2D, return_gradients=do_gradients) # get ref 2d features 57 | valid = valid & visible 58 | 59 | C_ref, C_query, C_count = confidences 60 | 61 | C_ref_p2D, _, _ = self.interpolator(C_ref, p2D, return_gradients=False) # get ref 2d confidence 62 | 63 | # the first confidence 64 | weight = C_ref_p2D[:, :, 0] * C_query[:, :, 0] 65 | if C_count > 1: 66 | grd_weight = C_ref_p2D[:, :, 1].detach() * C_query[:, :, 1] 67 | weight = weight * grd_weight 68 | # if C2_start == 0: 69 | # # only grd confidence: 70 | # # do not gradiant back to ref confidence 71 | # weight = C_ref_p2D[:, :, 0].detach() * C_query[:, :, 0] 72 | # else: 73 | # weight = C_ref_p2D[:,:,0] * C_query[:,:,0] 74 | # # the second confidence 75 | # if C_query.shape[-1] > 1: 76 | # grd_weight = C_ref_p2D[:, :, 1].detach() * C_query[:, :, 1] 77 | # grd_weight = torch.cat([torch.ones_like(grd_weight[:, :C2_start]), grd_weight[:, C2_start:]], dim=1) 78 | # weight = weight * grd_weight 79 | 80 | if weight != None: 81 | weight = weight.masked_fill(~(valid), 0.) 82 | #weight = torch.nn.functional.normalize(weight, p=float('inf'), dim=1) #?? 83 | 84 | if self.normalize: # huge memory 85 | F_p2D = torch.nn.functional.normalize(F_p2D_raw, dim=-1) 86 | else: 87 | F_p2D = F_p2D_raw 88 | 89 | res = F_p2D - F_query 90 | info = (p3D_r, F_p2D, gradients) # ref information 91 | return res, valid, weight, F_p2D, info 92 | 93 | # def jacobian( 94 | # self, T_w2q: Pose, camera: Camera, 95 | # p3D_q: Tensor, F_p2D_raw: Tensor, J_f_p2D: Tensor): 96 | # 97 | # J_p3D_T = T_w2q.J_transform(p3D_q) 98 | # J_p2D_p3D, _ = camera.J_world2image(p3D_q) 99 | # 100 | # if self.normalize: 101 | # J_f_p2D = J_normalization(F_p2D_raw) @ J_f_p2D 102 | # 103 | # J_p2D_T = J_p2D_p3D @ J_p3D_T 104 | # J = J_f_p2D @ J_p2D_T 105 | # return J, J_p2D_T 106 | def jacobian( 107 | self, T_q2r: Pose, camera: Camera, 108 | p3D_r: Tensor, F_p2D_raw: Tensor, J_f_p2D: Tensor): 109 | 110 | J_p3D_T = T_q2r.J_transform(p3D_r) 111 | J_p2D_p3D, _ = camera.J_world2image(p3D_r) 112 | 113 | if self.normalize: 114 | J_f_p2D = J_normalization(F_p2D_raw) @ J_f_p2D 115 | 116 | J_p2D_T = J_p2D_p3D @ J_p3D_T 117 | J = J_f_p2D @ J_p2D_T 118 | return J, J_p2D_T 119 | 120 | # def residual_jacobian( 121 | # self, T_w2q: Pose, camera: Camera, p3D: Tensor, 122 | # F_ref: Tensor, F_query: Tensor, 123 | # confidences: Optional[Tuple[Tensor, Tensor]] = None): 124 | # 125 | # res, valid, weight, F_p2D, info = self.residuals( 126 | # T_w2q, camera, p3D, F_ref, F_query, confidences, True) 127 | # J, _ = self.jacobian(T_w2q, camera, *info) 128 | # return res, valid, weight, F_p2D, J 129 | def residual_jacobian( 130 | self, T_q2r: Pose, camera: Camera, p3D: Tensor, 131 | F_ref: Tensor, F_query: Tensor, 132 | confidences: Optional[Tuple[Tensor, Tensor]] = None): 133 | 134 | res, valid, weight, F_p2D, info = self.residuals( 135 | T_q2r, camera, p3D, F_ref, F_query, confidences, True) 136 | J, _ = self.jacobian(T_q2r, camera, *info) 137 | return res, valid, weight, F_p2D, J 138 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/interpolation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | from typing import Tuple 5 | 6 | 7 | @torch.jit.script 8 | def interpolate_tensor_bicubic(tensor, pts, return_gradients: bool = False): 9 | # According to R. Keys "Cubic convolution interpolation for digital image processing". 10 | # references: 11 | # https://github.com/ceres-solver/ceres-solver/blob/master/include/ceres/cubic_interpolation.h 12 | # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/UpSample.h#L285 13 | # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/UpSampleBicubic2d.cpp#L63 14 | # https://github.com/ceres-solver/ceres-solver/blob/master/include/ceres/cubic_interpolation.h 15 | spline_base = torch.tensor([[-1, 2, -1, 0], 16 | [3, -5, 0, 2], 17 | [-3, 4, 1, 0], 18 | [1, -1, 0, 0]]).float() / 2 19 | 20 | # This is the original written by Måns, does not seem consistent with OpenCV remap 21 | # spline_base = torch.tensor([[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 0, 3, 0], [1, 4, 1, 0]]).float().T / 6 22 | spline_base = spline_base.to(tensor) 23 | 24 | pts_0 = torch.floor(pts) 25 | res = pts - pts_0 26 | x, y = pts_0[:, 0], pts_0[:, 1] 27 | 28 | c, h, w = tensor.shape 29 | f_patches = torch.zeros((c, len(pts), 4, 4)).to(tensor) 30 | # TODO: could we do this faster with gather or grid_sampler nearest? 31 | for i in [-1, 0, 1, 2]: 32 | for j in [-1, 0, 1, 2]: 33 | x_ = (x+j).long().clamp(min=0, max=w-1).long() 34 | y_ = (y+i).long().clamp(min=0, max=h-1).long() 35 | f_patches[:, :, i+1, j+1] = tensor[:, y_, x_] 36 | 37 | t = torch.stack([res**3, res**2, res, torch.ones_like(res)], -1) 38 | coeffs = torch.einsum('mk,nck->cnm', spline_base, t) 39 | coeffs_x, coeffs_y = coeffs[0], coeffs[1] 40 | interp = torch.einsum('ni,nj,cnij->nc', coeffs_y, coeffs_x, f_patches) 41 | 42 | if return_gradients: 43 | dt_xy = torch.stack([ 44 | 3*res**2, 2*res, torch.ones_like(res), torch.zeros_like(res)], -1) 45 | B_dt_xy = torch.einsum('mk,nck->cnm', spline_base, dt_xy) 46 | B_dt_x, B_dt_y = B_dt_xy[0], B_dt_xy[1] 47 | 48 | J_out_x = torch.einsum('ni,nj,cnij->nc', coeffs_y, B_dt_x, f_patches) 49 | J_out_y = torch.einsum('ni,nj,cnij->nc', B_dt_y, coeffs_x, f_patches) 50 | J_out_xy = torch.stack([J_out_x, J_out_y], -1) 51 | else: 52 | J_out_xy = torch.zeros(len(pts), c, 2).to(interp) 53 | 54 | return interp, J_out_xy 55 | 56 | 57 | @torch.jit.script 58 | def interpolate_tensor_bilinear(tensor, pts, return_gradients: bool = False): 59 | if tensor.dim() == 3: 60 | assert pts.dim() == 2 61 | batched = False 62 | tensor, pts = tensor[None], pts[None] 63 | else: 64 | batched = True 65 | 66 | b, c, h, w = tensor.shape 67 | scale = torch.tensor([w-1, h-1]).to(pts) 68 | pts = (pts / scale) * 2 - 1 69 | pts = pts.clamp(min=-2, max=2) # ideally use the mask instead 70 | interpolated = torch.nn.functional.grid_sample( 71 | tensor, pts[:, None], mode='bilinear', align_corners=True) 72 | interpolated = interpolated.reshape(b, c, -1).transpose(-1, -2) 73 | 74 | if return_gradients: 75 | dxdy = torch.tensor([[1, 0], [0, 1]])[:, None].to(pts) / scale * 2 76 | dx, dy = dxdy.chunk(2, dim=0) 77 | pts_d = torch.cat([pts-dx, pts+dx, pts-dy, pts+dy], 1) 78 | tensor_d = torch.nn.functional.grid_sample( 79 | tensor, pts_d[:, None], mode='bilinear', align_corners=True) 80 | tensor_d = tensor_d.reshape(b, c, -1).transpose(-1, -2) 81 | tensor_x0, tensor_x1, tensor_y0, tensor_y1 = tensor_d.chunk(4, dim=1) 82 | gradients = torch.stack([ 83 | (tensor_x1 - tensor_x0)/2, (tensor_y1 - tensor_y0)/2], dim=-1) 84 | else: 85 | gradients = torch.zeros(b, pts.shape[1], c, 2).to(tensor) 86 | 87 | if not batched: 88 | interpolated, gradients = interpolated[0], gradients[0] 89 | return interpolated, gradients 90 | 91 | 92 | def mask_in_image(pts, image_size: Tuple[int, int], pad: int = 1): 93 | w, h = image_size 94 | image_size_ = torch.tensor([w-pad-1, h-pad-1]).to(pts) 95 | return torch.all((pts >= pad) & (pts <= image_size_), -1) 96 | 97 | 98 | @torch.jit.script 99 | def interpolate_tensor(tensor, pts, mode: str = 'linear', 100 | pad: int = 1, return_gradients: bool = False): 101 | '''Interpolate a 3D tensor at given 2D locations. 102 | Args: 103 | tensor: with shape (C, H, W) or (B, C, H, W). 104 | pts: points with shape (N, 2) or (B, N, 2) 105 | mode: interpolation mode, `'linear'` or `'cubic'` 106 | pad: padding for the returned mask of valid keypoints 107 | return_gradients: whether to return the first derivative 108 | of the interpolated values (currentl only in cubic mode). 109 | Returns: 110 | tensor: with shape (N, C) or (B, N, C) 111 | mask: boolean mask, true if pts are in [pad, W-1-pad] x [pad, H-1-pad] 112 | gradients: (N, C, 2) or (B, N, C, 2), 0-filled if not return_gradients 113 | ''' 114 | h, w = tensor.shape[-2:] 115 | if mode == 'cubic': 116 | pad += 1 # bicubic needs one more pixel on each side 117 | mask = mask_in_image(pts, (w, h), pad=pad) 118 | # Ideally we want to use mask to clamp outlier pts before interpolationm 119 | # but this line throws some obscure errors about inplace ops. 120 | # pts = pts.masked_fill(mask.unsqueeze(-1), 0.) 121 | 122 | if mode == 'cubic': 123 | interpolated, gradients = interpolate_tensor_bicubic( 124 | tensor, pts, return_gradients) 125 | elif mode == 'linear': 126 | interpolated, gradients = interpolate_tensor_bilinear( 127 | tensor, pts, return_gradients) 128 | else: 129 | raise NotImplementedError(mode) 130 | return interpolated, mask, gradients 131 | 132 | 133 | class Interpolator: 134 | def __init__(self, mode: str = 'linear', pad: int = 1): 135 | self.mode = mode 136 | self.pad = pad 137 | 138 | def __call__(self, tensor: torch.Tensor, pts: torch.Tensor, 139 | return_gradients: bool = False): 140 | return interpolate_tensor( 141 | tensor, pts, self.mode, self.pad, return_gradients) 142 | 143 | 144 | def test_interpolate_cubic_opencv(f, pts): 145 | interp = interpolate_tensor_bicubic(f, pts)[0].cpu().numpy() 146 | interp_linear = interpolate_tensor(f, pts)[0].cpu().numpy() 147 | 148 | pts_ = pts.cpu().numpy() 149 | interp_cv2_cubic = [] 150 | interp_cv2_linear = [] 151 | for f_i in f.cpu().numpy(): 152 | interp_i = cv2.remap(f_i, pts_[None], None, cv2.INTER_CUBIC)[0] 153 | interp_cv2_cubic.append(interp_i) 154 | interp_i = cv2.remap(f_i, pts_[None], None, cv2.INTER_LINEAR)[0] 155 | interp_cv2_linear.append(interp_i) 156 | interp_cv2_cubic = np.stack(interp_cv2_cubic, -1) 157 | interp_cv2_linear = np.stack(interp_cv2_linear, -1) 158 | 159 | diff = np.abs(interp - interp_cv2_cubic) 160 | print('OpenCV cubic vs custom cubic:') 161 | print('Mean/med/max abs diff', np.mean(diff), np.median(diff), np.max(diff)) 162 | print('Rel diff', np.median(diff/np.abs(interp_cv2_cubic))*100, '%') 163 | 164 | diff = np.abs(interp_cv2_linear - interp_cv2_cubic) 165 | print('OpenCV cubic vs linear:') 166 | print('Mean/med/max abs diff', np.mean(diff), np.median(diff), np.max(diff)) 167 | print('Rel diff', np.median(diff/np.abs(interp_cv2_cubic))*100, '%') 168 | 169 | diff = np.abs(interp_linear - interp_cv2_linear) 170 | print('OpenCV linear vs grid sample:') 171 | print('Mean/med/max abs diff', np.mean(diff), np.median(diff), np.max(diff)) 172 | print('Rel diff', np.median(diff/np.abs(interp_cv2_linear))*100, '%') 173 | 174 | 175 | def test_interpolate_cubic_gradients(tensor, pts): 176 | def compute_J(fn_J, inp): 177 | with torch.enable_grad(): 178 | return torch.autograd.functional.jacobian(fn_J, inp) 179 | 180 | tensor, pts = tensor.double(), pts.double() 181 | 182 | _, J_analytical = interpolate_tensor_bicubic( 183 | tensor, pts, return_gradients=True) 184 | 185 | J = compute_J( 186 | lambda xy: interpolate_tensor_bicubic(tensor, xy.reshape(-1, 2))[0], 187 | pts.reshape(-1)) 188 | J = J.reshape(J.shape[:2]+(-1, 2)) 189 | J = J[range(len(pts)), :, range(len(pts)), :] 190 | 191 | print('Gradients consistent with autograd:', 192 | torch.allclose(J_analytical, J)) 193 | 194 | 195 | def test_run_all(seed=0): 196 | torch.random.manual_seed(seed) 197 | w, h = 480, 240 198 | 199 | pts = torch.rand(1000, 2) * torch.tensor([w-1, h-1]) 200 | tensor = torch.rand(16, h, w)*100 201 | 202 | test_interpolate_cubic_opencv(tensor, pts) 203 | test_interpolate_cubic_gradients(tensor, pts) 204 | 205 | 206 | if __name__ == '__main__': 207 | test_run_all() 208 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generic losses and error functions for optimization or training deep networks. 3 | """ 4 | 5 | import torch 6 | 7 | 8 | def scaled_loss(x, fn, a): 9 | """Apply a loss function to a tensor and pre- and post-scale it. 10 | Args: 11 | x: the data tensor, should already be squared: `x = y**2`. 12 | fn: the loss function, with signature `fn(x) -> y`. 13 | a: the scale parameter. 14 | Returns: 15 | The value of the loss, and its first and second derivatives. 16 | """ 17 | a2 = a**2 18 | loss, loss_d1, loss_d2 = fn(x/a2) 19 | return loss*a2, loss_d1, loss_d2/a2 20 | 21 | 22 | def squared_loss(x): 23 | """A dummy squared loss.""" 24 | return x, torch.ones_like(x), torch.zeros_like(x) 25 | 26 | 27 | def huber_loss(x): 28 | """The classical robust Huber loss, with first and second derivatives.""" 29 | mask = x <= 1 30 | sx = torch.sqrt(x) 31 | isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1/sx) 32 | loss = torch.where(mask, x, 2*sx-1) 33 | loss_d1 = torch.where(mask, torch.ones_like(x), isx) 34 | loss_d2 = torch.where(mask, torch.zeros_like(x), -isx/(2*x)) 35 | return loss, loss_d1, loss_d2 36 | 37 | 38 | def barron_loss(x, alpha, derivatives: bool = True, eps: float = 1e-7): 39 | """Parameterized & adaptive robust loss function. 40 | Described in: 41 | A General and Adaptive Robust Loss Function, Barron, CVPR 2019 42 | 43 | Contrary to the original implementation, assume the the input is already 44 | squared and scaled (basically scale=1). Computes the first derivative, but 45 | not the second (TODO if needed). 46 | """ 47 | loss_two = x 48 | loss_zero = 2 * torch.log1p(torch.clamp(0.5*x, max=33e37)) 49 | 50 | # The loss when not in one of the above special cases. 51 | # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. 52 | beta_safe = torch.abs(alpha - 2.).clamp(min=eps) 53 | # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. 54 | alpha_safe = torch.where( 55 | alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha)) 56 | alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps) 57 | 58 | loss_otherwise = 2 * (beta_safe / alpha_safe) * ( 59 | torch.pow(x / beta_safe + 1., 0.5 * alpha) - 1.) 60 | 61 | # Select which of the cases of the loss to return. 62 | loss = torch.where( 63 | alpha == 0, loss_zero, 64 | torch.where(alpha == 2, loss_two, loss_otherwise)) 65 | dummy = torch.zeros_like(x) 66 | 67 | if derivatives: 68 | loss_two_d1 = torch.ones_like(x) 69 | loss_zero_d1 = 2 / (x + 2) 70 | loss_otherwise_d1 = torch.pow(x / beta_safe + 1., 0.5 * alpha - 1.) 71 | loss_d1 = torch.where( 72 | alpha == 0, loss_zero_d1, 73 | torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1)) 74 | 75 | return loss, loss_d1, dummy 76 | else: 77 | return loss, dummy, dummy 78 | 79 | 80 | def scaled_barron(a, c): 81 | return lambda x: scaled_loss( 82 | x, lambda y: barron_loss(y, y.new_tensor(a)), c) 83 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/optimization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from packaging import version 4 | 5 | logger = logging.getLogger(__name__) 6 | 7 | if version.parse(torch.__version__) >= version.parse('1.9'): 8 | cholesky = torch.linalg.cholesky 9 | else: 10 | cholesky = torch.cholesky 11 | 12 | def optimizer_step(g, H, lambda_=0, mute=False, mask=None, eps=1e-6): 13 | """One optimization step with Gauss-Newton or Levenberg-Marquardt. 14 | Args: 15 | g: batched gradient tensor of size (..., N). 16 | H: batched hessian tensor of size (..., N, N). 17 | lambda_: damping factor for LM (use GN if lambda_=0). 18 | mask: denotes valid elements of the batch (optional). 19 | """ 20 | if lambda_ is 0: # noqa 21 | diag = torch.zeros_like(g) 22 | else: 23 | # select 3DOF lambda, tx,ty,Rz 24 | # idx = torch.tensor([0,2,4], device=g.device) 25 | # lambda_ = torch.index_select(lambda_, 0, idx) 26 | 27 | diag = H.diagonal(dim1=-2, dim2=-1) * lambda_ 28 | H = H + diag.clamp(min=eps).diag_embed() 29 | 30 | if mask is not None: 31 | # make sure that masked elements are not singular 32 | H = torch.where(mask[..., None, None], H, torch.eye(H.shape[-1]).to(H)) 33 | # set g to 0 to delta is 0 for masked elements 34 | g = g.masked_fill(~mask[..., None], 0.) 35 | 36 | # add by shan 37 | if torch.isnan(g).any() or torch.isnan(H).any(): 38 | print('nan in g or H, return 0 delta') 39 | delta = torch.zeros_like(g) 40 | return delta.to(H.device) 41 | 42 | H_, g_ = H.cpu(), g.cpu() 43 | 44 | try: 45 | #U = torch.linalg.cholesky(H_.transpose(-2, -1).conj()).transpose(-2, -1).conj() 46 | U = cholesky(H_) 47 | except RuntimeError as e: 48 | if 'singular U' in str(e): 49 | if not mute: 50 | logger.debug( 51 | 'Cholesky decomposition failed, fallback to LU.') 52 | #delta = -torch.solve(g_[..., None], H_)[0][..., 0] 53 | delta = -torch.linalg.solve(H_, g_[..., None])[0][..., 0] 54 | else: 55 | raise 56 | else: 57 | delta = -torch.cholesky_solve(g_[..., None], U)[..., 0] 58 | 59 | return delta.to(H.device) 60 | 61 | 62 | def skew_symmetric(v): 63 | """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). 64 | """ 65 | z = torch.zeros_like(v[..., 0]) 66 | M = torch.stack([ 67 | z, -v[..., 2], v[..., 1], 68 | v[..., 2], z, -v[..., 0], 69 | -v[..., 1], v[..., 0], z, 70 | ], dim=-1).reshape(v.shape[:-1]+(3, 3)) 71 | return M 72 | 73 | 74 | def so3exp_map(w, eps: float = 1e-7): 75 | """Compute rotation matrices from batched twists. 76 | Args: 77 | w: batched 3D axis-angle vectors of size (..., 3). 78 | Returns: 79 | A batch of rotation matrices of size (..., 3, 3). 80 | """ 81 | theta = w.norm(p=2, dim=-1, keepdim=True) 82 | small = theta < eps 83 | div = torch.where(small, torch.ones_like(theta), theta) 84 | W = skew_symmetric(w / div) 85 | theta = theta[..., None] # ... x 1 x 1 86 | res = W * torch.sin(theta) + (W @ W) * (1 - torch.cos(theta)) 87 | res = torch.where(small[..., None], W, res) # first-order Taylor approx 88 | return torch.eye(3).to(W) + res 89 | 90 | 91 | def J_normalization(x): 92 | """Jacobian of the L2 normalization, assuming that we normalize 93 | along the last dimension. 94 | """ 95 | x_normed = torch.nn.functional.normalize(x, dim=-1) 96 | norm = torch.norm(x, p=2, dim=-1, keepdim=True) 97 | 98 | Id = torch.diag_embed(torch.ones_like(x_normed)) 99 | J = (Id - x_normed.unsqueeze(-1) @ x_normed.unsqueeze(-2)) 100 | J = J / norm.unsqueeze(-1) 101 | return J 102 | -------------------------------------------------------------------------------- /PureACL/pixlib/geometry/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of geometry tools for PyTorch tensors and sometimes NumPy arrays. 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | 8 | 9 | def to_homogeneous(points): 10 | """Convert N-dimensional points to homogeneous coordinates. 11 | Args: 12 | points: torch.Tensor or numpy.ndarray with size (..., N). 13 | Returns: 14 | A torch.Tensor or numpy.ndarray with size (..., N+1). 15 | """ 16 | if isinstance(points, torch.Tensor): 17 | pad = points.new_ones(points.shape[:-1]+(1,)) 18 | return torch.cat([points, pad], dim=-1) 19 | elif isinstance(points, np.ndarray): 20 | pad = np.ones((points.shape[:-1]+(1,)), dtype=points.dtype) 21 | return np.concatenate([points, pad], axis=-1) 22 | else: 23 | raise ValueError 24 | 25 | 26 | def from_homogeneous(points): 27 | """Remove the homogeneous dimension of N-dimensional points. 28 | Args: 29 | points: torch.Tensor or numpy.ndarray with size (..., N+1). 30 | Returns: 31 | A torch.Tensor or numpy ndarray with size (..., N). 32 | """ 33 | return points[..., :-1] / points[..., -1:] 34 | 35 | 36 | @torch.jit.script 37 | def undistort_points(pts, dist): 38 | '''Undistort normalized 2D coordinates 39 | and check for validity of the distortion model. 40 | ''' 41 | dist = dist.unsqueeze(-2) # add point dimension 42 | ndist = dist.shape[-1] 43 | undist = pts 44 | valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool) 45 | if ndist > 0: 46 | k1, k2 = dist[..., :2].split(1, -1) 47 | r2 = torch.sum(pts**2, -1, keepdim=True) 48 | radial = k1*r2 + k2*r2**2 49 | undist = undist + pts * radial 50 | 51 | # The distortion model is supposedly only valid within the image 52 | # boundaries. Because of the negative radial distortion, points that 53 | # are far outside of the boundaries might actually be mapped back 54 | # within the image. To account for this, we discard points that are 55 | # beyond the inflection point of the distortion model, 56 | # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0 57 | limited = ((k2 > 0) & ((9*k1**2-20*k2) > 0)) | ((k2 <= 0) & (k1 > 0)) 58 | limit = torch.abs(torch.where( 59 | k2 > 0, (torch.sqrt(9*k1**2-20*k2)-3*k1)/(10*k2), 1/(3*k1))) 60 | valid = valid & torch.squeeze(~limited | (r2 < limit), -1) 61 | 62 | if ndist > 2: 63 | p12 = dist[..., 2:] 64 | p21 = p12.flip(-1) 65 | uv = torch.prod(pts, -1, keepdim=True) 66 | undist = undist + 2*p12*uv + p21*(r2 + 2*pts**2) 67 | # TODO: handle tangential boundaries 68 | 69 | return undist, valid 70 | 71 | 72 | @torch.jit.script 73 | def J_undistort_points(pts, dist): 74 | dist = dist.unsqueeze(-2) # add point dimension 75 | ndist = dist.shape[-1] 76 | 77 | J_diag = torch.ones_like(pts) 78 | J_cross = torch.zeros_like(pts) 79 | if ndist > 0: 80 | k1, k2 = dist[..., :2].split(1, -1) 81 | r2 = torch.sum(pts**2, -1, keepdim=True) 82 | uv = torch.prod(pts, -1, keepdim=True) 83 | radial = k1*r2 + k2*r2**2 84 | d_radial = (2*k1 + 4*k2*r2) 85 | J_diag += radial + (pts**2)*d_radial 86 | J_cross += uv*d_radial 87 | 88 | if ndist > 2: 89 | p12 = dist[..., 2:] 90 | p21 = p12.flip(-1) 91 | J_diag += 2*p12*pts.flip(-1) + 6*p21*pts 92 | J_cross += 2*p12*pts + 2*p21*pts.flip(-1) 93 | 94 | J = torch.diag_embed(J_diag) + torch.diag_embed(J_cross).flip(-1) 95 | return J 96 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from PureACL.pixlib.utils.tools import get_class 2 | from PureACL.pixlib.models.base_model import BaseModel 3 | 4 | 5 | def get_model(name): 6 | return get_class(name, __name__, BaseModel) 7 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base class for trainable models. 3 | """ 4 | 5 | from abc import ABCMeta, abstractmethod 6 | import omegaconf 7 | from omegaconf import OmegaConf 8 | from torch import nn 9 | from copy import copy 10 | 11 | 12 | class MetaModel(ABCMeta): 13 | def __prepare__(name, bases, **kwds): 14 | total_conf = OmegaConf.create() 15 | for base in bases: 16 | for key in ('base_default_conf', 'default_conf'): 17 | update = getattr(base, key, {}) 18 | if isinstance(update, dict): 19 | update = OmegaConf.create(update) 20 | total_conf = OmegaConf.merge(total_conf, update) 21 | return dict(base_default_conf=total_conf) 22 | 23 | 24 | class BaseModel(nn.Module, metaclass=MetaModel): 25 | """ 26 | What the child model is expect to declare: 27 | default_conf: dictionary of the default configuration of the model. 28 | It recursively updates the default_conf of all parent classes, and 29 | it is updated by the user-provided configuration passed to __init__. 30 | Configurations can be nested. 31 | 32 | required_data_keys: list of expected keys in the input data dictionary. 33 | 34 | strict_conf (optional): boolean. If false, BaseModel does not raise 35 | an error when the user provides an unknown configuration entry. 36 | 37 | _init(self, conf): initialization method, where conf is the final 38 | configuration object (also accessible with `self.conf`). Accessing 39 | unkown configuration entries will raise an error. 40 | 41 | _forward(self, data): method that returns a dictionary of batched 42 | prediction tensors based on a dictionary of batched input data tensors. 43 | 44 | loss(self, pred, data): method that returns a dictionary of losses, 45 | computed from model predictions and input data. Each loss is a batch 46 | of scalars, i.e. a torch.Tensor of shape (B,). 47 | The total loss to be optimized has the key `'total'`. 48 | 49 | metrics(self, pred, data): method that returns a dictionary of metrics, 50 | each as a batch of scalars. 51 | """ 52 | default_conf = { 53 | 'name': None, 54 | 'trainable': True, # if false: do not optimize this model parameters 55 | 'freeze_batch_normalization': False, # use test-time statistics 56 | } 57 | required_data_keys = [] 58 | strict_conf = True 59 | 60 | def __init__(self, conf): 61 | """Perform some logic and call the _init method of the child model.""" 62 | super().__init__() 63 | default_conf = OmegaConf.merge( 64 | self.base_default_conf, OmegaConf.create(self.default_conf)) 65 | if self.strict_conf: 66 | OmegaConf.set_struct(default_conf, True) 67 | 68 | # fixme: backward compatibility 69 | if 'pad' in conf and 'pad' not in default_conf: # backward compat. 70 | with omegaconf.read_write(conf): 71 | with omegaconf.open_dict(conf): 72 | conf['interpolation'] = {'pad': conf.pop('pad')} 73 | 74 | if isinstance(conf, dict): 75 | conf = OmegaConf.create(conf) 76 | self.conf = conf = OmegaConf.merge(default_conf, conf) 77 | OmegaConf.set_readonly(conf, True) 78 | OmegaConf.set_struct(conf, True) 79 | self.required_data_keys = copy(self.required_data_keys) 80 | self._init(conf) 81 | 82 | if not conf.trainable: 83 | for p in self.parameters(): 84 | p.requires_grad = False 85 | 86 | def train(self, mode=True): 87 | super().train(mode) 88 | 89 | def freeze_bn(module): 90 | if isinstance(module, nn.modules.batchnorm._BatchNorm): 91 | module.eval() 92 | if self.conf.freeze_batch_normalization: 93 | self.apply(freeze_bn) 94 | 95 | return self 96 | 97 | def forward(self, data): 98 | """Check the data and call the _forward method of the child model.""" 99 | def recursive_key_check(expected, given): 100 | for key in expected: 101 | assert key in given, f'Missing key {key} in data' 102 | if isinstance(expected, dict): 103 | recursive_key_check(expected[key], given[key]) 104 | 105 | recursive_key_check(self.required_data_keys, data) 106 | return self._forward(data) 107 | 108 | @abstractmethod 109 | def _init(self, conf): 110 | """To be implemented by the child class.""" 111 | raise NotImplementedError 112 | 113 | @abstractmethod 114 | def _forward(self, data): 115 | """To be implemented by the child class.""" 116 | raise NotImplementedError 117 | 118 | @abstractmethod 119 | def loss(self, pred, data): 120 | """To be implemented by the child class.""" 121 | raise NotImplementedError 122 | 123 | @abstractmethod 124 | def metrics(self, pred, data): 125 | """To be implemented by the child class.""" 126 | raise NotImplementedError 127 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/base_optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements a simple differentiable optimizer based on Levenberg-Marquardt 3 | with a constant, scalar damping factor and a fixed number of iterations. 4 | """ 5 | 6 | import logging 7 | from typing import Tuple, Dict, Optional 8 | import torch 9 | from torch import Tensor 10 | 11 | from .base_model import BaseModel 12 | from .utils import masked_mean 13 | from ..geometry import Camera, Pose 14 | from ..geometry.optimization import optimizer_step 15 | from ..geometry.interpolation import Interpolator 16 | from ..geometry.costs import DirectAbsoluteCost 17 | from ..geometry import losses # noqa 18 | from ...utils.tools import torchify 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class BaseOptimizer(BaseModel): 24 | default_conf = dict( 25 | num_iters=15, #100, 26 | loss_fn='scaled_barron(0, 0.1)', #'squared_loss', 27 | jacobi_scaling=False, 28 | normalize_features=False, 29 | lambda_=0.01, #0, # Gauss-Newton 30 | interpolation=dict( 31 | mode='linear', 32 | pad=4, 33 | ), 34 | grad_stop_criteria=1e-4, 35 | dt_stop_criteria=5e-3, # in meters 36 | dR_stop_criteria=5e-2, # in degrees 37 | 38 | # deprecated entries 39 | sqrt_diag_damping=False, 40 | bound_confidence=True, 41 | no_conditions=True, 42 | verbose=False, 43 | ) 44 | logging_fn = None 45 | 46 | def _init(self, conf): 47 | self.loss_fn = eval('losses.' + conf.loss_fn) 48 | self.interpolator = Interpolator(**conf.interpolation) 49 | self.cost_fn = DirectAbsoluteCost(self.interpolator, 50 | normalize=conf.normalize_features) 51 | assert conf.lambda_ >= 0. 52 | # deprecated entries 53 | assert not conf.sqrt_diag_damping 54 | assert conf.bound_confidence 55 | assert conf.no_conditions 56 | assert not conf.verbose 57 | 58 | def log(self, **args): 59 | if self.logging_fn is not None: 60 | self.logging_fn(**args) 61 | 62 | def early_stop(self, **args): 63 | stop = False 64 | if not self.training and (args['i'] % 10) == 0: 65 | T_delta, grad = args['T_delta'], args['grad'] 66 | grad_norm = torch.norm(grad.detach(), dim=-1) 67 | small_grad = grad_norm < self.conf.grad_stop_criteria 68 | dR, dt = T_delta.magnitude() 69 | small_step = ((dt < self.conf.dt_stop_criteria) 70 | & (dR < self.conf.dR_stop_criteria)) 71 | #if torch.all(small_step | small_grad): 72 | if torch.all(small_step): 73 | stop = True 74 | return stop 75 | 76 | def J_scaling(self, J: Tensor, J_scaling: Tensor, valid: Tensor): 77 | if J_scaling is None: 78 | J_norm = torch.norm(J.detach(), p=2, dim=(-2)) 79 | J_norm = masked_mean(J_norm, valid[..., None], -2) 80 | J_scaling = 1 / (1 + J_norm) 81 | J = J * J_scaling[..., None, None, :] 82 | return J, J_scaling 83 | 84 | def build_system(self, J: Tensor, res: Tensor, weights: Tensor): 85 | grad = torch.einsum('...ndi,...nd->...ni', J, res) # ... x N x 6 86 | grad = weights[..., None] * grad 87 | grad = grad.sum(-2) # ... x 6 88 | 89 | Hess = torch.einsum('...ijk,...ijl->...ikl', J, J) # ... x N x 6 x 6 90 | Hess = weights[..., None, None] * Hess 91 | Hess = Hess.sum(-3) # ... x 6 x6 92 | 93 | return grad, Hess 94 | 95 | def _forward(self, data: Dict): 96 | return self._run( 97 | data['p3D'], data['F_ref'], data['F_q'], data['T_init'], 98 | data['camera'], data['mask'], data.get('W_ref_q')) 99 | 100 | @torchify 101 | def run(self, *args, **kwargs): 102 | return self._run(*args, **kwargs) 103 | 104 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 105 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 106 | W_ref_query: Optional[Tuple[Tensor, Tensor, int]] = None): 107 | 108 | T = T_init 109 | J_scaling = None 110 | if self.conf.normalize_features: 111 | F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 112 | args = (camera, p3D, F_ref, F_query, W_ref_query) 113 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 114 | 115 | for i in range(self.conf.num_iters): 116 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian(T, *args) 117 | if mask is not None: 118 | valid &= mask 119 | failed = failed | (valid.long().sum(-1) < 10) # too few points 120 | 121 | # compute the cost and aggregate the weights 122 | cost = (res**2).sum(-1) 123 | cost, w_loss, _ = self.loss_fn(cost) 124 | weights = w_loss * valid.float() 125 | if w_unc is not None: 126 | weights *= w_unc 127 | if self.conf.jacobi_scaling: 128 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 129 | 130 | # solve the linear system 131 | g, H = self.build_system(J, res, weights) 132 | delta = optimizer_step(g, H, self.conf.lambda_, mask=~failed) 133 | if self.conf.jacobi_scaling: 134 | delta = delta * J_scaling 135 | 136 | # compute the pose update 137 | dt, dw = delta.split([3, 3], dim=-1) 138 | T_delta = Pose.from_aa(dw, dt) 139 | T = T_delta @ T 140 | 141 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 142 | valid=valid, w_unc=w_unc, w_loss=w_loss, H=H, J=J) 143 | if self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost): 144 | break 145 | 146 | if failed.any(): 147 | logger.debug('One batch element had too few valid points.') 148 | 149 | return T, failed 150 | 151 | def loss(self, pred, data): 152 | raise NotImplementedError 153 | 154 | def metrics(self, pred, data): 155 | raise NotImplementedError 156 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/classic_optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Optional 3 | import torch 4 | from torch import Tensor 5 | 6 | from .base_optimizer import BaseOptimizer 7 | from .utils import masked_mean 8 | from ..geometry import Camera, Pose 9 | from ..geometry.optimization import optimizer_step 10 | from ..geometry import losses # noqa 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class ClassicOptimizer(BaseOptimizer): 16 | default_conf = dict( 17 | lambda_=1e-2, 18 | lambda_max=1e4, 19 | ) 20 | 21 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 22 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 23 | W_ref_query: Optional[Tuple[Tensor, Tensor]] = None): 24 | 25 | T = T_init 26 | J_scaling = None 27 | if self.conf.normalize_features: 28 | F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 29 | args = (camera, p3D, F_ref, F_query, W_ref_query) 30 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 31 | 32 | lambda_ = torch.full_like(failed, self.conf.lambda_, dtype=T.dtype) 33 | mult = torch.full_like(lambda_, 10) 34 | recompute = True 35 | 36 | # compute the initial cost 37 | with torch.no_grad(): 38 | res, valid_i, w_unc_i = self.cost_fn.residuals(T_init, *args)[:3] 39 | cost_i = self.loss_fn((res.detach()**2).sum(-1))[0] 40 | if w_unc_i is not None: 41 | cost_i *= w_unc_i.detach() 42 | valid_i &= mask 43 | cost_best = masked_mean(cost_i, valid_i, -1) 44 | 45 | for i in range(self.conf.num_iters): 46 | if recompute: 47 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian( 48 | T, *args) 49 | if mask is not None: 50 | valid &= mask 51 | failed = failed | (valid.long().sum(-1) < 10) # too few points 52 | 53 | cost = (res**2).sum(-1) 54 | cost, w_loss, _ = self.loss_fn(cost) 55 | weights = w_loss * valid.float() 56 | if w_unc is not None: 57 | weights *= w_unc 58 | if self.conf.jacobi_scaling: 59 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 60 | g, H = self.build_system(J, res, weights) 61 | 62 | delta = optimizer_step(g, H, lambda_.unqueeze(-1), mask=~failed) 63 | if self.conf.jacobi_scaling: 64 | delta = delta * J_scaling 65 | 66 | dt, dw = delta.split([3, 3], dim=-1) 67 | T_delta = Pose.from_aa(dw, dt) 68 | T_new = T_delta @ T 69 | 70 | # compute the new cost and update if it decreased 71 | with torch.no_grad(): 72 | res = self.cost_fn.residual(T_new, *args)[0] 73 | cost_new = self.loss_fn((res**2).sum(-1))[0] 74 | cost_new = masked_mean(cost_new, valid, -1) 75 | accept = cost_new < cost_best 76 | lambda_ = lambda_ * torch.where(accept, 1/mult, mult) 77 | lambda_ = lambda_.clamp(max=self.conf.lambda_max, min=1e-7) 78 | T = Pose(torch.where(accept[..., None], T_new._data, T._data)) 79 | cost_best = torch.where(accept, cost_new, cost_best) 80 | recompute = accept.any() 81 | 82 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 83 | valid=valid, w_unc=w_unc, w_loss=w_loss, accept=accept, 84 | lambda_=lambda_, H=H, J=J) 85 | 86 | stop = self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost) 87 | if self.conf.lambda_ == 0: # Gauss-Newton 88 | stop |= (~recompute) 89 | else: # LM saturates 90 | stop |= bool(torch.all(lambda_ >= self.conf.lambda_max)) 91 | if stop: 92 | break 93 | 94 | if failed.any(): 95 | logger.debug('One batch element had too few valid points.') 96 | 97 | return T, failed 98 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/gaussiannet.py: -------------------------------------------------------------------------------- 1 | """ 2 | A dummy model that computes an image pyramid with appropriate blurring. 3 | """ 4 | 5 | import torch 6 | import kornia 7 | 8 | from .base_model import BaseModel 9 | 10 | 11 | class GaussianNet(BaseModel): 12 | default_conf = { 13 | 'output_scales': [1, 4, 16], # what scales to adapt and output 14 | 'kernel_size_factor': 3, 15 | } 16 | 17 | def _init(self, conf): 18 | self.scales = conf['output_scales'] 19 | 20 | def _forward(self, data): 21 | image = data['image'] 22 | scale_prev = 1 23 | pyramid = [] 24 | for scale in self.conf.output_scales: 25 | sigma = scale / scale_prev 26 | ksize = self.conf.kernel_size_factor * sigma 27 | image = kornia.filter.gaussian_blur2d( 28 | image, kernel_size=ksize, sigma=sigma) 29 | if sigma != 1: 30 | image = torch.nn.functional.interpolate( 31 | image, scale_factor=1/sigma, mode='bilinear', 32 | align_corners=False) 33 | pyramid.append(image) 34 | scale_prev = scale 35 | return {'feature_maps': pyramid} 36 | 37 | def loss(self, pred, data): 38 | raise NotImplementedError 39 | 40 | def metrics(self, pred, data): 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/learned_optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Tuple, Optional 3 | import torch 4 | from torch import nn, Tensor 5 | 6 | from .base_optimizer import BaseOptimizer 7 | from ..geometry import Camera, Pose 8 | from ..geometry.optimization import optimizer_step 9 | from ..geometry import losses # noqa 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DampingNet(nn.Module): 15 | def __init__(self, conf, num_params=6): 16 | super().__init__() 17 | self.conf = conf 18 | if conf.type == 'constant': 19 | const = torch.zeros(num_params) 20 | self.register_parameter('const', torch.nn.Parameter(const)) 21 | else: 22 | raise ValueError(f'Unsupported type of damping: {conf.type}.') 23 | 24 | def forward(self): 25 | min_, max_ = self.conf.log_range 26 | lambda_ = 10.**(min_ + self.const.sigmoid()*(max_ - min_)) 27 | return lambda_ 28 | 29 | 30 | class LearnedOptimizer(BaseOptimizer): 31 | default_conf = dict( 32 | damping=dict( 33 | type='constant', 34 | log_range=[-6, 5], 35 | ), 36 | feature_dim=None, 37 | 38 | # deprecated entries 39 | lambda_=0., 40 | learned_damping=True, 41 | ) 42 | 43 | def _init(self, conf): 44 | self.dampingnet = DampingNet(conf.damping) 45 | assert conf.learned_damping 46 | super()._init(conf) 47 | 48 | # def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 49 | # T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 50 | # W_ref_query: Optional[Tuple[Tensor, Tensor]] = None): 51 | # 52 | # T = T_init 53 | # J_scaling = None 54 | # if self.conf.normalize_features: 55 | # F_ref = torch.nn.functional.normalize(F_ref, dim=-1) 56 | # args = (camera, p3D, F_ref, F_query, W_ref_query) 57 | # failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 58 | # 59 | # lambda_ = self.dampingnet() 60 | # 61 | # for i in range(self.conf.num_iters): 62 | # res, valid, w_unc, _, J = self.cost_fn.residual_jacobian(T, *args) 63 | # if mask is not None: 64 | # valid &= mask 65 | # failed = failed | (valid.long().sum(-1) < 10) # too few points 66 | # 67 | # # compute the cost and aggregate the weights 68 | # cost = (res**2).sum(-1) 69 | # cost, w_loss, _ = self.loss_fn(cost) 70 | # weights = w_loss * valid.float() 71 | # if w_unc is not None: 72 | # weights *= w_unc 73 | # if self.conf.jacobi_scaling: 74 | # J, J_scaling = self.J_scaling(J, J_scaling, valid) 75 | # 76 | # # solve the linear system 77 | # g, H = self.build_system(J, res, weights) 78 | # delta = optimizer_step(g, H, lambda_, mask=~failed) 79 | # if self.conf.jacobi_scaling: 80 | # delta = delta * J_scaling 81 | # 82 | # # compute the pose update 83 | # dt, dw = delta.split([3, 3], dim=-1) 84 | # T_delta = Pose.from_aa(dw, dt) 85 | # T = T_delta @ T 86 | # 87 | # self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 88 | # valid=valid, w_unc=w_unc, w_loss=w_loss, H=H, J=J) 89 | # if self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost): 90 | # break 91 | # 92 | # if failed.any(): 93 | # logger.debug('One batch element had too few valid points.') 94 | # 95 | # return T, failed 96 | def _run(self, p3D: Tensor, F_ref: Tensor, F_query: Tensor, 97 | T_init: Pose, camera: Camera, mask: Optional[Tensor] = None, 98 | W_ref_query: Optional[Tuple[Tensor, Tensor, int]] = None): 99 | 100 | T = T_init 101 | J_scaling = None 102 | if self.conf.normalize_features: 103 | F_query = torch.nn.functional.normalize(F_query, dim=-1) 104 | args = (camera, p3D, F_ref, F_query, W_ref_query) 105 | failed = torch.full(T.shape, False, dtype=torch.bool, device=T.device) 106 | 107 | lambda_ = self.dampingnet() 108 | 109 | for i in range(self.conf.num_iters): 110 | res, valid, w_unc, _, J = self.cost_fn.residual_jacobian(T, *args) 111 | if mask is not None: 112 | valid &= mask 113 | failed = failed | (valid.long().sum(-1) < 10) # too few points 114 | 115 | # compute the cost and aggregate the weights 116 | cost = (res**2).sum(-1) 117 | cost, w_loss, _ = self.loss_fn(cost) 118 | weights = w_loss * valid.float() 119 | if w_unc is not None: 120 | weights = weights*w_unc 121 | if self.conf.jacobi_scaling: 122 | J, J_scaling = self.J_scaling(J, J_scaling, valid) 123 | 124 | # solve the linear system 125 | g, H = self.build_system(J, res, weights) 126 | delta = optimizer_step(g, H, lambda_, mask=~failed) 127 | if self.conf.jacobi_scaling: 128 | delta = delta * J_scaling 129 | 130 | # compute the pose update 131 | dt, dw = delta.split([3, 3], dim=-1) 132 | # dt, dw = delta.split([2, 1], dim=-1) 133 | # fix z trans, roll and pitch 134 | zeros = torch.zeros_like(dw[:,-1:]) 135 | dw = torch.cat([zeros,zeros,dw[:,-1:]], dim=-1) 136 | dt = torch.cat([dt[:,0:2],zeros], dim=-1) 137 | 138 | T_delta = Pose.from_aa(dw, dt) 139 | T = T_delta @ T 140 | 141 | self.log(i=i, T_init=T_init, T=T, T_delta=T_delta, cost=cost, 142 | valid=valid, w_unc=w_unc, w_loss=w_loss, H=H, J=J) 143 | if self.early_stop(i=i, T_delta=T_delta, grad=g, cost=cost): 144 | break 145 | 146 | if failed.any(): 147 | logger.debug('One batch element had too few valid points.') 148 | 149 | return T, failed 150 | 151 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/s2dnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | An implementation of 3 | S2DNet: Learning Image Features for Accurate Sparse-to-Dense Matching 4 | Hugo Germain, Guillaume Bourmaud, Vincent Lepetit 5 | European Conference on Computer Vision (ECCV) 2020 6 | 7 | Adapted from https://github.com/germain-hug/S2DNet-Minimal 8 | """ 9 | 10 | from typing import List 11 | import torch 12 | import torch.nn as nn 13 | from torchvision import models 14 | import logging 15 | 16 | from .base_model import BaseModel 17 | from ...settings import DATA_PATH 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | # VGG-16 Layer Names and Channels 23 | vgg16_layers = { 24 | "conv1_1": 64, 25 | "relu1_1": 64, 26 | "conv1_2": 64, 27 | "relu1_2": 64, 28 | "pool1": 64, 29 | "conv2_1": 128, 30 | "relu2_1": 128, 31 | "conv2_2": 128, 32 | "relu2_2": 128, 33 | "pool2": 128, 34 | "conv3_1": 256, 35 | "relu3_1": 256, 36 | "conv3_2": 256, 37 | "relu3_2": 256, 38 | "conv3_3": 256, 39 | "relu3_3": 256, 40 | "pool3": 256, 41 | "conv4_1": 512, 42 | "relu4_1": 512, 43 | "conv4_2": 512, 44 | "relu4_2": 512, 45 | "conv4_3": 512, 46 | "relu4_3": 512, 47 | "pool4": 512, 48 | "conv5_1": 512, 49 | "relu5_1": 512, 50 | "conv5_2": 512, 51 | "relu5_2": 512, 52 | "conv5_3": 512, 53 | "relu5_3": 512, 54 | "pool5": 512, 55 | } 56 | 57 | 58 | class AdapLayers(nn.Module): 59 | """Small adaptation layers. 60 | """ 61 | 62 | def __init__(self, hypercolumn_layers: List[str], output_dim: int = 128): 63 | """Initialize one adaptation layer for every extraction point. 64 | 65 | Args: 66 | hypercolumn_layers: The list of the hypercolumn layer names. 67 | output_dim: The output channel dimension. 68 | """ 69 | super(AdapLayers, self).__init__() 70 | self.layers = [] 71 | channel_sizes = [vgg16_layers[name] for name in hypercolumn_layers] 72 | for i, l in enumerate(channel_sizes): 73 | layer = nn.Sequential( 74 | nn.Conv2d(l, 64, kernel_size=1, stride=1, padding=0), 75 | nn.ReLU(), 76 | nn.Conv2d(64, output_dim, kernel_size=5, stride=1, padding=2), 77 | nn.BatchNorm2d(output_dim), 78 | ) 79 | self.layers.append(layer) 80 | self.add_module("adap_layer_{}".format(i), layer) 81 | 82 | def forward(self, features: List[torch.tensor]): 83 | """Apply adaptation layers. 84 | """ 85 | for i, _ in enumerate(features): 86 | features[i] = getattr(self, "adap_layer_{}".format(i))(features[i]) 87 | return features 88 | 89 | 90 | class S2DNet(BaseModel): 91 | default_conf = { 92 | 'hypercolumn_layers': ["conv1_2", "conv3_3", "conv5_3"], 93 | 'checkpointing': None, 94 | 'output_dim': 128, 95 | 'pretrained': 's2dnet', 96 | } 97 | mean = [0.485, 0.456, 0.406] 98 | std = [0.229, 0.224, 0.225] 99 | 100 | def _init(self, conf): 101 | assert conf.pretrained in ['s2dnet', 'imagenet', None] 102 | 103 | self.layer_to_index = {k: v for v, k in enumerate(vgg16_layers.keys())} 104 | self.hypercolumn_indices = [ 105 | self.layer_to_index[n] for n in conf.hypercolumn_layers] 106 | num_layers = self.hypercolumn_indices[-1] + 1 107 | 108 | # Initialize architecture 109 | vgg16 = models.vgg16(pretrained=conf.pretrained == 'imagenet') 110 | layers = list(vgg16.features.children())[:num_layers] 111 | self.encoder = nn.ModuleList(layers) 112 | 113 | self.scales = [] 114 | current_scale = 0 115 | for i, layer in enumerate(layers): 116 | if isinstance(layer, torch.nn.MaxPool2d): 117 | current_scale += 1 118 | if i in self.hypercolumn_indices: 119 | self.scales.append(2**current_scale) 120 | 121 | self.adaptation_layers = AdapLayers( 122 | conf.hypercolumn_layers, conf.output_dim) 123 | 124 | if conf.pretrained == 's2dnet': 125 | path = DATA_PATH / 's2dnet_weights.pth' 126 | logger.info(f'Loading S2DNet checkpoint at {path}.') 127 | state_dict = torch.load(path, map_location='cpu')['state_dict'] 128 | params = self.state_dict() 129 | state_dict = {k: v for k, v in state_dict.items() 130 | if v.shape == params[k].shape} 131 | self.load_state_dict(state_dict, strict=False) 132 | 133 | def _forward(self, data): 134 | image = data['image'] 135 | mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) 136 | image = (image - mean[:, None, None]) / std[:, None, None] 137 | 138 | feature_map = image 139 | feature_maps = [] 140 | start = 0 141 | for idx in self.hypercolumn_indices: 142 | if self.conf.checkpointing: 143 | blocks = list(range(start, idx+2, self.conf.checkpointing)) 144 | if blocks[-1] != idx+1: 145 | blocks.append(idx+1) 146 | for start_, end_ in zip(blocks[:-1], blocks[1:]): 147 | feature_map = torch.utils.checkpoint.checkpoint( 148 | nn.Sequential(*self.encoder[start_:end_]), feature_map) 149 | else: 150 | for i in range(start, idx + 1): 151 | feature_map = self.encoder[i](feature_map) 152 | feature_maps.append(feature_map) 153 | start = idx + 1 154 | 155 | feature_maps = self.adaptation_layers(feature_maps) 156 | return {'feature_maps': feature_maps} 157 | 158 | def loss(self, pred, data): 159 | raise NotImplementedError 160 | 161 | def metrics(self, pred, data): 162 | raise NotImplementedError 163 | -------------------------------------------------------------------------------- /PureACL/pixlib/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def masked_mean(x, mask, dim): 5 | mask = mask.float() 6 | return (mask * x).sum(dim) / mask.sum(dim).clamp(min=1) 7 | 8 | 9 | def checkpointed(cls, do=True): 10 | '''Adapted from the DISK implementation of Michał Tyszkiewicz.''' 11 | assert issubclass(cls, torch.nn.Module) 12 | 13 | class Checkpointed(cls): 14 | def forward(self, *args, **kwargs): 15 | super_fwd = super(Checkpointed, self).forward 16 | if any((torch.is_tensor(a) and a.requires_grad) for a in args): 17 | return torch.utils.checkpoint.checkpoint( 18 | super_fwd, *args, **kwargs) 19 | else: 20 | return super_fwd(*args, **kwargs) 21 | 22 | return Checkpointed if do else cls 23 | 24 | # shan add for key points extraction, from super point 25 | def merge_confidence_map(confidence, number): 26 | """extrac key ponts from confidence map. 27 | Args: 28 | confidence: torch.Tensor with size (B,C,H,W). 29 | number: number of confidence map 30 | Returns: 31 | merged confidence map: torch.Tensor with size (B,H,W). 32 | """ 33 | B,C,H,W = confidence[0].size() 34 | for level in range(len(confidence)): 35 | if number == 2: 36 | c_cur = confidence[level][:,:1]*confidence[level][:,1:] 37 | else: 38 | c_cur = confidence[level][:,:1] 39 | if level > 0: 40 | c_cur = torch.nn.functional.interpolate(c_cur, size=(H,W), mode='bilinear') 41 | max, _ = torch.max(c_cur.flatten(-2), dim=-1) 42 | c_cur = c_cur / (max[:,:,None,None] + 1e-8) 43 | #c_cur = torch.nn.functional.normalize(c_cur.flatten(-2), p=float('inf'), dim=-1) # normalize in 2d Plane #[b,c,H,W] 44 | #c_cur = c_cur.view(B, 1, H, W) 45 | c_last = 0.8*c_last + c_cur 46 | else: 47 | max, _ = torch.max(c_cur.flatten(-2), dim=-1) 48 | c_cur = c_cur / (max[:,:,None,None] + 1e-8) 49 | #c_cur = torch.nn.functional.normalize(c_cur.flatten(-2), p=float('inf'), dim=-1) # normalize in 2d Plane #[b,c,H,W] 50 | #c_cur = c_cur.view(B, 1, H, W) 51 | c_last = c_cur 52 | return c_last 53 | 54 | # shan add for key points extraction, from super point 55 | def extract_keypoints(confidence, topk=1024, start_ratio=0.6): 56 | """extrac key ponts from confidence map. 57 | Args: 58 | confidence: torch.Tensor with size (B,C,H,W). 59 | topk: extract topk points each confidence map 60 | start_ratio: extract close to ground part (start_ratio*H:) 61 | Returns: 62 | A torch.Tensor of index where the key points are. 63 | """ 64 | assert (start_ratio < 1 and start_ratio >= 0) 65 | 66 | 67 | w_end = -1 68 | h_end = -1 69 | radius = 4 70 | if confidence.size(-1) > 1200: # KITTI 71 | radius = 5 72 | 73 | # only extract close to ground part (start_H:) 74 | start_H = int(confidence.size(2)*start_ratio) 75 | confidence = confidence[:,:,start_H:h_end,:w_end].detach().clone() 76 | 77 | # fast Non-maximum suppression to remove nearby points 78 | def max_pool(x): 79 | return torch.nn.functional.max_pool2d(x, kernel_size=radius*2+1, stride=1, padding=radius) 80 | 81 | max_mask = (confidence == max_pool(confidence)) 82 | for _ in range(2): 83 | supp_mask = max_pool(max_mask.float()) > 0 84 | supp_confidence = torch.where(supp_mask, torch.zeros_like(confidence), confidence) 85 | new_max_mask = (supp_confidence == max_pool(supp_confidence)) 86 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 87 | confidence = torch.where(max_mask, confidence, torch.zeros_like(confidence)) 88 | 89 | # remove borders 90 | border = radius 91 | confidence[:, :, :border] = 0. 92 | confidence[:, :, -border:] = 0. 93 | confidence[:, :, :, :border] = 0. 94 | confidence[:, :, :, -border:] = 0. 95 | 96 | # confidence topk 97 | _, index = confidence.flatten(1).topk(topk, dim=1, largest=True, sorted=True) 98 | 99 | index_v = torch.div(index, confidence.size(-1) , rounding_mode='trunc') 100 | index_u = index % confidence.size(-1) 101 | # back to original index 102 | index_v += start_H 103 | 104 | return torch.cat([index_u.unsqueeze(-1),index_v.unsqueeze(-1)],dim=-1) 105 | 106 | def camera_to_onground(p3d_c, T_w2cam, camera_h, normal_grd, min=1E-8, max=200.): 107 | # normal from query to camera coordinate 108 | normal = torch.einsum('...ij,...cj->...ci', T_w2cam.R, normal_grd) 109 | normal = normal.squeeze(1) 110 | h = 0 111 | if p3d_c.dim() > 3: 112 | b,h,w,c = p3d_c.shape 113 | p3d_c = p3d_c.flatten(1,2) 114 | depth = camera_h[:,None] / torch.einsum('b...i,b...i->b...', p3d_c, normal) 115 | valid = (depth < max) & (depth >= min) 116 | depth = depth.clamp(min, max) 117 | p3d_grd = depth.unsqueeze(-1) * p3d_c 118 | # each camera coordinate to 'query' coordinate 119 | p3d_grd = T_w2cam.inv()*p3d_grd # camera to query 120 | 121 | # not valid set to far away 122 | p3d_grd[~valid] = torch.tensor(max).to(p3d_grd) 123 | if h > 0: 124 | p3d_grd = p3d_grd.reshape(b,h,w,c) 125 | return p3d_grd 126 | -------------------------------------------------------------------------------- /PureACL/pixlib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanWang-Shan/PureACL/21a4c2e64f0eeafaa09117b6bc8aef40d9cdf4e3/PureACL/pixlib/utils/__init__.py -------------------------------------------------------------------------------- /PureACL/pixlib/utils/experiments.py: -------------------------------------------------------------------------------- 1 | """ 2 | A set of utilities to manage and load checkpoints of training experiments. 3 | """ 4 | 5 | from pathlib import Path 6 | import logging 7 | import re 8 | from omegaconf import OmegaConf 9 | import torch 10 | import os 11 | 12 | from ...settings import TRAINING_PATH 13 | from ..models import get_model 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def list_checkpoints(dir_): 19 | """List all valid checkpoints in a given directory.""" 20 | checkpoints = [] 21 | for p in dir_.glob('checkpoint_*.tar'): 22 | numbers = re.findall(r'(\d+)', p.name) 23 | if len(numbers) == 0: 24 | continue 25 | assert len(numbers) == 1 26 | checkpoints.append((int(numbers[0]), p)) 27 | return checkpoints 28 | 29 | 30 | def get_last_checkpoint(exper, allow_interrupted=True): 31 | """Get the last saved checkpoint for a given experiment name.""" 32 | ckpts = list_checkpoints(Path(TRAINING_PATH, exper)) 33 | if not allow_interrupted: 34 | ckpts = [(n, p) for (n, p) in ckpts if '_interrupted' not in p.name] 35 | assert len(ckpts) > 0 36 | return sorted(ckpts)[-1][1] 37 | 38 | 39 | def get_best_checkpoint(exper): 40 | """Get the checkpoint with the best loss, for a given experiment name.""" 41 | p = Path(TRAINING_PATH, exper, 'checkpoint_best.tar') 42 | return p 43 | 44 | 45 | def delete_old_checkpoints(dir_, num_keep): 46 | """Delete all but the num_keep last saved checkpoints.""" 47 | ckpts = list_checkpoints(dir_) 48 | ckpts = sorted(ckpts)[::-1] 49 | kept = 0 50 | for ckpt in ckpts: 51 | if ('_interrupted' in str(ckpt[1]) and kept > 0) or kept >= num_keep: 52 | logger.info(f'Deleting checkpoint {ckpt[1].name}') 53 | ckpt[1].unlink() 54 | else: 55 | kept += 1 56 | 57 | 58 | def load_experiment(exper, conf={}, get_last=False): 59 | """Load and return the model of a given experiment.""" 60 | if get_last: 61 | ckpt = get_last_checkpoint(exper) 62 | else: 63 | ckpt = get_best_checkpoint(exper) 64 | logger.info(f'Loading checkpoint {ckpt.name}') 65 | ckpt = torch.load(str(ckpt), map_location='cpu') 66 | 67 | loaded_conf = OmegaConf.create(ckpt['conf']) 68 | OmegaConf.set_struct(loaded_conf, False) 69 | conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf)) 70 | model = get_model(conf.name)(conf).eval() 71 | 72 | state_dict = ckpt['model'] 73 | dict_params = set(state_dict.keys()) 74 | model_params = set(map(lambda n: n[0], model.named_parameters())) 75 | diff = model_params - dict_params 76 | if len(diff) > 0: 77 | subs = os.path.commonprefix(list(diff)).rstrip('.') 78 | logger.warning(f'Missing {len(diff)} parameters in {subs}') 79 | model.load_state_dict(state_dict, strict=False) 80 | return model 81 | 82 | 83 | def flexible_load(state_dict, model): 84 | """TODO: fix a probable nasty bug, and move to BaseModel.""" 85 | dict_params = set(state_dict.keys()) 86 | model_params = set(map(lambda n: n[0], model.named_parameters())) 87 | 88 | if dict_params == model_params: # prefect fit 89 | logger.info('Loading all parameters of the checkpoint.') 90 | model.load_state_dict(state_dict, strict=True) 91 | return 92 | elif len(dict_params & model_params) == 0: # perfect mismatch 93 | strip_prefix = lambda x: '.'.join(x.split('.')[:1]+x.split('.')[2:]) 94 | state_dict = {strip_prefix(n): p for n, p in state_dict.items()} 95 | dict_params = set(state_dict.keys()) 96 | if len(dict_params & model_params) == 0: 97 | raise ValueError('Could not manage to load the checkpoint with' 98 | 'parameters:' + '\n\t'.join(sorted(dict_params))) 99 | common_params = dict_params & model_params 100 | left_params = dict_params - model_params 101 | logger.info('Loading parameters:\n\t'+'\n\t'.join(sorted(common_params))) 102 | if len(left_params) > 0: 103 | logger.info('Could not load parameters:\n\t' 104 | + '\n\t'.join(sorted(left_params))) 105 | model.load_state_dict(state_dict, strict=False) 106 | -------------------------------------------------------------------------------- /PureACL/pixlib/utils/stdout_capturing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on sacred/stdout_capturing.py in project Sacred 3 | https://github.com/IDSIA/sacred 4 | """ 5 | 6 | from __future__ import division, print_function, unicode_literals 7 | import os 8 | import sys 9 | import subprocess 10 | from threading import Timer 11 | from contextlib import contextmanager 12 | 13 | 14 | def apply_backspaces_and_linefeeds(text): 15 | """ 16 | Interpret backspaces and linefeeds in text like a terminal would. 17 | Interpret text like a terminal by removing backspace and linefeed 18 | characters and applying them line by line. 19 | If final line ends with a carriage it keeps it to be concatenable with next 20 | output chunk. 21 | """ 22 | orig_lines = text.split('\n') 23 | orig_lines_len = len(orig_lines) 24 | new_lines = [] 25 | for orig_line_idx, orig_line in enumerate(orig_lines): 26 | chars, cursor = [], 0 27 | orig_line_len = len(orig_line) 28 | for orig_char_idx, orig_char in enumerate(orig_line): 29 | if orig_char == '\r' and (orig_char_idx != orig_line_len - 1 or 30 | orig_line_idx != orig_lines_len - 1): 31 | cursor = 0 32 | elif orig_char == '\b': 33 | cursor = max(0, cursor - 1) 34 | else: 35 | if (orig_char == '\r' and 36 | orig_char_idx == orig_line_len - 1 and 37 | orig_line_idx == orig_lines_len - 1): 38 | cursor = len(chars) 39 | if cursor == len(chars): 40 | chars.append(orig_char) 41 | else: 42 | chars[cursor] = orig_char 43 | cursor += 1 44 | new_lines.append(''.join(chars)) 45 | return '\n'.join(new_lines) 46 | 47 | 48 | def flush(): 49 | """Try to flush all stdio buffers, both from python and from C.""" 50 | try: 51 | sys.stdout.flush() 52 | sys.stderr.flush() 53 | except (AttributeError, ValueError, IOError): 54 | pass # unsupported 55 | 56 | 57 | # Duplicate stdout and stderr to a file. Inspired by: 58 | # http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ 59 | # http://stackoverflow.com/a/651718/1388435 60 | # http://stackoverflow.com/a/22434262/1388435 61 | @contextmanager 62 | def capture_outputs(filename): 63 | """Duplicate stdout and stderr to a file on the file descriptor level.""" 64 | with open(str(filename), 'a+') as target: 65 | original_stdout_fd = 1 66 | original_stderr_fd = 2 67 | target_fd = target.fileno() 68 | 69 | # Save a copy of the original stdout and stderr file descriptors 70 | saved_stdout_fd = os.dup(original_stdout_fd) 71 | saved_stderr_fd = os.dup(original_stderr_fd) 72 | 73 | tee_stdout = subprocess.Popen( 74 | ['tee', '-a', '-i', '/dev/stderr'], start_new_session=True, 75 | stdin=subprocess.PIPE, stderr=target_fd, stdout=1) 76 | tee_stderr = subprocess.Popen( 77 | ['tee', '-a', '-i', '/dev/stderr'], start_new_session=True, 78 | stdin=subprocess.PIPE, stderr=target_fd, stdout=2) 79 | 80 | flush() 81 | os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd) 82 | os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd) 83 | 84 | try: 85 | yield 86 | finally: 87 | flush() 88 | 89 | # then redirect stdout back to the saved fd 90 | tee_stdout.stdin.close() 91 | tee_stderr.stdin.close() 92 | 93 | # restore original fds 94 | os.dup2(saved_stdout_fd, original_stdout_fd) 95 | os.dup2(saved_stderr_fd, original_stderr_fd) 96 | 97 | # wait for completion of the tee processes with timeout 98 | # implemented using a timer because timeout support is py3 only 99 | def kill_tees(): 100 | tee_stdout.kill() 101 | tee_stderr.kill() 102 | 103 | tee_timer = Timer(1, kill_tees) 104 | try: 105 | tee_timer.start() 106 | tee_stdout.wait() 107 | tee_stderr.wait() 108 | finally: 109 | tee_timer.cancel() 110 | 111 | os.close(saved_stdout_fd) 112 | os.close(saved_stderr_fd) 113 | 114 | # Cleanup log file 115 | with open(str(filename), 'r') as target: 116 | text = target.read() 117 | text = apply_backspaces_and_linefeeds(text) 118 | with open(str(filename), 'w') as target: 119 | target.write(text) 120 | -------------------------------------------------------------------------------- /PureACL/pixlib/utils/tensor.py: -------------------------------------------------------------------------------- 1 | #from torch._six import string_classes 2 | import collections.abc as collections 3 | 4 | string_classes = (str, bytes) 5 | 6 | 7 | def map_tensor(input_, func): 8 | if isinstance(input_, string_classes): 9 | return input_ 10 | elif isinstance(input_, collections.Mapping): 11 | return {k: map_tensor(sample, func) for k, sample in input_.items()} 12 | elif isinstance(input_, collections.Sequence): 13 | return [map_tensor(sample, func) for sample in input_] 14 | else: 15 | return func(input_) 16 | 17 | 18 | def batch_to_numpy(batch): 19 | return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) 20 | 21 | 22 | def batch_to_device(batch, device, non_blocking=True): 23 | def _func(tensor): 24 | return tensor.to(device=device, non_blocking=non_blocking) 25 | 26 | return map_tensor(batch, _func) 27 | -------------------------------------------------------------------------------- /PureACL/pixlib/utils/tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various handy Python and PyTorch utils. 3 | """ 4 | 5 | import time 6 | import inspect 7 | import numpy as np 8 | import os 9 | import torch 10 | import random 11 | from contextlib import contextmanager 12 | 13 | 14 | class AverageMetric: 15 | def __init__(self): 16 | self._sum = 0 17 | self._num_examples = 0 18 | 19 | def update(self, tensor): 20 | assert tensor.dim() == 1 21 | tensor = tensor[~torch.isnan(tensor)] 22 | self._sum += tensor.sum().item() 23 | self._num_examples += len(tensor) 24 | 25 | def compute(self): 26 | if self._num_examples == 0: 27 | return np.nan 28 | else: 29 | return self._sum / self._num_examples 30 | 31 | 32 | class MedianMetric: 33 | def __init__(self): 34 | self._elements = [] 35 | 36 | def update(self, tensor): 37 | assert tensor.dim() == 1 38 | self._elements += tensor.cpu().numpy().tolist() 39 | 40 | def compute(self): 41 | if len(self._elements) == 0: 42 | return np.nan 43 | else: 44 | return np.nanmedian(self._elements) 45 | 46 | 47 | def get_class(mod_name, base_path, BaseClass): 48 | """Get the class object which inherits from BaseClass and is defined in 49 | the module named mod_name, child of base_path. 50 | """ 51 | mod_path = '{}.{}'.format(base_path, mod_name) 52 | mod = __import__(mod_path, fromlist=['']) 53 | classes = inspect.getmembers(mod, inspect.isclass) 54 | # Filter classes defined in the module 55 | classes = [c for c in classes if c[1].__module__ == mod_path] 56 | # Filter classes inherited from BaseModel 57 | classes = [c for c in classes if issubclass(c[1], BaseClass)] 58 | assert len(classes) == 1, classes 59 | return classes[0][1] 60 | 61 | 62 | class Timer(object): 63 | """A simpler timer context object. 64 | Usage: 65 | ``` 66 | > with Timer('mytimer'): 67 | > # some computations 68 | [mytimer] Elapsed: X 69 | ``` 70 | """ 71 | def __init__(self, name=None): 72 | self.name = name 73 | 74 | def __enter__(self): 75 | self.tstart = time.time() 76 | return self 77 | 78 | def __exit__(self, type, value, traceback): 79 | self.duration = time.time() - self.tstart 80 | if self.name is not None: 81 | print('[%s] Elapsed: %s' % (self.name, self.duration)) 82 | 83 | 84 | def set_num_threads(nt): 85 | """Force numpy and other libraries to use a limited number of threads.""" 86 | try: 87 | import mkl 88 | except ImportError: 89 | pass 90 | else: 91 | mkl.set_num_threads(nt) 92 | torch.set_num_threads(1) 93 | os.environ['IPC_ENABLE'] = '1' 94 | for o in ['OPENBLAS_NUM_THREADS', 'NUMEXPR_NUM_THREADS', 95 | 'OMP_NUM_THREADS', 'MKL_NUM_THREADS']: 96 | os.environ[o] = str(nt) 97 | 98 | 99 | def set_seed(seed): 100 | random.seed(seed) 101 | torch.manual_seed(seed) 102 | np.random.seed(seed) 103 | if torch.cuda.is_available(): 104 | torch.cuda.manual_seed(seed) 105 | torch.cuda.manual_seed_all(seed) 106 | 107 | 108 | def get_random_state(): 109 | pth_state = torch.get_rng_state() 110 | np_state = np.random.get_state() 111 | py_state = random.getstate() 112 | if torch.cuda.is_available(): 113 | cuda_state = torch.cuda.get_rng_state_all() 114 | else: 115 | cuda_state = None 116 | return pth_state, np_state, py_state, cuda_state 117 | 118 | 119 | def set_random_state(state): 120 | pth_state, np_state, py_state, cuda_state = state 121 | torch.set_rng_state(pth_state) 122 | np.random.set_state(np_state) 123 | random.setstate(py_state) 124 | if (cuda_state is not None 125 | and torch.cuda.is_available() 126 | and len(cuda_state) == torch.cuda.device_count()): 127 | torch.cuda.set_rng_state_all(cuda_state) 128 | 129 | 130 | @contextmanager 131 | def fork_rng(seed=None): 132 | state = get_random_state() 133 | if seed is not None: 134 | set_seed(seed) 135 | try: 136 | yield 137 | finally: 138 | set_random_state(state) 139 | -------------------------------------------------------------------------------- /PureACL/settings.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | root = Path(__file__).parent.parent # top-level directory 4 | DATA_PATH = root / 'datasets/' # datasets and pretrained weights 5 | TRAINING_PATH = root / 'outputs/training/' # training checkpoints 6 | LOC_PATH = root / 'outputs/hloc/' # localization logs 7 | EVAL_PATH = root / 'outputs/results/' # evaluation results 8 | -------------------------------------------------------------------------------- /PureACL/utils/data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | from pathlib import Path 4 | from typing import Dict, List, Optional 5 | from omegaconf import DictConfig, OmegaConf as oc 6 | 7 | from .. import settings, logger 8 | 9 | 10 | @dataclasses.dataclass 11 | class Paths: 12 | query_images: Path 13 | reference_images: Path 14 | reference_sfm: Path 15 | query_list: Path 16 | 17 | dataset: Optional[Path] = None 18 | dumps: Optional[Path] = None 19 | 20 | retrieval_pairs: Optional[Path] = None 21 | results: Optional[Path] = None 22 | global_descriptors: Optional[Path] = None 23 | hloc_logs: Optional[Path] = None 24 | ground_truth: Optional[Path] = None 25 | 26 | def interpolate(self, **kwargs) -> 'Paths': 27 | args = {} 28 | for f in dataclasses.fields(self): 29 | val = getattr(self, f.name) 30 | if val is not None: 31 | val = str(val) 32 | for k, v in kwargs.items(): 33 | val = val.replace(f'{{{k}}}', str(v)) 34 | val = Path(val) 35 | args[f.name] = val 36 | return self.__class__(**args) 37 | 38 | def asdict(self) -> Dict[str, Path]: 39 | return dataclasses.asdict(self) 40 | 41 | @classmethod 42 | def fields(cls) -> List[str]: 43 | return [f.name for f in dataclasses.fields(cls)] 44 | 45 | def add_prefixes(self, dataset: Path, dumps: Path, 46 | eval_dir: Optional[Path] = Path('.')) -> 'Paths': 47 | paths = {} 48 | for attr in self.fields(): 49 | val = getattr(self, attr) 50 | if val is not None: 51 | if attr in {'dataset', 'dumps'}: 52 | paths[attr] = val 53 | elif attr in {'query_images', 54 | 'reference_images', 55 | 'ground_truth'}: 56 | paths[attr] = dataset / val 57 | elif attr in {'results'}: 58 | paths[attr] = eval_dir / val 59 | else: # everything else is part of the hloc dumps 60 | paths[attr] = dumps / val 61 | paths['dataset'] = dataset 62 | paths['dumps'] = dumps 63 | return self.__class__(**paths) 64 | 65 | 66 | def create_argparser(dataset: str) -> argparse.ArgumentParser: 67 | parser = argparse.ArgumentParser( 68 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | 70 | parser.add_argument('--results', type=Path) 71 | parser.add_argument('--reference_sfm', type=Path) 72 | parser.add_argument('--retrieval', type=Path) 73 | parser.add_argument('--global_descriptors', type=Path) 74 | parser.add_argument('--hloc_logs', type=Path) 75 | 76 | parser.add_argument('--dataset', type=Path, 77 | default=settings.DATA_PATH / dataset) 78 | parser.add_argument('--dumps', type=Path, 79 | default=settings.LOC_PATH / dataset) 80 | parser.add_argument('--eval_dir', type=Path, 81 | default=settings.EVAL_PATH) 82 | 83 | parser.add_argument('--from_poses', action='store_true') 84 | parser.add_argument('--inlier_ranking', action='store_true') 85 | parser.add_argument('--skip', type=int) 86 | parser.add_argument('--verbose', action='store_true') 87 | parser.add_argument('dotlist', nargs='*') 88 | 89 | return parser 90 | 91 | 92 | def parse_paths(args, default_paths: Paths) -> Paths: 93 | default_paths = default_paths.add_prefixes( 94 | args.dataset, args.dumps, args.eval_dir) 95 | paths = {} 96 | for attr in Paths.fields(): 97 | val = getattr(args, attr, None) 98 | if val is None: 99 | val = getattr(default_paths, attr, None) 100 | if val is None: 101 | continue 102 | paths[attr] = val 103 | return Paths(**paths) 104 | 105 | 106 | def parse_conf(args, default_confs: Dict) -> DictConfig: 107 | conf = default_confs['from_poses' if args.from_poses else 'from_retrieval'] 108 | conf = oc.merge(oc.create(conf), oc.from_cli(args.dotlist)) 109 | logger.info('Parsed configuration:\n%s', oc.to_yaml(conf)) 110 | return conf 111 | -------------------------------------------------------------------------------- /PureACL/utils/eval.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from pathlib import Path 4 | from typing import Union, Dict, Tuple, Optional 5 | import numpy as np 6 | from .io import parse_image_list 7 | from .colmap import qvec2rotmat, read_images_binary, read_images_text 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def evaluate(gt_sfm_model: Path, predictions: Union[Dict, Path], 13 | test_file_list: Optional[Path] = None, 14 | only_localized: bool = False): 15 | """Compute the evaluation metrics for 7Scenes and Cambridge Landmarks. 16 | The other datasets are evaluated on visuallocalization.net 17 | """ 18 | if not isinstance(predictions, dict): 19 | predictions = parse_image_list(predictions, with_poses=True) 20 | predictions = {n: (im.qvec, im.tvec) for n, im in predictions} 21 | 22 | # ground truth poses from the sfm model 23 | images_bin = gt_sfm_model / 'images.bin' 24 | images_txt = gt_sfm_model / 'images.txt' 25 | if images_bin.exists(): 26 | images = read_images_binary(images_bin) 27 | elif images_txt.exists(): 28 | images = read_images_text(images_txt) 29 | else: 30 | raise ValueError(gt_sfm_model) 31 | name2id = {image.name: i for i, image in images.items()} 32 | 33 | if test_file_list is None: 34 | test_names = list(name2id) 35 | else: 36 | with open(test_file_list, 'r') as f: 37 | test_names = f.read().rstrip().split('\n') 38 | 39 | # translation and rotation errors 40 | errors_t = [] 41 | errors_R = [] 42 | for name in test_names: 43 | if name not in predictions: 44 | if only_localized: 45 | continue 46 | e_t = np.inf 47 | e_R = 180. 48 | else: 49 | image = images[name2id[name]] 50 | R_gt, t_gt = image.qvec2rotmat(), image.tvec 51 | qvec, t = predictions[name] 52 | R = qvec2rotmat(qvec) 53 | e_t = np.linalg.norm(-R_gt.T @ t_gt + R.T @ t, axis=0) 54 | cos = np.clip((np.trace(np.dot(R_gt.T, R)) - 1) / 2, -1., 1.) 55 | e_R = np.rad2deg(np.abs(np.arccos(cos))) 56 | errors_t.append(e_t) 57 | errors_R.append(e_R) 58 | 59 | errors_t = np.array(errors_t) 60 | errors_R = np.array(errors_R) 61 | med_t = np.median(errors_t) 62 | med_R = np.median(errors_R) 63 | out = f'\nMedian errors: {med_t:.3f}m, {med_R:.3f}deg' 64 | 65 | out += '\nPercentage of test images localized within:' 66 | threshs_t = [0.01, 0.02, 0.03, 0.05, 0.25, 0.5, 5.0] 67 | threshs_R = [1.0, 2.0, 3.0, 5.0, 2.0, 5.0, 10.0] 68 | for th_t, th_R in zip(threshs_t, threshs_R): 69 | ratio = np.mean((errors_t < th_t) & (errors_R < th_R)) 70 | out += f'\n\t{th_t*100:.0f}cm, {th_R:.0f}deg : {ratio*100:.2f}%' 71 | logger.info(out) 72 | 73 | 74 | def cumulative_recall(errors: np.ndarray) -> Tuple[np.ndarray]: 75 | sort_idx = np.argsort(errors) 76 | errors = np.array(errors.copy())[sort_idx] 77 | recall = (np.arange(len(errors)) + 1) / len(errors) 78 | errors = np.r_[0., errors] 79 | recall = np.r_[0., recall] 80 | return errors, recall*100 81 | -------------------------------------------------------------------------------- /PureACL/utils/io.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, List, Union, Any 3 | from pathlib import Path 4 | from collections import defaultdict 5 | import numpy as np 6 | import h5py 7 | 8 | from .colmap import Camera, Image 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def parse_image_list(path: Path, with_intrinsics: bool = False, 14 | with_poses: bool = False) -> List: 15 | images = [] 16 | with open(path, 'r') as f: 17 | for line in f: 18 | line = line.strip('\n') 19 | if len(line) == 0 or line[0] == '#': 20 | continue 21 | name, *data = line.split() 22 | if with_intrinsics: 23 | camera_model, width, height, *params = data 24 | params = np.array(params, float) 25 | camera = Camera( 26 | None, camera_model, int(width), int(height), params) 27 | images.append((name, camera)) 28 | elif with_poses: 29 | qvec, tvec = np.split(np.array(data, float), [4]) 30 | image = Image( 31 | id=None, qvec=qvec, tvec=tvec, camera_id=None, name=name, 32 | xys=None, point3D_ids=None) 33 | images.append((name, image)) 34 | else: 35 | images.append(name) 36 | 37 | logger.info(f'Imported {len(images)} images from {path.name}') 38 | return images 39 | 40 | 41 | def parse_image_lists(paths: Path, **kwargs) -> List: 42 | images = [] 43 | files = list(Path(paths.parent).glob(paths.name)) 44 | assert len(files) > 0, paths 45 | for lfile in files: 46 | images += parse_image_list(lfile, **kwargs) 47 | return images 48 | 49 | 50 | def parse_retrieval(path: Path) -> Dict[str, List[str]]: 51 | retrieval = defaultdict(list) 52 | with open(path, 'r') as f: 53 | for p in f.read().rstrip('\n').split('\n'): 54 | q, r = p.split() 55 | retrieval[q].append(r) 56 | return dict(retrieval) 57 | 58 | 59 | def load_hdf5(path: Path) -> Dict[str, Any]: 60 | with h5py.File(path, 'r') as hfile: 61 | data = {} 62 | def collect(_, obj): # noqa 63 | if isinstance(obj, h5py.Dataset): 64 | name = obj.parent.name.strip('/') 65 | data[name] = obj.__array__() 66 | hfile.visititems(collect) 67 | return data 68 | 69 | 70 | def write_pose_results(pose_dict: Dict, outfile: Path, 71 | prepend_camera_name: bool = False): 72 | logger.info('Writing the localization results to %s.', outfile) 73 | outfile.parent.mkdir(parents=True, exist_ok=True) 74 | with open(str(outfile), 'w') as f: 75 | for imgname, (qvec, tvec) in pose_dict.items(): 76 | qvec = ' '.join(map(str, qvec)) 77 | tvec = ' '.join(map(str, tvec)) 78 | name = imgname.split('/')[-1] 79 | if prepend_camera_name: 80 | name = imgname.split('/')[-2] + '/' + name 81 | f.write(f'{name} {qvec} {tvec}\n') 82 | 83 | 84 | def concat_results(paths: List[Path], names: List[Union[int, str]], 85 | output_path: Path, key: str) -> Path: 86 | results = [] 87 | for path in sorted(paths): 88 | with open(path, 'r') as fp: 89 | results.append(fp.read().rstrip('\n')) 90 | output_path = str(output_path).replace( 91 | f'{{{key}}}', '-'.join(str(n)[:3] for n in names)) 92 | with open(output_path, 'w') as fp: 93 | fp.write('\n'.join(results)) 94 | return Path(output_path) 95 | -------------------------------------------------------------------------------- /PureACL/utils/quaternions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def qvec2rotmat(qvec): 5 | return np.array([ 6 | [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2, 7 | 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3], 8 | 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]], 9 | [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3], 10 | 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2, 11 | 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]], 12 | [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2], 13 | 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1], 14 | 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]]) 15 | 16 | 17 | def rotmat2qvec(R): 18 | Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat 19 | K = np.array([ 20 | [Rxx - Ryy - Rzz, 0, 0, 0], 21 | [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0], 22 | [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0], 23 | [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0 24 | eigvals, eigvecs = np.linalg.eigh(K) 25 | qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)] 26 | if qvec[0] < 0: 27 | qvec *= -1 28 | return qvec 29 | 30 | 31 | def weighted_qvecs(qvecs, weights): 32 | """Adapted from Tolga Birdal: 33 | https://github.com/tolgabirdal/averaging_quaternions/blob/master/wavg_quaternion_markley.m 34 | """ 35 | outer = np.einsum('ni,nj,n->ij', qvecs, qvecs, weights) 36 | avg = np.linalg.eigh(outer)[1][:, -1] # eigenvector of largest eigenvalue 37 | avg *= np.sign(avg[0]) 38 | return avg 39 | 40 | 41 | def weighted_pose(t_w2c, q_w2c, weights): 42 | weights = np.array(weights) 43 | R_w2c = np.stack([qvec2rotmat(q) for q in q_w2c], 0) 44 | 45 | t_c2w = -np.einsum('nij,ni->nj', R_w2c, np.array(t_w2c)) 46 | t_approx_c2w = np.sum(t_c2w * weights[:, None], 0) 47 | 48 | q_c2w = np.array(q_w2c) * np.array([[1, -1, -1, -1]]) # invert 49 | q_c2w *= np.sign(q_c2w[:, 0])[:, None] # handle antipodal 50 | q_approx_c2w = weighted_qvecs(q_c2w, weights) 51 | 52 | # convert back to camera coordinates 53 | R_approx = qvec2rotmat(q_approx_c2w).T 54 | t_approx = -R_approx @ t_approx_c2w 55 | 56 | return R_approx, t_approx 57 | -------------------------------------------------------------------------------- /PureACL/utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import functools 4 | 5 | 6 | def torchify(func): 7 | """Extends to NumPy arrays a function written for PyTorch tensors. 8 | 9 | Converts input arrays to tensors and output tensors back to arrays. 10 | Supports hybrid inputs where some are arrays and others are tensors: 11 | - in this case all tensors should have the same device and float dtype; 12 | - the output is not converted. 13 | 14 | No data copy: tensors and arrays share the same underlying storage. 15 | 16 | Warning: kwargs are currently not supported when using jit. 17 | """ 18 | # TODO: switch to @torch.jit.unused when is_scripting will work 19 | @torch.jit.ignore 20 | @functools.wraps(func) 21 | def wrapped(*args, **kwargs): 22 | device = None 23 | dtype = None 24 | for arg in args: 25 | if isinstance(arg, torch.Tensor): 26 | device_ = arg.device 27 | if device is not None and device != device_: 28 | raise ValueError( 29 | 'Two input tensors have different devices: ' 30 | f'{device} and {device_}') 31 | device = device_ 32 | if torch.is_floating_point(arg): 33 | dtype_ = arg.dtype 34 | if dtype is not None and dtype != dtype_: 35 | raise ValueError( 36 | 'Two input tensors have different float dtypes: ' 37 | f'{dtype} and {dtype_}') 38 | dtype = dtype_ 39 | 40 | args_converted = [] 41 | for arg in args: 42 | if isinstance(arg, np.ndarray): 43 | arg = torch.from_numpy(arg).to(device) 44 | if torch.is_floating_point(arg): 45 | arg = arg.to(dtype) 46 | args_converted.append(arg) 47 | 48 | rets = func(*args_converted, **kwargs) 49 | 50 | def convert_back(ret): 51 | if isinstance(ret, torch.Tensor): 52 | if device is None: # no input was torch.Tensor 53 | ret = ret.cpu().numpy() 54 | return ret 55 | 56 | # TODO: handle nested struct with map tensor 57 | if not isinstance(rets, tuple): 58 | rets = convert_back(rets) 59 | else: 60 | rets = tuple(convert_back(ret) for ret in rets) 61 | return rets 62 | 63 | # BUG: is_scripting does not work in 1.6 so wrapped is always called 64 | if torch.jit.is_scripting(): 65 | return func 66 | else: 67 | return wrapped 68 | -------------------------------------------------------------------------------- /PureACL/visualization/animation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Optional, List 3 | import logging 4 | import shutil 5 | import json 6 | import io 7 | import base64 8 | import cv2 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from .viz_2d import save_plot 13 | from ..localization import Model3D 14 | from ..pixlib.geometry import Pose, Camera 15 | from ..utils.quaternions import rotmat2qvec 16 | 17 | logger = logging.getLogger(__name__) 18 | try: 19 | import ffmpeg 20 | except ImportError: 21 | logger.info('Cannot import ffmpeg.') 22 | 23 | 24 | def subsample_steps(T_w2q: Pose, p2d_q: np.ndarray, mask_q: np.ndarray, 25 | camera_size: np.ndarray, thresh_dt: float = 0.1, 26 | thresh_px: float = 0.005) -> List[int]: 27 | """Subsample steps of the optimization based on camera or point 28 | displacements. Main use case: compress an animation 29 | but keep it smooth and interesting. 30 | """ 31 | mask = mask_q.any(0) 32 | dp2ds = np.linalg.norm(np.diff(p2d_q, axis=0), axis=-1) 33 | dp2ds = np.median(dp2ds[:, mask], 1) 34 | dts = (T_w2q[:-1] @ T_w2q[1:].inv()).magnitude()[0].numpy() 35 | assert len(dts) == len(dp2ds) 36 | 37 | thresh_dp2 = camera_size.min()*thresh_px # from percent to pixel 38 | 39 | num = len(dp2ds) 40 | keep = [] 41 | count_dp2 = 0 42 | count_dt = 0 43 | for i, dp2 in enumerate(dp2ds): 44 | count_dp2 += dp2 45 | count_dt += dts[i] 46 | if (i == 0 or i == (num-1) 47 | or count_dp2 >= thresh_dp2 or count_dt >= thresh_dt): 48 | count_dp2 = 0 49 | count_dt = 0 50 | keep.append(i) 51 | return keep 52 | 53 | 54 | class VideoWriter: 55 | """Write frames sequentially as images, create a video, and clean up.""" 56 | def __init__(self, tmp_dir: Path, ext='.jpg'): 57 | self.tmp_dir = Path(tmp_dir) 58 | self.ext = ext 59 | self.count = 0 60 | if self.tmp_dir.exists(): 61 | shutil.rmtree(self.tmp_dir) 62 | self.tmp_dir.mkdir(parents=True) 63 | 64 | def add_frame(self): 65 | save_plot(self.tmp_dir / f'{self.count:0>5}{self.ext}') 66 | plt.close() 67 | self.count += 1 68 | 69 | def to_video(self, out_path: Path, duration: Optional[float] = None, 70 | fps: int = 5, crf: int = 23, verbose: bool = False): 71 | assert self.count > 0 72 | if duration is not None: 73 | fps = self.count / duration 74 | frames = self.tmp_dir / f'*{self.ext}' 75 | logger.info('Running ffmpeg.') 76 | ( 77 | ffmpeg 78 | .input(frames, pattern_type='glob', framerate=fps) 79 | .filter('crop', 'trunc(iw/2)*2', 'trunc(ih/2)*2') 80 | .output(out_path, crf=crf, vcodec='libx264', pix_fmt='yuv420p') 81 | .run(overwrite_output=True, quiet=not verbose) 82 | ) 83 | shutil.rmtree(self.tmp_dir) 84 | 85 | 86 | def display_video(path: Path): 87 | from IPython.display import HTML 88 | # prevent jupyter from caching the video file 89 | data = io.open(path, 'r+b').read() 90 | encoded = base64.b64encode(data).decode('ascii') 91 | return HTML(f""" 92 | 95 | """) 96 | 97 | 98 | def frustum_points(camera: Camera) -> np.ndarray: 99 | """Compute the corners of the frustum of a camera object.""" 100 | W, H = camera.size.numpy() 101 | corners = np.array([[0, 0], [W, 0], [W, H], [0, H], 102 | [0, 0], [W/2, -H/5], [W, 0]]) 103 | corners = (corners - camera.c.numpy()) / camera.f.numpy() 104 | return corners 105 | 106 | 107 | def copy_compress_image(source: Path, target: Path, quality: int = 50): 108 | """Read an image and write it to a low-quality jpeg.""" 109 | image = cv2.imread(str(source)) 110 | cv2.imwrite(str(target), image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) 111 | 112 | 113 | def format_json(x, decimals: int = 3): 114 | """Control the precision of numpy float arrays, convert boolean to int.""" 115 | if isinstance(x, np.ndarray): 116 | if np.issubdtype(x.dtype, np.floating): 117 | if x.shape != (4,): # qvec 118 | x = np.round(x, decimals=decimals) 119 | elif x.dtype == np.bool: 120 | x = x.astype(int) 121 | return x.tolist() 122 | if isinstance(x, float): 123 | return round(x, decimals) 124 | if isinstance(x, dict): 125 | return {k: format_json(v) for k, v in x.items()} 126 | if isinstance(x, (list, tuple)): 127 | return [format_json(v) for v in x] 128 | return x 129 | 130 | 131 | def create_viz_dump(assets: Path, paths: Path, cam_q: Camera, name_q: str, 132 | T_w2q: Pose, mask_q: np.ndarray, p2d_q: np.ndarray, 133 | ref_ids: List[int], model3d: Model3D, p3d_ids: np.ndarray, 134 | tfm: np.ndarray = np.eye(3)): 135 | assets.mkdir(parents=True, exist_ok=True) 136 | 137 | dump = { 138 | 'p3d': {}, 139 | 'T': {}, 140 | 'camera': {}, 141 | 'image': {}, 142 | 'p2d': {}, 143 | } 144 | 145 | p3d = np.stack([model3d.points3D[i].xyz for i in p3d_ids], 0) 146 | dump['p3d']['colors'] = [model3d.points3D[i].rgb for i in p3d_ids] 147 | dump['p3d']['xyz'] = p3d @ tfm.T 148 | 149 | dump['T']['refs'] = [] 150 | dump['camera']['refs'] = [] 151 | dump['image']['refs'] = [] 152 | dump['p2d']['refs'] = [] 153 | for idx, ref_id in enumerate(ref_ids): 154 | ref = model3d.dbs[ref_id] 155 | cam_r = Camera.from_colmap(model3d.cameras[ref.camera_id]) 156 | T_w2r = Pose.from_colmap(ref) 157 | 158 | qtvec = (rotmat2qvec(T_w2r.R.numpy() @ tfm.T), T_w2r.t.numpy()) 159 | dump['T']['refs'].append(qtvec) 160 | dump['camera']['refs'].append(frustum_points(cam_r)) 161 | 162 | tmp_name = f'ref{idx}.jpg' 163 | dump['image']['refs'].append(tmp_name) 164 | copy_compress_image( 165 | paths.reference_images / ref.name, assets / tmp_name) 166 | 167 | p2d_, valid_ = cam_r.world2image(T_w2r * p3d) 168 | p2d_ = p2d_[valid_ & mask_q.any(0)] / cam_r.size 169 | dump['p2d']['refs'].append(p2d_.numpy()) 170 | 171 | qtvec_q = [(rotmat2qvec(T.R.numpy() @ tfm.T), T.t.numpy()) for T in T_w2q] 172 | dump['T']['query'] = qtvec_q 173 | dump['camera']['query'] = frustum_points(cam_q) 174 | 175 | p2d_q_norm = [np.asarray(p[v]/cam_q.size) for p, v in zip(p2d_q, mask_q)] 176 | dump['p2d']['query'] = p2d_q_norm[-1] 177 | 178 | tmp_name = 'query.jpg' 179 | dump['image']['query'] = tmp_name 180 | copy_compress_image(paths.query_images / name_q, assets / tmp_name) 181 | 182 | with open(assets / 'dump.json', 'w') as fid: 183 | json.dump(format_json(dump), fid, separators=(',', ':')) 184 | 185 | # We dump 2D points as a separate json because it is much heavier 186 | # and thus slower to load. 187 | dump_p2d = { 188 | 'query': p2d_q_norm, 189 | 'masks': np.asarray(mask_q), 190 | } 191 | with open(assets / 'dump_p2d.json', 'w') as fid: 192 | json.dump(format_json(dump_p2d), fid, separators=(',', ':')) 193 | -------------------------------------------------------------------------------- /PureACL/visualization/viz_2d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 2D visualization primitives based on Matplotlib. 3 | 4 | 1) Plot images with `plot_images`. 5 | 2) Call `plot_keypoints` or `plot_matches` any number of times. 6 | 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. 7 | """ 8 | 9 | import matplotlib 10 | import matplotlib.pyplot as plt 11 | import matplotlib.patheffects as path_effects 12 | import numpy as np 13 | 14 | 15 | def cm_RdGn(x): 16 | """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" 17 | x = np.clip(x, 0, 1)[..., None]*2 18 | c = x*np.array([[0, 1., 0]]) + (2-x)*np.array([[1., 0, 0]]) 19 | return np.clip(c, 0, 1) 20 | 21 | 22 | def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5, 23 | adaptive=True, autoscale=True): 24 | """Plot a set of images horizontally. 25 | Args: 26 | imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W). 27 | titles: a list of strings, as titles for each image. 28 | cmaps: colormaps for monochrome images. 29 | adaptive: whether the figure size should fit the image aspect ratios. 30 | """ 31 | n = len(imgs) 32 | if not isinstance(cmaps, (list, tuple)): 33 | cmaps = [cmaps] * n 34 | 35 | if adaptive: 36 | ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H 37 | else: 38 | ratios = [4/3] * n 39 | figsize = [sum(ratios)*4.5, 4.5] 40 | fig, ax = plt.subplots( 41 | 1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios}) 42 | if n == 1: 43 | ax = [ax] 44 | for i in range(n): 45 | ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i])) 46 | ax[i].get_yaxis().set_ticks([]) 47 | ax[i].get_xaxis().set_ticks([]) 48 | ax[i].set_axis_off() 49 | for spine in ax[i].spines.values(): # remove frame 50 | spine.set_visible(False) 51 | if titles: 52 | ax[i].set_title(titles[i]) 53 | if not autoscale: 54 | ax[i].autoscale(False) 55 | fig.tight_layout(pad=pad) 56 | 57 | 58 | def plot_keypoints(kpts, colors='lime', ps=6): 59 | """Plot keypoints for existing images. 60 | Args: 61 | kpts: list of ndarrays of size (N, 2). 62 | colors: string, or list of list of tuples (one for each keypoints). 63 | ps: size of the keypoints as float. 64 | """ 65 | if not isinstance(colors, list): 66 | colors = [colors] * len(kpts) 67 | axes = plt.gcf().axes 68 | for a, k, c in zip(axes, kpts, colors): 69 | if k is not None: 70 | a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0) 71 | 72 | 73 | def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.): 74 | """Plot matches for a pair of existing images. 75 | Args: 76 | kpts0, kpts1: corresponding keypoints of size (N, 2). 77 | color: color of each match, string or RGB tuple. Random if not given. 78 | lw: width of the lines. 79 | ps: size of the end points (no endpoint if ps=0) 80 | indices: indices of the images to draw the matches on. 81 | a: alpha opacity of the match lines. 82 | """ 83 | fig = plt.gcf() 84 | ax = fig.axes 85 | assert len(ax) > max(indices) 86 | ax0, ax1 = ax[indices[0]], ax[indices[1]] 87 | fig.canvas.draw() 88 | 89 | assert len(kpts0) == len(kpts1) 90 | if color is None: 91 | color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist() 92 | elif len(color) > 0 and not isinstance(color[0], (tuple, list)): 93 | color = [color] * len(kpts0) 94 | 95 | if lw > 0: 96 | # transform the points into the figure coordinate system 97 | transFigure = fig.transFigure.inverted() 98 | fkpts0 = transFigure.transform(ax0.transData.transform(kpts0)) 99 | fkpts1 = transFigure.transform(ax1.transData.transform(kpts1)) 100 | fig.lines += [matplotlib.lines.Line2D( 101 | (fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), 102 | zorder=1, transform=fig.transFigure, c=color[i], linewidth=lw, 103 | alpha=a) 104 | for i in range(len(kpts0))] 105 | 106 | # freeze the axes to prevent the transform to change 107 | ax0.autoscale(enable=False) 108 | ax1.autoscale(enable=False) 109 | 110 | if ps > 0: 111 | ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps) 112 | ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps) 113 | 114 | 115 | def add_text(idx, text, pos=(0.01, 0.99), fs=15, color='w', 116 | lcolor='k', lwidth=2): 117 | ax = plt.gcf().axes[idx] 118 | t = ax.text(*pos, text, fontsize=fs, va='top', ha='left', 119 | color=color, transform=ax.transAxes) 120 | if lcolor is not None: 121 | t.set_path_effects([ 122 | path_effects.Stroke(linewidth=lwidth, foreground=lcolor), 123 | path_effects.Normal()]) 124 | 125 | 126 | def save_plot(path, **kw): 127 | """Save the current figure without any white margin.""" 128 | plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw) 129 | 130 | 131 | def features_to_RGB(*Fs, skip=1): 132 | """Project a list of d-dimensional feature maps to RGB colors using PCA.""" 133 | from sklearn.decomposition import PCA 134 | 135 | def normalize(x): 136 | return x / np.linalg.norm(x, axis=-1, keepdims=True) 137 | flatten = [] 138 | shapes = [] 139 | for F in Fs: 140 | c, h, w = F.shape 141 | F = np.rollaxis(F, 0, 3) 142 | F = F.reshape(-1, c) 143 | flatten.append(F) 144 | shapes.append((h, w)) 145 | flatten = np.concatenate(flatten, axis=0) 146 | 147 | pca = PCA(n_components=3) 148 | if skip > 1: 149 | pca.fit(normalize(flatten[::skip])) 150 | flatten = normalize(pca.transform(normalize(flatten))) 151 | else: 152 | flatten = normalize(pca.fit_transform(normalize(flatten))) 153 | flatten = (flatten + 1) / 2 154 | 155 | Fs = [] 156 | for h, w in shapes: 157 | F, flatten = np.split(flatten, [h*w], axis=0) 158 | F = F.reshape((h, w, 3)) 159 | Fs.append(F) 160 | assert flatten.shape[0] == 0 161 | return Fs 162 | -------------------------------------------------------------------------------- /PureACL/visualization/viz_3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3D visualization primitives based on Plotly. 3 | We might want to instead use a more powerful library like Open3D. 4 | Plotly however supports animations, buttons and sliders. 5 | 6 | 1) Initialize a figure with `fig = init_figure()` 7 | 2) Plot points, cameras, lines, or create a slider animation. 8 | 3) Call `fig.show()` to render the figure. 9 | """ 10 | 11 | import plotly.graph_objects as go 12 | import numpy as np 13 | 14 | from ..pixlib.geometry.utils import to_homogeneous 15 | 16 | 17 | def init_figure(height=800): 18 | """Initialize a 3D figure.""" 19 | fig = go.Figure() 20 | fig.update_layout( 21 | height=height, 22 | scene_camera=dict( 23 | eye=dict(x=0., y=-.1, z=-2), up=dict(x=0, y=-1., z=0)), 24 | scene=dict( 25 | xaxis=dict(showbackground=False), 26 | yaxis=dict(showbackground=False), 27 | aspectmode='data', dragmode='orbit'), 28 | margin=dict(l=0, r=0, b=0, t=0, pad=0)) # noqa E741 29 | return fig 30 | 31 | 32 | def plot_points(fig, pts, color='rgba(255, 0, 0, 1)', ps=2): 33 | """Plot a set of 3D points.""" 34 | x, y, z = pts.T 35 | tr = go.Scatter3d( 36 | x=x, y=y, z=z, mode='markers', marker_size=ps, 37 | marker_color=color, marker_line_width=.2) 38 | fig.add_trace(tr) 39 | 40 | 41 | def plot_camera(fig, R, t, K, color='rgb(0, 0, 255)'): 42 | """Plot a camera as a cone with camera frustum.""" 43 | x, y, z = t 44 | u, v, w = R @ -np.array([0, 0, 1]) 45 | tr = go.Cone( 46 | x=[x], y=[y], z=[z], u=[u], v=[v], w=[w], anchor='tip', 47 | showscale=False, colorscale=[[0, color], [1, color]], 48 | sizemode='absolute') 49 | fig.add_trace(tr) 50 | 51 | W, H = K[0, 2]*2, K[1, 2]*2 52 | corners = np.array([[0, 0], [W, 0], [W, H], [0, H], [0, 0]]) 53 | corners = to_homogeneous(corners) @ np.linalg.inv(K).T 54 | corners = (corners/2) @ R.T + t 55 | x, y, z = corners.T 56 | tr = go.Scatter3d( 57 | x=x, y=y, z=z, line=dict(color='rgba(0, 0, 0, .5)'), 58 | marker=dict(size=0.0001), showlegend=False) 59 | fig.add_trace(tr) 60 | 61 | 62 | def create_slider_animation(fig, traces): 63 | """Create a slider that animates a list of traces (e.g. 3D points).""" 64 | slider = {'steps': []} 65 | frames = [] 66 | fig.add_trace(traces[0]) 67 | idx = len(fig.data) - 1 68 | for i, tr in enumerate(traces): 69 | frames.append(go.Frame(name=str(i), traces=[idx], data=[tr])) 70 | step = {"args": [ 71 | [str(i)], 72 | {"frame": {"redraw": True}, 73 | "mode": "immediate"}], 74 | "label": i, 75 | "method": "animate"} 76 | slider['steps'].append(step) 77 | fig.frames = tuple(frames) 78 | fig.layout.sliders = (slider,) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # View Consistent Purification for Accurate Cross-View Localization 2 | 3 | View Consistent Purification for Accurate Cross-View Localization, Shan Wang, Yanhao Zhang, Akhil Perincherry, Ankit Vora and Hongdong Li, ICCV 2023 [Paper](https://arxiv.org/abs/2308.08110) 4 | 5 | ## Abstract 6 | This paper proposes a fine-grained self-localization method for outdoor robotics that utilizes a flexible number of onboard cameras and readily accessible satellite images. The proposed method addresses limitations in existing cross-view localization methods that struggle to handle noise sources such as moving objects and seasonal variations. It is the first sparse visual-only method that enhances perception in dynamic environments by detecting view-consistent key points and their corresponding deep features from ground and satellite views, while removing off-the-ground objects and establishing homography transformation between the two views. Moreover, the proposed method incorporates a spatial embedding approach that leverages camera intrinsic and extrinsic information to reduce the ambiguity of purely visual matching, leading to improved feature matching and overall pose estimation accuracy. The method exhibits strong generalization and is robust to environmental changes, requiring only geo-poses as ground truth. Extensive experiments on the KITTI and Ford Multi-AV Seasonal datasets demonstrate that our proposed method outperforms existing state-of-the-art methods, achieving median spatial accuracy errors below $0.5$ meters along the lateral and longitudinal directions, and a median orientation accuracy error below $2^\circ$. 7 | 8 |

9 | 10 |

11 | 12 | ## Installation 13 | 14 | PureACL is built with Python >=3.6 and PyTorch. The package includes code for both training and evaluation. Installing the package locally also installs the minimal dependencies listed in `requirements.txt`: 15 | 16 | ``` bash 17 | git clone https://github.com/ShanWang-Shan/PureACL.git 18 | cd PureACL/ 19 | pip install -e . 20 | ``` 21 | 22 | ## Datasets 23 | 24 | We construct our Ford-CVL datasets by correcting the spatial-consistent satellite counterparts from Google Map \cite{google} according to these GPS tags. More specifically, we find the large region covering the vehicle trajectory and uniformly partition the region into overlapping satellite image patches. Each satellite image patch has a resolution of $1280\times 1280$ pixels. A script to download the latest satellite images is provided in (ford_data_process/downloading_satellite_iamges.py). For access to our collected satellite images, please send an inquiry email to shan.wang@anu.edu.au from an academic institution email address. The data is strictly for academic purposes. we will provide you with the download link. 25 | 26 | 27 | Ford-CVL: Please first download the raw data (ground images) from [https://avdata.ford.com/](https://avdata.ford.com/). We provide the script(ford_data_process/raw_data_downloader.sh) for raw data download and the script(ford_data_process/other_data_downloader.sh) for processed data download. Your dataset folder structure should be like this. If the link in the script file has expired or lacks the necessary permissions, please contact us. 28 | ``` 29 | FordAV/ 30 | ├─ 2017-08-04-V2-Log*/ 31 | │ ├─ 2017-08-04-V2-Log*-FL/ 32 | │ │ └─ *******.png 33 | │ ├─ 2017-08-04-V2-Log*-RR/ 34 | │ ├─ 2017-08-04-V2-Log*-SL/ 35 | │ ├─ 2017-08-04-V2-Log*-SR/ 36 | │ ├─ info_files/ 37 | │ │ ├─ gps.csv 38 | │ │ ├─ gps_time.csv 39 | │ │ ├─ imu.csv 40 | │ │ ├─ pose_ground_truth.csv 41 | │ │ ├─ pose_localized.csv 42 | │ │ ├─ pose_raw.csv 43 | │ │ ├─ pose_tf.csv 44 | │ │ ├─ velocity_raw.csv 45 | │ │ ├─ groundview_gps.npy 46 | │ │ ├─ groundview_NED_pose_gt.npy 47 | │ │ ├─ groundview_pitchs_pose_gt.npy 48 | │ │ ├─ groundview_yaws_pose_gt.npy 49 | │ │ ├─ groundview_satellite_pair.npy 50 | │ │ ├─ satellite_gps_center.npy 51 | │ │ ├─ 2017-08-04-V2-Log*-FL-names.txt 52 | │ │ ├─ 2017-08-04-V2-Log*-RR-names.txt 53 | │ │ ├─ 2017-08-04-V2-Log*-SL-names.txt 54 | │ │ └─2017-08-04-V2-Log*-SR-names.txt 55 | │ ├─ Satellit_Image_18 56 | │ │ └─ satellite_*_lat_*_long_*_zoom_18_size_640x640_scale_2.png 57 | ├─ 2017-10-26-V2-Log*/ 58 | └─ V2/ 59 | ``` 60 | To update your dataset path, you can modify the "default_conf.dataset_dir" in the following files: "PureACL/pixlib/dataset/ford.py" or in your training/evaluation script. Additionally, if you wish to change the trajectory for the Ford-CVL dataset, you can adjust the "log_id_train/val/test" in the "PureACL/pixlib/dataset/ford.py" file. 61 | 62 | 63 | ## Models 64 | Weights of the model trained on *Ford-CVL*, hosted [here](https://drive.google.com/drive/folders/1X8pPmBYfLSYwiklQM_f67rInXAGVkTSQ?usp=sharing). 65 | 66 | 67 | ## Evaluation 68 | 69 | To perform the PureACL, simply launch the corresponding run script: 70 | 71 | ``` 72 | python -m PureACL.evaluation 73 | ``` 74 | 75 | ## Training 76 | 77 | To train the PureACL, simply launch the corresponding run script: 78 | 79 | ``` 80 | python -m PureACL.pixlib.train 81 | ``` 82 | 83 | ## BibTex Citation 84 | 85 | Please consider citing our work if you use any of the ideas presented in the paper or code from this repo: 86 | 87 | ``` 88 | @misc{wang2023view, 89 | title={View Consistent Purification for Accurate Cross-View Localization}, 90 | author={Shan Wang and Yanhao Zhang and Akhil Perincherry and Ankit Vora and Hongdong Li}, 91 | year={2023}, 92 | eprint={2308.08110}, 93 | archivePrefix={arXiv}, 94 | primaryClass={cs.CV} 95 | } 96 | ``` 97 | 98 | Thanks to the work of [Paul-Edouard Sarlin](psarlin.com/) et al., the code of this repository borrows heavily from their [psarlin.com/pixloc](https://psarlin.com/pixloc), and we follow the same pipeline to verify the effectiveness of our solution. 99 | -------------------------------------------------------------------------------- /architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanWang-Shan/PureACL/21a4c2e64f0eeafaa09117b6bc8aef40d9cdf4e3/architecture.jpg -------------------------------------------------------------------------------- /ford_data_process/SuperPoint_gen.py: -------------------------------------------------------------------------------- 1 | # check the matching relationship between cross-view images 2 | 3 | from input_libs import * 4 | import cv2 as cv 5 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 6 | from superpoint import SuperPoint 7 | import torch 8 | from torchvision import transforms 9 | 10 | root_folder = "/data/dataset/Ford_AV" 11 | 12 | log_id = "2017-10-26-V2-Log5"#"2017-08-04-V2-Log5"# 13 | # size of the satellite image and ground-view query image (left camera) 14 | # satellite_size = 1280 15 | query_size = [1656, 860] 16 | start_ratio = 0.6 17 | 18 | 19 | log_folder = os.path.join(root_folder , log_id, 'info_files') 20 | FL_image_names = read_txt(log_folder, log_id + '-FL-names.txt') 21 | FL_image_names.pop(0) 22 | nb_query_images = len(FL_image_names) 23 | 24 | ToTensor = transforms.Compose([ 25 | transforms.ToTensor()]) 26 | 27 | # init super point------------------------------------------ 28 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 29 | print('Running inference on device \"{}\"'.format(device)) 30 | config = { 31 | 'nms_radius': 4, 32 | 'keypoint_threshold': 0.005, 33 | 'max_keypoints': 256 34 | } 35 | 36 | superpoint = SuperPoint(config).eval().to(device) 37 | 38 | #----------------------------------------------------------- 39 | 40 | # # get the satellite images 41 | # satellite_folder = os.path.join( root_folder, log_id, "Satellite_Images") 42 | # satellite_names = glob.glob(satellite_folder + '/*.png') 43 | # nb_satellite_images = len(satellite_names) 44 | 45 | # satellite_dict = {} 46 | # for i in range(nb_satellite_images): 47 | # sate_img = cv.imread(satellite_names[i]) 48 | # # Initiate ORB detector 49 | # orb = cv.ORB_create(nfeatures=512*4) 50 | # # find the keypoints with ORB 51 | # kp = orb.detect(sate_img, None) 52 | # 53 | # if 0: #debug: 54 | # # draw only keypoints location,not size and orientation 55 | # img2 = cv.drawKeypoints(sate_img, kp, None, color=(0, 255, 0), flags=0) 56 | # plt.imshow(img2), plt.show() 57 | # # only save kp 58 | # kp_list = [] 59 | # for p in range(len(kp)): 60 | # kp_list.append(kp[p].pt) 61 | # 62 | # sat_file_name = satellite_names[i].split('/') 63 | # satellite_dict[sat_file_name[-1]] = np.array(kp_list) 64 | # write_numpy(log_folder, 'satellite_kp.npy', satellite_dict) 65 | # print('satellite_kp.npy saved') 66 | 67 | # # 3. read the matching pair 68 | # match_pair = read_numpy(log_folder , 'groundview_satellite_pair.npy') # 'groundview_satellite_pair_2.npy' 69 | 70 | for grd_folder in ('-FL','-RR','-SL','-SR'): 71 | query_image_folder = os.path.join(root_folder, log_id, log_id + grd_folder) 72 | 73 | # crop 74 | H_start = int(query_size[1]*start_ratio) 75 | H_end = query_size[1] 76 | 77 | grd_dict = {} 78 | for i in range(nb_query_images): 79 | grd_img = cv.imread(os.path.join(query_image_folder, FL_image_names[i][:-1]), cv.IMREAD_GRAYSCALE) 80 | if grd_img is None: 81 | print(os.path.join(query_image_folder, FL_image_names[i][:-1])) 82 | 83 | # trun np to tensor 84 | img = ToTensor(grd_img[H_start:H_end]) 85 | img = img.unsqueeze(0).to(device) # add b 86 | 87 | pred = superpoint({'image': img}) 88 | key_points = pred['keypoints'][0].detach().cpu().numpy() #[n,2] 89 | key_points[:, 1] += H_start 90 | if 0: # debug: 91 | grd_img = cv.imread(os.path.join(query_image_folder, FL_image_names[i][:-1])) 92 | for j in range(key_points.shape[0]): 93 | cv.circle(grd_img, (np.int32(key_points[j,0]), np.int32(key_points[j,1])), 2, (255, 0, 0), 94 | -1) 95 | plt.imshow(grd_img), plt.show() 96 | # save kp 97 | grd_dict[FL_image_names[i][:-1]] = key_points 98 | write_numpy(log_folder, grd_folder[1:]+'_kp.npy', grd_dict) 99 | print(grd_folder[1:]+'_kp.npy saved') 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /ford_data_process/angle_func.py: -------------------------------------------------------------------------------- 1 | # functions to process angles 2 | 3 | 4 | def convert_body_yaw_to_360(yaw_body): 5 | yaw_360 = 0 6 | # if (yaw_body >= 0.0) and (yaw_body <=90.0): 7 | # yaw_360 = 90.0 - yaw_body 8 | 9 | if (yaw_body >90.0) and (yaw_body <=180.0): 10 | yaw_360 = 360.0 - yaw_body + 90.0 11 | else: 12 | yaw_360 = 90.0 - yaw_body 13 | 14 | # if (yaw_body >= -90) and (yaw_body <0.0): 15 | # yaw_360 = 90.0 - yaw_body 16 | # 17 | # if (yaw_body >= -180) and (yaw_body < -90): 18 | # yaw_360 = 90.0 - yaw_body 19 | return yaw_360 -------------------------------------------------------------------------------- /ford_data_process/avi_gener.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | 4 | # root_folder = "/home/users/u7094434/projects/SIDFM/visual_kitti/confidence_maps_1" 5 | # 6 | # image_dir = os.path.join(root_folder) 7 | # 8 | # # video parameter 9 | # fps = 30 10 | # size = (708, 218) #(411, 218) 11 | # 12 | # for dir in os.listdir(image_dir): 13 | # # FL/RR~ 14 | # subdir = os.path.join(image_dir, dir) 15 | # if not os.path.isdir(subdir): 16 | # continue 17 | # 18 | # if ('fl' in dir) or ('rr' in dir) or ('sl' in dir) or ('sr' in dir) or ('sat' in dir): 19 | # print('process '+dir) 20 | # else: 21 | # continue 22 | # 23 | # if 'sat' in dir: 24 | # size = (218, 218) 25 | # 26 | # file_list = os.listdir(subdir) 27 | # num_list = [int(i.split('.')[0]) for i in file_list] 28 | # num_list.sort() 29 | # #file_list.sort() 30 | # 31 | # video = cv2.VideoWriter(dir+".avi", cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, size) 32 | # for num in num_list: 33 | # img = cv2.imread(os.path.join(subdir, str(num)+'.png')) 34 | # video.write(img) 35 | # 36 | # video.release() 37 | # cv2.destroyAllWindows() 38 | 39 | 40 | root_folder = "/home/users/u7094434/projects/SIDFM/visual_ford/pose_refine" 41 | fps = 3 42 | size = (436, 436) 43 | file_list = os.listdir(root_folder) 44 | file_list.sort() 45 | 46 | video = cv2.VideoWriter("pose_refine.avi", cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, size) 47 | for file in file_list: 48 | img = cv2.imread(os.path.join(root_folder, file)) 49 | video.write(img) 50 | 51 | video.release() 52 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /ford_data_process/check_cross_view_center_distance.py: -------------------------------------------------------------------------------- 1 | # for each query image, find the nearest satellite image, and calculate their distance 2 | 3 | from input_libs import * 4 | from angle_func import convert_body_yaw_to_360 5 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 6 | import gps_coord_func as gps_func 7 | 8 | root_folder = "/data/dataset/Ford_AV" 9 | 10 | log_id = "2017-08-04-V2-Log6" #"2017-10-26-V2-Log6"# 11 | info_dir = 'info_files' 12 | 13 | log_folder = os.path.join(root_folder, log_id, info_dir) 14 | 15 | 16 | Geodetic = read_numpy(log_folder, 'satellite_gps_center.npy') 17 | 18 | # first, get the query image location 19 | # ----------------------------------------------- 20 | 21 | # 1. get the image names 22 | imageNames = read_txt(log_folder, log_id + '-FL-names.txt') 23 | imageNames.pop(0) 24 | 25 | image_times = np.zeros((len(imageNames), 1)) 26 | for i in range(len(imageNames)): 27 | image_times[i] = float(imageNames[i][:-5]) 28 | 29 | 30 | # 2. read the gps times and data 31 | gps_data = read_csv(log_folder , "gps.csv") 32 | # remove the headlines 33 | gps_data.pop(0) 34 | # save timestamp -- >> gps 35 | gps_dict = {} 36 | gps_times = np.zeros((len(gps_data), 1)) 37 | gps_lat = np.zeros((len(gps_data))) 38 | gps_long = np.zeros((len(gps_data))) 39 | gps_height = np.zeros((len(gps_data))) 40 | sat_gsd = np.zeros((len(gps_data))) 41 | for i, line in zip(range(len(gps_data)), gps_data): 42 | gps_timeStamp = float(line[0]) 43 | gps_latLongAtt = "%s_%s_%s" % (line[10], line[11], line[12]) 44 | gps_dict[gps_timeStamp] = gps_latLongAtt 45 | gps_times[i] = gps_timeStamp / 1000.0 46 | gps_lat[i] = float(line[10]) 47 | gps_long[i] = float(line[11]) 48 | gps_height[i] = float(line[12]) 49 | sat_gsd[i] = 156543.03392 * np.cos(gps_lat[i]*np.pi/180.0) / np.power(2, 20) / 2.0 # a scale at 2 when downloading the dataset 50 | 51 | # 3. for each query image time, find the nearest gps tag 52 | 53 | neigh = NearestNeighbors(n_neighbors=1) 54 | neigh.fit(gps_times) 55 | # KNN search given the image utms 56 | distances, indices = neigh.kneighbors(image_times, return_distance=True) 57 | distances = distances.ravel() 58 | indices = indices.ravel() 59 | 60 | 61 | NED_coords_query = np.zeros((len(imageNames), 3)) 62 | 63 | gps_query = np.zeros((len(imageNames), 3)) 64 | 65 | for i in range(len(imageNames)): 66 | x,y,z = gps_func.GeodeticToEcef(gps_lat[indices[i]]*np.pi/180.0, gps_long[indices[i]]*np.pi/180.0, gps_height[indices[i]]) 67 | xEast,yNorth,zUp = gps_func.EcefToEnu( x, y, z, gps_lat[0]*np.pi/180.0, gps_long[0]*np.pi/180.0, gps_height[0]) 68 | NED_coords_query[i, 0] = xEast 69 | NED_coords_query[i, 1] = yNorth 70 | NED_coords_query[i, 2] = zUp 71 | 72 | gps_query[i, 0] = gps_lat[indices[i]] 73 | gps_query[i, 1] = gps_long[indices[i]] 74 | gps_query[i, 2] = gps_height[indices[i]] 75 | 76 | NED_coords_satellite = np.zeros((Geodetic.shape[0], 3)) 77 | for i in range(Geodetic.shape[0]): 78 | x,y,z = gps_func.GeodeticToEcef(Geodetic[i,0]*np.pi/180.0, Geodetic[i,1]*np.pi/180.0, Geodetic[i,2]) 79 | xEast,yNorth,zUp = gps_func.EcefToEnu( x, y, z, gps_lat[0]*np.pi/180.0, gps_long[0]*np.pi/180.0, gps_height[0]) 80 | NED_coords_satellite[i, 0] = xEast 81 | NED_coords_satellite[i, 1] = yNorth 82 | NED_coords_satellite[i, 2] = zUp 83 | 84 | 85 | # for each query, find the nearest satellite 86 | 87 | neigh = NearestNeighbors(n_neighbors=1) 88 | neigh.fit(NED_coords_satellite) 89 | # KNN search given the image utms 90 | distances, indices = neigh.kneighbors(NED_coords_query, return_distance=True) 91 | distances = distances.ravel() 92 | indices = indices.ravel() 93 | 94 | print("max distance: {}; min distance: {} ".format(np.amax(distances), np.amin(distances))) 95 | 96 | # save the ground-view query to satellite matching pair 97 | # save the gps coordinates of query images 98 | write_numpy(log_folder , 'groundview_satellite_pair.npy', indices) 99 | write_numpy(log_folder , 'groundview_gps.npy', gps_query) 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /ford_data_process/check_groundImg_orientation.py: -------------------------------------------------------------------------------- 1 | from input_libs import * 2 | from pose_func import read_csv, read_txt, write_numpy 3 | root_folder = "/data/dataset/Ford_AV/" 4 | 5 | log_id = "2017-10-26-V2-Log6" #"2017-08-04-V2-Log6" # 6 | info_folder = 'info_files' 7 | log_folder = os.path.join(root_folder , log_id, info_folder) 8 | 9 | 10 | # read the ground-truth yaw angle from the file pose_ground_truth 11 | pose_data = read_csv(log_folder , "pose_ground_truth.csv") 12 | # remove the headlines 13 | pose_data.pop(0) 14 | 15 | pose_times = np.zeros((len(pose_data), 1)) 16 | pose_quat_x = np.zeros((len(pose_data))) 17 | pose_quat_y = np.zeros((len(pose_data))) 18 | pose_quat_z = np.zeros((len(pose_data))) 19 | pose_quat_w = np.zeros((len(pose_data))) 20 | 21 | pose_roll = np.zeros((len(pose_data))) 22 | pose_pitch = np.zeros((len(pose_data))) 23 | pose_yaw = np.zeros((len(pose_data))) 24 | 25 | pose_rotation = np.zeros((4,4,len(pose_data))) 26 | pose_NED = np.zeros((len(pose_data),3)) 27 | 28 | for i, line in zip(range(len(pose_data)), pose_data): 29 | pose_timeStamp = float(line[0]) / 1000.0 30 | pose_times[i] = pose_timeStamp 31 | pose_quat_x[i] = float(line[13]) 32 | pose_quat_y[i] = float(line[14]) 33 | pose_quat_z[i] = float(line[15]) 34 | pose_quat_w[i] = float(line[16]) 35 | pose_rotation[:,:,i] = transformations.quaternion_matrix([pose_quat_w[i], pose_quat_x[i], pose_quat_y[i], pose_quat_z[i]]) 36 | 37 | euler_angles = transformations.euler_from_matrix(pose_rotation[:,:,i]) 38 | 39 | pose_roll[i] = euler_angles[0]*180.0/np.pi 40 | pose_pitch[i] = euler_angles[1]*180.0/np.pi 41 | pose_yaw[i] = euler_angles[2] * 180.0 / np.pi 42 | # print(pose_rotation[:,:,i]) 43 | 44 | # NED pose 45 | pose_NED[i,0] = float(line[9]) 46 | pose_NED[i,1] = float(line[10]) 47 | pose_NED[i,2] = float(line[11]) 48 | 49 | # read the time of each query image, and fetch the yaw angle of each query image 50 | 51 | # 1. get the image names 52 | FL_image_names = read_txt(log_folder, log_id + '-FL-names.txt') 53 | FL_image_names.pop(0) 54 | 55 | nb_query_images = len(FL_image_names) 56 | 57 | image_times = np.zeros((len(FL_image_names), 1)) 58 | for i in range(len(FL_image_names)): 59 | image_times[i] = float(FL_image_names[i][:-5]) 60 | 61 | # 3. for each query image time, find the nearest gps tag 62 | 63 | neigh = NearestNeighbors(n_neighbors=1) 64 | neigh.fit(pose_times) 65 | # KNN search given the image utms 66 | distances, indices = neigh.kneighbors(image_times, return_distance=True) 67 | distances = distances.ravel() 68 | indices = indices.ravel() 69 | 70 | query_image_yaws = pose_yaw[indices] 71 | query_image_rolls = pose_roll[indices] 72 | query_image_pitchs = pose_pitch[indices] 73 | query_image_NED = pose_NED[indices] 74 | 75 | # save the yaw angles of qeury images 76 | write_numpy(log_folder, 'groundview_yaws_pose_gt.npy', query_image_yaws) 77 | write_numpy(log_folder, 'groundview_rolls_pose_gt.npy', query_image_rolls) 78 | write_numpy(log_folder, 'groundview_pitchs_pose_gt.npy', query_image_pitchs) 79 | write_numpy(log_folder, 'groundview_NED_pose_gt.npy', query_image_NED) 80 | 81 | 82 | x = np.linspace(0, query_image_yaws.shape[0]-1, query_image_yaws.shape[0]) 83 | plt.plot(x, query_image_yaws) 84 | plt.show() 85 | 86 | print("finish") 87 | 88 | -------------------------------------------------------------------------------- /ford_data_process/downloading_satellite_images.py: -------------------------------------------------------------------------------- 1 | # given the center geodestic coordinate of each satellite patch 2 | # retrieve satellite patchs from the google map server 3 | 4 | # Todo: using the viewing direction of forward camera to move the center of satellite patch 5 | # without this, the satellite patch only share a small common FoV as the ground-view query image 6 | 7 | # NOTE: 8 | # You need to provide a key 9 | keys = [''] 10 | 11 | import requests 12 | from io import BytesIO 13 | import os 14 | import time 15 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 16 | from PIL import Image as PILI 17 | 18 | root_folder = "/data/dataset/Ford_AV" 19 | 20 | log_id = "2017-10-26-V2-Log3" #"2017-08-04-V2-Log6" # 21 | 22 | info_folder = 'info_files' 23 | 24 | log_folder = os.path.join(root_folder, log_id, info_folder) 25 | 26 | Geodetic = read_numpy(log_folder, 'satellite_gps_center.npy') 27 | 28 | 29 | url_head = 'https://maps.googleapis.com/maps/api/staticmap?' 30 | zoom = 18 31 | sat_size = [640, 640] 32 | maptype = 'satellite' 33 | scale = 2 34 | 35 | 36 | nb_keys = len(keys) 37 | 38 | nb_satellites = Geodetic.shape[0] 39 | 40 | satellite_folder = os.path.join(root_folder, log_id, "Satellite_Images_18") 41 | 42 | if not os.path.exists(satellite_folder): 43 | os.makedirs(satellite_folder) 44 | 45 | 46 | for i in range(nb_satellites): 47 | 48 | lat_a, long_a, height_a = Geodetic[i, 0], Geodetic[i, 1], Geodetic[i, 2] 49 | 50 | image_name = satellite_folder + "/satellite_" + str(i) + "_lat_" + str(lat_a) + "_long_" + str( 51 | long_a) + "_zoom_" + str( 52 | zoom) + "_size_" + str(sat_size[0]) + "x" + str(sat_size[0]) + "_scale_" + str(scale) + ".png" 53 | 54 | if os.path.exists(image_name): 55 | continue 56 | 57 | time.sleep(1) 58 | 59 | saturl = url_head + 'center=' + str(lat_a) + ',' + str(long_a) + '&zoom=' + str( 60 | zoom) + '&size=' + str( 61 | sat_size[0]) + 'x' + str(sat_size[1]) + '&maptype=' + maptype + '&scale=' + str( 62 | scale) + '&format=png32' + '&key=' + \ 63 | keys[0] 64 | #f = requests.get(saturl, stream=True) 65 | 66 | try: 67 | f = requests.get(saturl, stream=True) 68 | f.raise_for_status() 69 | except requests.exceptions.HTTPError as err: 70 | raise SystemExit(err) 71 | 72 | bytesio = BytesIO(f.content) 73 | cur_image = PILI.open(bytesio) 74 | 75 | cur_image.save(image_name) 76 | -------------------------------------------------------------------------------- /ford_data_process/filelist_txt_gener.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | from sklearn.neighbors import NearestNeighbors 4 | import os.path 5 | import glob 6 | from matplotlib import pyplot as plt 7 | import matplotlib.image as mpimg 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | import transformations 10 | import yaml 11 | import math 12 | 13 | root_folder = "/data/dataset/Ford_AV" 14 | log_id = "2017-08-04-V2-Log3" #"2017-10-26-V2-Log6" # 15 | 16 | log_folder = os.path.join(root_folder, log_id) 17 | 18 | for dir in os.listdir(log_folder): 19 | # log-FL/RR~ 20 | subdir = os.path.join(log_folder, dir) 21 | if not os.path.isdir(subdir): 22 | continue 23 | 24 | if ('-FL' in dir) or ('-RR' in dir) or ('-SL' in dir) or ('-SR' in dir): 25 | print('process '+dir) 26 | else: 27 | continue 28 | 29 | file_list = os.listdir(subdir) 30 | file_list.sort() 31 | 32 | # # ignore reconstruction images 33 | # if '2017-10-26-V2-Log1' in dir: 34 | # file_list = file_list[2900:4900]+file_list[5200:8300] 35 | # if '2017-08-04-V2-Log1' in dir: 36 | # file_list = file_list[2000:3601] + file_list[4500:7900] 37 | # if '2017-10-26-V2-Log5' in dir: 38 | # file_list = file_list[200:2300] 39 | # if '2017-08-04-V2-Log5' in dir: 40 | # file_list = file_list[:3500] + file_list[7000:] 41 | 42 | txt_file_name = os.path.join(log_folder, 'info_files', dir + '-names.txt') 43 | with open(txt_file_name, 'w') as f: 44 | f.write(str(dir) + '\n') 45 | for name in file_list: 46 | f.write(str(name)+'\n') -------------------------------------------------------------------------------- /ford_data_process/get_gps_coverage.py: -------------------------------------------------------------------------------- 1 | # get the overall gps coverage 2 | from input_libs import * 3 | from angle_func import convert_body_yaw_to_360 4 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 5 | import gps_coord_func as gps_func 6 | 7 | root_folder = "/data/dataset/Ford_AV" 8 | log_id = "2017-10-26-V2-Log6"#"2017-08-04-V2-Log6" # 9 | save_dir = 'info_files' 10 | 11 | log_folder = os.path.join(root_folder, log_id, save_dir) 12 | 13 | 14 | gps_data = read_csv(log_folder , "gps.csv") 15 | # remove the headlines 16 | gps_data.pop(0) 17 | # save timestamp -- >> gps 18 | gps_dict = {} 19 | gps_times = np.zeros((len(gps_data), 1)) 20 | gps_lat = np.zeros((len(gps_data))) 21 | gps_long = np.zeros((len(gps_data))) 22 | gps_height = np.zeros((len(gps_data))) 23 | sat_gsd = np.zeros((len(gps_data))) 24 | for i, line in zip(range(len(gps_data)), gps_data): 25 | gps_timeStamp = float(line[0]) 26 | gps_latLongAtt = "%s_%s_%s" % (line[10], line[11], line[12]) 27 | gps_dict[gps_timeStamp] = gps_latLongAtt 28 | gps_times[i] = gps_timeStamp 29 | gps_lat[i] = float(line[10]) 30 | gps_long[i] = float(line[11]) 31 | gps_height[i] = float(line[12]) 32 | sat_gsd[i] = 156543.03392 * np.cos(gps_lat[i]*np.pi/180.0) / np.power(2, 20) / 2.0 # a scale at 2 when downloading the dataset 33 | 34 | # I use a consevative way 35 | sat_gsd_min = np.amin(sat_gsd) 36 | 37 | # convert all gps coordinates to NED coordinates 38 | # use the first gps as reference 39 | NED_coords = np.zeros((len(gps_data), 3)) 40 | for i in range(len(gps_data)): 41 | x,y,z = gps_func.GeodeticToEcef(gps_lat[i]*np.pi/180.0, gps_long[i]*np.pi/180.0, gps_height[i]) 42 | xEast,yNorth,zUp = gps_func.EcefToEnu( x, y, z, gps_lat[0]*np.pi/180.0, gps_long[0]*np.pi/180.0, gps_height[0]) 43 | NED_coords[i,0] = xEast 44 | NED_coords[i, 1] = yNorth 45 | NED_coords[i, 2] = zUp 46 | 47 | NED_coords_max = np.amax(NED_coords, axis=0) + 1.0 48 | NED_coords_min = np.amin(NED_coords, axis=0) - 1.0 49 | 50 | NED_coords_length = NED_coords_max - NED_coords_min 51 | 52 | # the coverage of each satellite image 53 | # I use a scale factor for ensurance 54 | scale_factor = 0.25 55 | imageSize = 1200.0 56 | cell_length = sat_gsd_min * imageSize * scale_factor 57 | 58 | 59 | cell_length_sqare = cell_length * cell_length 60 | delta = cell_length + 1.0 61 | 62 | numRows = math.ceil((NED_coords_max[0] - NED_coords_min[0]) / delta) 63 | numCols = math.ceil((NED_coords_max[1] - NED_coords_min[1]) / delta) 64 | 65 | 66 | hashTab = [] 67 | 68 | for i in range(numRows): 69 | hashTab.append([]) 70 | for j in range(numCols): 71 | hashTab[i].append([]) 72 | 73 | inds= np.ceil((NED_coords[:,:2].T - NED_coords_min[:2,None]) / delta ) 74 | 75 | for i in range(len(gps_data)): 76 | rowIndx = int(inds[0,i]) - 1 77 | colIndx = int(inds[1,i]) - 1 78 | hashTab[rowIndx][colIndx].append(i) 79 | 80 | # check non-empty cells 81 | 82 | numNonEmptyCells = 0 83 | 84 | cellCenters = [] 85 | 86 | for i in range(numRows): 87 | for j in range(numCols): 88 | if len(hashTab[i][j]) > 0: 89 | # this is a valid cell 90 | center_x = i * delta + 0.5 * delta 91 | center_y = j * delta + 0.5 * delta 92 | # all_z = NED_coords[hashTab[i][j],2] 93 | center_z = np.sum(NED_coords[hashTab[i][j],2]) / len(hashTab[i][j]) 94 | 95 | avg_x = np.sum(NED_coords[hashTab[i][j],0]) / len(hashTab[i][j]) - NED_coords_min[0] 96 | avg_y = np.sum(NED_coords[hashTab[i][j], 1]) / len(hashTab[i][j]) - NED_coords_min[1] 97 | 98 | numNonEmptyCells += 1 99 | cellCenters.append([center_x, center_y, center_z]) 100 | 101 | cellCenters = np.asarray(cellCenters) 102 | 103 | cellCenters[:,0] += NED_coords_min[0] 104 | cellCenters[:,1] += NED_coords_min[1] 105 | 106 | # after collect cellCenters, transform back to Geodetic coordinates system 107 | 108 | Geodetic = np.zeros((cellCenters.shape[0], 3)) 109 | 110 | for i in range(cellCenters.shape[0]): 111 | x_ecef, y_ecef, z_ecef = gps_func.EnuToEcef( cellCenters[i,0], cellCenters[i,1], cellCenters[i,2], gps_lat[0]*np.pi/180.0, gps_long[0]*np.pi/180.0, gps_height[0]) 112 | lat, lon, h = gps_func.EcefToGeodetic( x_ecef, y_ecef, z_ecef) 113 | Geodetic[i, 0] = lat 114 | Geodetic[i, 1] = lon 115 | Geodetic[i, 2] = h 116 | 117 | # save the gps coordinates of satellite images 118 | write_numpy(log_folder, 'satellite_gps_center.npy', Geodetic) 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /ford_data_process/gps_coord_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | 4 | semi_a = 6378137.0 5 | semi_b = 6356752.31424518 6 | 7 | ratio = semi_b / semi_a 8 | 9 | f = (semi_a - semi_b) / semi_a 10 | # eccentricity = np.sqrt(1.0 - ratio*ratio) 11 | eccentricity_square = 2.0 * f - f * f 12 | eccentricity = np.sqrt(eccentricity_square) 13 | 14 | # this is the reference gps for all ford dataset 15 | gps_ref_lat = 42.294319*np.pi/180.0 16 | gps_ref_long = -83.223275*np.pi/180.0 17 | gps_ref_height = 0.0 18 | 19 | ru_m = semi_a * (1.0-eccentricity_square)/np.power(1.0-eccentricity_square*np.sin(gps_ref_lat)*np.sin(gps_ref_lat), 1.5) 20 | ru_t = semi_a / np.sqrt(1.0-eccentricity_square*np.sin(gps_ref_lat)*np.sin(gps_ref_lat)) 21 | 22 | 23 | # Converts WGS-84 Geodetic point (lat, lon, h) to the 24 | # // Earth-Centered Earth-Fixed (ECEF) coordinates (x, y, z). 25 | def GeodeticToEcef( lat, lon, h): 26 | 27 | sin_lambda = np.sin(lat) 28 | cos_lambda = np.cos(lat) 29 | cos_phi = np.cos(lon) 30 | sin_phi = np.sin(lon) 31 | N = semi_a / np.sqrt(1.0 - eccentricity_square * sin_lambda * sin_lambda) 32 | 33 | x = (h + N) * cos_lambda * cos_phi 34 | y = (h + N) * cos_lambda * sin_phi 35 | z = (h + (1.0 - eccentricity_square) * N) * sin_lambda 36 | return x,y,z 37 | 38 | # Converts the Earth-Centered Earth-Fixed (ECEF) coordinates (x, y, z) to 39 | # East-North-Up coordinates in a Local Tangent Plane that is centered at the 40 | # (WGS-84) Geodetic point (lat0, lon0, h0). 41 | def EcefToEnu( x, y, z, lat0, lon0, h0): 42 | 43 | sin_lambda = np.sin(lat0) 44 | cos_lambda = np.cos(lat0) 45 | cos_phi = np.cos(lon0) 46 | sin_phi = np.sin(lon0) 47 | N = semi_a / np.sqrt(1.0 - eccentricity_square * sin_lambda * sin_lambda) 48 | x0 = (h0 + N) * cos_lambda * cos_phi 49 | y0 = (h0 + N) * cos_lambda * sin_phi 50 | z0 = (h0 + (1.0 - eccentricity_square) * N) * sin_lambda 51 | xd = x - x0 52 | yd = y - y0 53 | zd = z - z0 54 | # This is the matrix multiplication 55 | xEast = -sin_phi * xd + cos_phi * yd 56 | yNorth = -cos_phi * sin_lambda * xd - sin_lambda * sin_phi * yd + cos_lambda * zd 57 | zUp = cos_lambda * cos_phi * xd + cos_lambda * sin_phi * yd + sin_lambda * zd 58 | return xEast,yNorth,zUp 59 | 60 | 61 | # Inverse of EcefToEnu. Converts East-North-Up coordinates (xEast, yNorth, zUp) in a 62 | # Local Tangent Plane that is centered at the (WGS-84) Geodetic point (lat0, lon0, h0) 63 | # to the Earth-Centered Earth-Fixed (ECEF) coordinates (x, y, z). 64 | def EnuToEcef( xEast, yNorth, zUp, lat0, lon0, h0): 65 | 66 | # Convert to radians in notation consistent with the paper: 67 | sin_lambda = np.sin(lat0) 68 | cos_lambda = np.cos(lat0) 69 | cos_phi = np.cos(lon0) 70 | sin_phi = np.sin(lon0) 71 | N = semi_a / np.sqrt(1.0 - eccentricity_square * sin_lambda * sin_lambda) 72 | 73 | x0 = (h0 + N) * cos_lambda * cos_phi 74 | y0 = (h0 + N) * cos_lambda * sin_phi 75 | z0 = (h0 + (1 - eccentricity_square) * N) * sin_lambda 76 | 77 | xd = -sin_phi * xEast - cos_phi * sin_lambda * yNorth + cos_lambda * cos_phi * zUp 78 | yd = cos_phi * xEast - sin_lambda * sin_phi * yNorth + cos_lambda * sin_phi * zUp 79 | zd = cos_lambda * yNorth + sin_lambda * zUp 80 | 81 | x = xd + x0 82 | y = yd + y0 83 | z = zd + z0 84 | return x,y,z 85 | 86 | 87 | 88 | # Converts the Earth-Centered Earth-Fixed (ECEF) coordinates (x, y, z) to 89 | # (WGS-84) Geodetic point (lat, lon, h). 90 | def EcefToGeodetic( x, y, z): 91 | 92 | eps = eccentricity_square / (1.0 - eccentricity_square) 93 | p = math.sqrt(x * x + y * y) 94 | q = math.atan2((z * semi_a), (p * semi_b)) 95 | sin_q = np.sin(q) 96 | cos_q = np.cos(q) 97 | sin_q_3 = sin_q * sin_q * sin_q 98 | cos_q_3 = cos_q * cos_q * cos_q 99 | phi = math.atan2((z + eps * semi_b * sin_q_3), (p - eccentricity_square * semi_a * cos_q_3)) 100 | lon = math.atan2(y, x) * 180.0 / np.pi 101 | v = semi_a / math.sqrt(1.0 - eccentricity_square * np.sin(phi) * np.sin(phi)) 102 | h = (p / np.cos(phi)) - v 103 | 104 | lat = phi*180.0/np.pi 105 | 106 | return lat, lon, h 107 | 108 | 109 | 110 | def angular_distance_to_xy_distance( lat, long): 111 | # dx = ru_t * np.cos(gps_ref_lat) * (long1 - gps_ref_long) 112 | # dy = ru_m * (lat1 - gps_ref_lat) 113 | 114 | gps_lat = lat * np.pi / 180.0 115 | gps_long = long * np.pi / 180.0 116 | 117 | # ru_m_local = semi_a * (1.0 - eccentricity_square) / np.power( 118 | # 1.0 - eccentricity_square * np.sin(gps_lat) * np.sin(gps_lat), 1.5) 119 | # 120 | # ru_t_local = semi_a / np.sqrt(1.0 - eccentricity_square * np.sin(gps_lat) * np.sin(gps_lat)) 121 | # dx = ru_t_local * np.cos(gps_ref_lat) * (gps_long - gps_ref_long) 122 | # dy = ru_m_local * (gps_lat - gps_ref_lat) 123 | 124 | dx = ru_t * np.cos(gps_ref_lat) * (gps_long - gps_ref_long) 125 | dy = ru_m * (gps_lat - gps_ref_lat) 126 | return dx, dy 127 | 128 | 129 | def angular_distance_to_xy_distance_v2(lat_ref, long_ref, lat1, long1): 130 | 131 | gps_ref_lat = lat_ref * np.pi / 180.0 132 | gps_ref_long = long_ref * np.pi / 180.0 133 | 134 | ru_m_local = semi_a * (1.0 - eccentricity_square) / np.power( 135 | 1.0 - eccentricity_square * np.sin(gps_ref_lat) * np.sin(gps_ref_lat), 1.5) 136 | 137 | ru_t_local = semi_a / np.sqrt(1.0 - eccentricity_square * np.sin(gps_ref_lat) * np.sin(gps_ref_lat)) 138 | 139 | dx = ru_t_local * np.cos(gps_ref_lat) * (long1* np.pi / 180.0 - gps_ref_long) 140 | 141 | dy = ru_m_local * (lat1* np.pi / 180.0 - gps_ref_lat) 142 | 143 | return dx, dy -------------------------------------------------------------------------------- /ford_data_process/input_libs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import csv 3 | #import pymap3d as pm 4 | from sklearn.neighbors import NearestNeighbors 5 | import os.path 6 | import glob 7 | from matplotlib import pyplot as plt 8 | import matplotlib.image as mpimg 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | import transformations 11 | import yaml 12 | import math 13 | 14 | -------------------------------------------------------------------------------- /ford_data_process/orb_points_gen.py: -------------------------------------------------------------------------------- 1 | # check the matching relationship between cross-view images 2 | 3 | from input_libs import * 4 | import cv2 as cv 5 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 6 | 7 | root_folder = "../" 8 | 9 | log_id = "2017-10-26-V2-Log1"#"2017-08-04-V2-Log1"# 10 | # size of the satellite image and ground-view query image (left camera) 11 | satellite_size = 1280 12 | query_size = [1656, 860] 13 | 14 | 15 | log_folder = os.path.join(root_folder , log_id, 'info_files') 16 | FL_image_names = read_txt(log_folder, log_id + '-FL-names.txt') 17 | FL_image_names.pop(0) 18 | nb_query_images = len(FL_image_names) 19 | 20 | # get the satellite images 21 | satellite_folder = os.path.join( root_folder, log_id, "Satellite_Images") 22 | satellite_names = glob.glob(satellite_folder + '/*.png') 23 | 24 | nb_satellite_images = len(satellite_names) 25 | 26 | # satellite_dict = {} 27 | # for i in range(nb_satellite_images): 28 | # sate_img = cv.imread(satellite_names[i]) 29 | # # Initiate ORB detector 30 | # orb = cv.ORB_create(nfeatures=512*4) 31 | # # find the keypoints with ORB 32 | # kp = orb.detect(sate_img, None) 33 | # 34 | # if 0: #debug: 35 | # # draw only keypoints location,not size and orientation 36 | # img2 = cv.drawKeypoints(sate_img, kp, None, color=(0, 255, 0), flags=0) 37 | # plt.imshow(img2), plt.show() 38 | # # only save kp 39 | # kp_list = [] 40 | # for p in range(len(kp)): 41 | # kp_list.append(kp[p].pt) 42 | # 43 | # sat_file_name = satellite_names[i].split('/') 44 | # satellite_dict[sat_file_name[-1]] = np.array(kp_list) 45 | # write_numpy(log_folder, 'satellite_kp.npy', satellite_dict) 46 | # print('satellite_kp.npy saved') 47 | 48 | # 3. read the matching pair 49 | match_pair = read_numpy(log_folder , 'groundview_satellite_pair.npy') # 'groundview_satellite_pair_2.npy' 50 | 51 | for grd_folder in ('-SR','-SL'):#('-FL','-RR','-SL','-SR'): 52 | query_image_folder = os.path.join(root_folder, log_id, log_id + grd_folder) 53 | 54 | if grd_folder == '-FL': 55 | H_start = query_size[1]*62//100 56 | H_end = query_size[1]*95//100 57 | elif grd_folder == '-RR': 58 | H_start = query_size[1]*62//100 59 | H_end = query_size[1]*88//100 60 | else: 61 | H_start = query_size[1]*62//100 62 | H_end = query_size[1] 63 | 64 | grd_dict = {} 65 | for i in range(nb_query_images): 66 | grd_img = cv.imread(os.path.join(query_image_folder, FL_image_names[i][:-1])) 67 | if grd_img is None: 68 | print(os.path.join(query_image_folder, FL_image_names[i][:-1])) 69 | # Initiate ORB detector 70 | orb = cv.ORB_create(nfeatures=512) 71 | # find the keypoints with ORB 72 | # turn RGB to HSV-------------- 73 | detect_img = cv.cvtColor(grd_img[H_start:H_end], cv.COLOR_BGR2HSV) 74 | # v to range 50~150 75 | # detect_img[:,:,-1] = np.where(detect_img[:,:,-1]>150, 150, detect_img[:,:,-1]) 76 | # detect_img[:, :, -1] = np.where(detect_img[:, :, -1] < 100, 100, detect_img[:, :, -1]) 77 | detect_img[:, :, -1] = np.clip(detect_img[:, :, -1], 50, 170) 78 | detect_img = cv.cvtColor(detect_img, cv.COLOR_HSV2BGR) 79 | #------------------------------------- 80 | 81 | #detect_img = grd_img[query_size[1]*64//100:] 82 | kp = orb.detect(detect_img, None) 83 | #assert len(kp) >0, f'no orb points detected on {FL_image_names[i][:-1]}' 84 | if len(kp) <=0: 85 | print(f'no orb points detected on {FL_image_names[i][:-1]}') 86 | continue 87 | if 0: # debug: 88 | # draw only keypoints location,not size and orientation 89 | img2 = cv.drawKeypoints(detect_img, kp, None, color=(0, 255, 0), flags=0) 90 | plt.imshow(img2), plt.show() 91 | # only save kp 92 | kp_list = [] 93 | for p in range(len(kp)): 94 | ori_pt = [kp[p].pt[0], kp[p].pt[1]+H_start] 95 | kp_list.append(ori_pt) 96 | 97 | grd_dict[FL_image_names[i][:-1]] = kp_list 98 | write_numpy(log_folder, grd_folder[1:]+'_kp.npy', grd_dict) 99 | print(grd_folder[1:]+'_kp.npy saved') 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /ford_data_process/other_data_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | declare -A file_dict 3 | 4 | file_dict['2017-08-04-V2-Log1-info_files']='Eby22D-rGtZKoXISuPynmg8BtHyu0s5YV0x2wV_6Kef-Rg' 5 | file_dict['2017-08-04-V2-Log3-info_files']='EZbpEOaKIetMjAKUOGCSEjcB-2c_FiPbnuyhSVHvIWaEWg' 6 | file_dict['2017-08-04-V2-Log4-info_files']='EYOqUb_1a8tDsMBoAokV6S8Bln-gtn9fDLaksMuRw0Tm_w' 7 | file_dict['2017-08-04-V2-Log5-info_files']='EYtT85HV3fRFiAFdZWQ7zLMBqqDwvAbTojb2klF63rxp2A' 8 | file_dict['2017-10-26-V2-Log1-info_files']='EWCoEOyzTChMv3HVsmkER5UB3X9IT078rOQz-U7ju8bSCw' 9 | file_dict['2017-10-26-V2-Log3-info_files']='EYFArPf4IOFAgO9ZP-z8tEgBr3FqKz1JhPkdSGWRjlitYA' 10 | file_dict['2017-10-26-V2-Log4-info_files']='EfkruM3-WaVIkvRt7yaMjQcBysrwPIpnh0W9klESmh8DUA' 11 | file_dict['2017-10-26-V2-Log5-info_files']='ESd4FRPRyGtCtbgIUa3wM_kBCzGVSQcseyDdTftfmq0kHg' 12 | 13 | file_dict['2017-08-04-V2-Log1-sat18']='ESKlaMf-O5RHmNPi8okrinABcQnSgExZy-5EbW6fhs0XPQ' 14 | file_dict['2017-08-04-V2-Log3-sat18']='EQLBSX0ZWgxAp1XNpLffMHoBQBiEAUiqQXktJegimvhAwA' 15 | file_dict['2017-08-04-V2-Log4-sat18']='EUVR7aEvh9tKgVJjvo4JPskB53mwunNoTZguHTv9JDmWVQ' 16 | file_dict['2017-08-04-V2-Log5-sat18']='EbSMjz-tueRItQszhv5ysQ4B3HrEzHsgWoceKhb5YceXPQ' 17 | file_dict['2017-10-26-V2-Log1-sat18']='EWHc-Sc8LlhAleFK8ZyW9MsBup0TUnELBlg9sim3OJ5LEw' 18 | file_dict['2017-10-26-V2-Log3-sat18']='Eer-nJaT62xPqBGTPsTlUAEBU5R-97rcii2rc3cvN_Sl4w' 19 | file_dict['2017-10-26-V2-Log4-sat18']='EeKCHU9_egVIocQzX-B_FQMBmsDsIVniu-yYaIg2kiXjhA' 20 | file_dict['2017-10-26-V2-Log5-sat18']='Eedfhp8nYh9CqMepag-QgcwBKK_Rao6Q4YGsFhBwplYZ1w' 21 | 22 | for key in "${!file_dict[@]}"; do 23 | link=${file_dict[$key]} 24 | name=$key'.tar.gz' 25 | dir=${key: 0: 18}'/'${i: 19} 26 | if [ ! -d ${i: 0: 18} ]; then 27 | mkdir ${i: 0: 18} 28 | fi 29 | if [ ! -d $dir ]; then 30 | mkdir $dir 31 | fi 32 | echo "Downloading: "$key 33 | #wget --no-check-certificate 'https://anu365-my.sharepoint.com/:u:/g/personal/u7094434_anu_edu_au/'$link'&download=1' 34 | wget --no-check-certificate 'https://anu365-my.sharepoint.com/:u:/g/personal/u7094434_anu_edu_au/'$link'?download=1' 35 | mv $link'?download=1' $name 36 | tar -zxvf $name -C $dir 37 | rm $name 38 | done 39 | -------------------------------------------------------------------------------- /ford_data_process/pose_func.py: -------------------------------------------------------------------------------- 1 | # functions to process poses 2 | 3 | from input_libs import * 4 | 5 | def quat_from_pose(trans): 6 | 7 | w = trans['transform']['rotation']['w'] 8 | x = trans['transform']['rotation']['x'] 9 | y = trans['transform']['rotation']['y'] 10 | z = trans['transform']['rotation']['z'] 11 | 12 | return [w,x,y,z] 13 | 14 | 15 | def read_calib_yaml(calib_folder, file_name): 16 | with open(os.path.join(calib_folder, file_name), 'r') as stream: 17 | try: 18 | cur_yaml = yaml.safe_load(stream) 19 | except yaml.YAMLError as exc: 20 | print(exc) 21 | return cur_yaml 22 | 23 | 24 | def read_txt(root_folder, file_name): 25 | with open(os.path.join(root_folder, file_name)) as f: 26 | cur_file = f.readlines() 27 | return cur_file 28 | 29 | def read_numpy(root_folder, file_name): 30 | with open(os.path.join(root_folder, file_name), 'rb') as f: 31 | cur_file = np.load(f) 32 | return cur_file 33 | 34 | def write_numpy(root_folder, file_name, current_file): 35 | with open(os.path.join(root_folder, file_name), 'wb') as f: 36 | np.save(f, current_file) 37 | 38 | def read_csv(root_folder, file_name): 39 | with open(os.path.join(root_folder, file_name), newline='') as f: 40 | reader = csv.reader(f) 41 | cur_file = list(reader) 42 | return cur_file -------------------------------------------------------------------------------- /ford_data_process/project_lidar_to_camera.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | import glob 3 | import os 4 | from sklearn.neighbors import NearestNeighbors 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import yaml 8 | import transformations 9 | import cv2 10 | 11 | # based on https://github.com/Ford/AVData/issues/26 12 | # https://github.com/Ford/AVData/issues/28 13 | 14 | 15 | def quat_from_pose(trans): 16 | 17 | w = trans['transform']['rotation']['w'] 18 | x = trans['transform']['rotation']['x'] 19 | y = trans['transform']['rotation']['y'] 20 | z = trans['transform']['rotation']['z'] 21 | 22 | return [w,x,y,z] 23 | 24 | def trans_from_pose(trans): 25 | 26 | x = trans['transform']['translation']['x'] 27 | y = trans['transform']['translation']['y'] 28 | z = trans['transform']['translation']['z'] 29 | 30 | return [x,y,z] 31 | 32 | 33 | def inverse_pose(pose): 34 | 35 | pose_inv = np.identity(4) 36 | pose_inv[:3,:3] = np.transpose(pose[:3,:3]) 37 | pose_inv[:3, 3] = - pose_inv[:3,:3] @ pose[:3,3] 38 | 39 | return pose_inv 40 | 41 | 42 | # warp lidar to image, project 43 | 44 | lidar_folder = "/media/yanhao/8tb/00_data_ford/2017-08-04/V2/Log1/2017_08_04_V2_Log1_lidar_blue_pointcloud/" 45 | 46 | image_floder = "/media/yanhao/8tb/00_data_ford/2017-08-04/V2/Log1/2017-08-04-V2-Log1-FL/" 47 | 48 | calib_folder = "/media/yanhao/8tb/00_data_ford/Calibration-V2/" 49 | 50 | # get the nb of scans 51 | 52 | pcd_names = glob.glob(lidar_folder + '*.pcd') 53 | image_names = glob.glob(image_floder + '*.png') 54 | 55 | pcd_time_stamps = np.zeros((len(pcd_names),1)) 56 | 57 | for i, pcd_name in zip(range(len(pcd_names)), pcd_names): 58 | pcd_time_stamp = float(os.path.split(pcd_name)[-1][:-4]) 59 | pcd_time_stamps[i] = pcd_time_stamp 60 | 61 | image_time_stamps = np.zeros((len(image_names),1)) 62 | 63 | for i, image_name in zip(range(len(image_names)), image_names): 64 | image_time_stamp = float(os.path.split(image_name)[-1][:-4]) 65 | image_time_stamps[i] = image_time_stamp 66 | 67 | neigh = NearestNeighbors(n_neighbors=1) 68 | 69 | neigh.fit(image_time_stamps) 70 | 71 | # KNN search given the image utms 72 | # find the nearest lidar scan 73 | distances, indices = neigh.kneighbors(pcd_time_stamps, return_distance=True) 74 | distances = distances.ravel() 75 | indices = indices.ravel() 76 | 77 | lidar_matched_images = [image_names[indices[i]] for i in range(indices.shape[0])] 78 | 79 | # read the lidar point clouds and project to the corresponding image 80 | 81 | with open(calib_folder + "cameraFrontLeft_body.yaml", 'r') as stream: 82 | try: 83 | FL_cameraFrontLeft_body = yaml.safe_load(stream) 84 | except yaml.YAMLError as exc: 85 | print(exc) 86 | 87 | with open(calib_folder + "cameraFrontLeftIntrinsics.yaml", 'r') as stream: 88 | try: 89 | FL_cameraFrontLeftIntrinsics = yaml.safe_load(stream) 90 | except yaml.YAMLError as exc: 91 | print(exc) 92 | 93 | 94 | with open(calib_folder + "lidarBlue_body.yaml", 'r') as stream: 95 | try: 96 | lidarBlue_body = yaml.safe_load(stream) 97 | except yaml.YAMLError as exc: 98 | print(exc) 99 | 100 | # transform 3D points to camera coordinate system 101 | FL_relPose_body = transformations.quaternion_matrix(quat_from_pose(FL_cameraFrontLeft_body)) 102 | FL_relTrans_body = trans_from_pose(FL_cameraFrontLeft_body) 103 | FL_relPose_body[0,3] = FL_relTrans_body[0] 104 | FL_relPose_body[1,3] = FL_relTrans_body[1] 105 | FL_relPose_body[2,3] = FL_relTrans_body[2] 106 | 107 | LidarBlue_relPose_body = transformations.quaternion_matrix(quat_from_pose(lidarBlue_body)) 108 | LidarBlue_relTrans_body = trans_from_pose(lidarBlue_body) 109 | LidarBlue_relPose_body[0,3] = LidarBlue_relTrans_body[0] 110 | LidarBlue_relPose_body[1,3] = LidarBlue_relTrans_body[1] 111 | LidarBlue_relPose_body[2,3] = LidarBlue_relTrans_body[2] 112 | 113 | LidarBlue_relPose_FL = inverse_pose(FL_relPose_body) @ LidarBlue_relPose_body 114 | 115 | # camera projection matrix 116 | FL_proj_mat = np.asarray(FL_cameraFrontLeftIntrinsics['P']).reshape(3,4) 117 | 118 | # image size 119 | image_width = 1656 120 | image_height = 860 121 | for i in range(indices.shape[0]): 122 | 123 | # read image 124 | img_0 = cv2.imread(lidar_matched_images[i], cv2.IMREAD_GRAYSCALE) 125 | # read point clouds 126 | pcd1 = o3d.io.read_point_cloud(pcd_names[i]) 127 | # get the transformed 3D points 128 | pcd1.transform(LidarBlue_relPose_FL) 129 | pcd_cam = np.transpose(np.asarray(pcd1.points)) 130 | # non0-visible mask 131 | non_visible_mask = pcd_cam[2,:] < 0.0 132 | # project the 3D points to image 133 | pcd_img = FL_proj_mat[:3,:3] @ pcd_cam + FL_proj_mat[:3,3][:,None] 134 | pcd_img = pcd_img / pcd_img[2,:] 135 | 136 | # update non0-visible mask 137 | non_visible_mask = non_visible_mask | (pcd_img[0,:] > image_width) | (pcd_img[0,:] < 0.0) 138 | non_visible_mask = non_visible_mask | (pcd_img[1, :] > image_height) | (pcd_img[1, :] < 0.0) 139 | 140 | # visible mask 141 | visible_mask = np.invert(non_visible_mask) 142 | 143 | # get visible point clouds and their projections 144 | pcd_visible = pcd_cam[:,visible_mask] 145 | pcd_proj_visible = pcd_img[:, visible_mask] 146 | 147 | 148 | # --------------------------- 149 | # show the pcds 150 | pcd_visible_o3d = o3d.geometry.PointCloud() 151 | pcd_visible_o3d.points = o3d.utility.Vector3dVector(pcd_visible.T) 152 | o3d.io.write_point_cloud("temp.ply", pcd_visible_o3d) 153 | 154 | # --------------------------- 155 | # show the visible pcds on image 156 | color_image = cv2.cvtColor(img_0, cv2.COLOR_GRAY2RGB) 157 | for j in range(pcd_proj_visible.shape[1]): 158 | cv2.circle(color_image, (np.int32(pcd_proj_visible[0][j]), np.int32(pcd_proj_visible[1][j])), 2, (255, 0, 0), -1) 159 | plt.imshow(color_image) 160 | plt.title(os.path.split(lidar_matched_images[i])[-1]) 161 | plt.show() 162 | 163 | -------------------------------------------------------------------------------- /ford_data_process/raw_data_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | files=(2017-10-26/V2/Log3/2017-10-26-V2-Log3-FL.tar.gz 4 | 2017-10-26/V2/Log3/2017-10-26-V2-Log3-RR.tar.gz 5 | 2017-10-26/V2/Log3/2017-10-26-V2-Log3-SL.tar.gz 6 | 2017-10-26/V2/Log3/2017-10-26-V2-Log3-SR.tar.gz 7 | 2017-10-26/V2/Log4/2017-10-26-V2-Log4-FL.tar.gz 8 | 2017-10-26/V2/Log4/2017-10-26-V2-Log4-RR.tar.gz 9 | 2017-10-26/V2/Log4/2017-10-26-V2-Log4-SL.tar.gz 10 | 2017-10-26/V2/Log4/2017-10-26-V2-Log4-SR.tar.gz 11 | 2017-10-26/V2/Log5/2017-10-26-V2-Log5-FL.tar.gz 12 | 2017-10-26/V2/Log5/2017-10-26-V2-Log5-RR.tar.gz 13 | 2017-10-26/V2/Log5/2017-10-26-V2-Log5-SL.tar.gz 14 | 2017-10-26/V2/Log5/2017-10-26-V2-Log5-SR.tar.gz 15 | 2017-10-26/V2/Log6/2017-10-26-V2-Log6-FL.tar.gz 16 | 2017-10-26/V2/Log6/2017-10-26-V2-Log6-RR.tar.gz 17 | 2017-10-26/V2/Log6/2017-10-26-V2-Log6-SL.tar.gz 18 | 2017-10-26/V2/Log6/2017-10-26-V2-Log6-SR.tar.gz 19 | 2017-08-04/V2/Log3/2017-08-04-V2-Log3-FL.tar.gz 20 | 2017-08-04/V2/Log3/2017-08-04-V2-Log3-RR.tar.gz 21 | 2017-08-04/V2/Log3/2017-08-04-V2-Log3-SL.tar.gz 22 | 2017-08-04/V2/Log3/2017-08-04-V2-Log3-SR.tar.gz 23 | 2017-08-04/V2/Log4/2017-08-04-V2-Log4-FL.tar.gz 24 | 2017-08-04/V2/Log4/2017-08-04-V2-Log4-RR.tar.gz 25 | 2017-08-04/V2/Log4/2017-08-04-V2-Log4-SL.tar.gz 26 | 2017-08-04/V2/Log4/2017-08-04-V2-Log4-SR.tar.gz 27 | 2017-08-04/V2/Log5/2017-08-04-V2-Log5-FL.tar.gz 28 | 2017-08-04/V2/Log5/2017-08-04-V2-Log5-RR.tar.gz 29 | 2017-08-04/V2/Log5/2017-08-04-V2-Log5-SL.tar.gz 30 | 2017-08-04/V2/Log5/2017-08-04-V2-Log5-SR.tar.gz 31 | 2017-08-04/V2/Log6/2017-08-04-V2-Log6-FL.tar.gz 32 | 2017-08-04/V2/Log6/2017-08-04-V2-Log6-RR.tar.gz 33 | 2017-08-04/V2/Log6/2017-08-04-V2-Log6-SL.tar.gz 34 | 2017-08-04/V2/Log6/2017-08-04-V2-Log6-SR.tar.gz) 35 | 36 | 37 | wget 'https://ford-multi-av-seasonal.s3-us-west-2.amazonaws.com/Calibration/Calibration-V2.tar.gz' 38 | tar -zxvf 'Calibration-V2.tar.gz' -C './' 39 | rm 'Calibration-V2.tar.gz' 40 | for i in ${files[@]}; do 41 | shortname=${i: 19} 42 | fullname=$i 43 | dir=${i: 19: 18}'/'${i: 19: 21} 44 | if [ ! -d ${i: 19: 18} ]; then 45 | mkdir ${i: 19: 18} 46 | fi 47 | if [ ! -d $dir ]; then 48 | mkdir $dir 49 | fi 50 | echo "Downloading: "$shortname 51 | wget 'https://ford-multi-av-seasonal.s3-us-west-2.amazonaws.com/'$fullname 52 | tar -zxvf $shortname -C $dir 53 | rm $shortname 54 | done 55 | -------------------------------------------------------------------------------- /ford_data_process/shift_to_pose.py: -------------------------------------------------------------------------------- 1 | # check the matching relationship between cross-view images 2 | 3 | from input_libs import * 4 | from angle_func import convert_body_yaw_to_360 5 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy 6 | import gps_coord_func as gps_func 7 | 8 | root_folder = "/data/dataset/Ford_AV" 9 | log_id = "2017-10-26-V2-Log4" 10 | 11 | log_folder = os.path.join(root_folder, log_id) 12 | info_folder = os.path.join(log_folder,'info_files') 13 | FL_image_names = read_txt(info_folder, log_id + '-FL-names.txt') 14 | FL_image_names.pop(0) 15 | nb_query_images = len(FL_image_names) 16 | groundview_gps = read_numpy(info_folder, 'groundview_gps.npy') # 'groundview_gps_2.npy' 17 | groundview_yaws = read_numpy(info_folder, 'groundview_yaws_pose_gt.npy') # 'groundview_yaws_pose.npy' 18 | shift_pose = read_numpy(log_folder, 'shift_pose.npy') # yaw in pi, north, est 19 | 20 | pre_pose = [] 21 | for i in range(nb_query_images): 22 | shift_pose = read_numpy(log_folder, 'shift_pose.npy') # yaw in pi, north, est 23 | query_gps = groundview_gps[i] 24 | # print(query_gps) 25 | # print(groundview_yaws[i]) 26 | # using the original gps reference and calculate the offset of the ground-view query 27 | east, north = gps_func.angular_distance_to_xy_distance(query_gps[0], query_gps[1]) 28 | east = east + shift_pose[i][2] 29 | north = north + shift_pose[i][1] 30 | yaw = groundview_yaws[i]+ shift_pose[i][0]*180./np.pi 31 | # turn to +- pi 32 | if abs(yaw) > 180.: 33 | yaw = yaw - np.sign(yaw)*180 34 | pre_pose.append([yaw,north,east]) 35 | 36 | with open(os.path.join(log_folder, 'pred_pose.npy'), 'wb') as f: 37 | np.save(f, pre_pose) 38 | -------------------------------------------------------------------------------- /ford_data_process/show_all_logs_traj.py: -------------------------------------------------------------------------------- 1 | # this script shows the trajectories of all logs 2 | from input_libs import * 3 | import gps_coord_func as gps_func 4 | from pose_func import read_numpy 5 | from transformations import euler_matrix 6 | 7 | root_dir = "/data/dataset/Ford_AV/" 8 | Logs = ["2017-10-26-V2-Log4",] 9 | # "2017-10-26-V2-Log1", 10 | # # "2017-10-26-V2-Log2", 11 | # # "2017-10-26-V2-Log3", 12 | # # "2017-10-26-V2-Log4", 13 | # "2017-08-04-V2-Log1", 14 | # "2017-10-26-V2-Log5", 15 | # "2017-08-04-V2-Log5"] 16 | # # "2017-10-26-V2-Log6"] 17 | 18 | nb_logs = len(Logs) 19 | 20 | all_gps_neds = [] 21 | all_slam_neds = [] 22 | 23 | for log_iter in range(nb_logs): 24 | 25 | cur_log_dir = root_dir + Logs[log_iter] 26 | # read the raw gps data 27 | # gps_file = cur_log_dir + "/info_files/gps.csv" 28 | # with open(gps_file, newline='') as f: 29 | # reader = csv.reader(f) 30 | # gps_data = list(reader) 31 | # # remove the headlines 32 | # gps_data.pop(0) 33 | # 34 | # cur_nb_gps = len(gps_data) 35 | # 36 | # gps_geodestic = np.zeros((cur_nb_gps, 3)) 37 | # 38 | # # convert to NED position 39 | # gps_ned = np.zeros((cur_nb_gps, 3)) 40 | # 41 | # for i, line in zip(range(cur_nb_gps), gps_data): 42 | # gps_geodestic[i, 0] = float(line[10]) 43 | # gps_geodestic[i, 1] = float(line[11]) 44 | # gps_geodestic[i, 2] = float(line[12]) 45 | # x, y, z = gps_func.GeodeticToEcef(gps_geodestic[i, 0] * np.pi / 180.0, gps_geodestic[i, 1] * np.pi / 180.0, 46 | # gps_geodestic[i, 2]) 47 | # xEast, yNorth, zUp = gps_func.EcefToEnu(x, y, z, gps_func.gps_ref_lat, gps_func.gps_ref_long, 48 | # gps_func.gps_ref_height) 49 | # gps_ned[i, 0] = yNorth 50 | # gps_ned[i, 1] = xEast 51 | # gps_ned[i, 2] = -zUp 52 | 53 | # read the ned gt pose 54 | query_ned = read_numpy(cur_log_dir+"/info_files", "groundview_NED_pose_gt.npy") 55 | 56 | # all_gps_neds.append(gps_ned) 57 | all_slam_neds.append(query_ned) 58 | 59 | # # plot the gps trajectories 60 | # color_list = ['g', 'b', 'r', 'k', 'c', 'm', 'y'] 61 | # 62 | # for log_iter in range(nb_logs): 63 | # plt.plot(all_gps_neds[log_iter][:, 1], all_gps_neds[log_iter][:, 0], color=color_list[log_iter], linewidth=1, 64 | # label='log_{}'.format(log_iter)) 65 | # # plt.plot(all_slam_neds[log_iter][:, 1], all_slam_neds[log_iter][:, 0], color=color_list[log_iter], linewidth=1, linestyle="--", 66 | # # label='log_{}'.format(log_iter)) 67 | # 68 | # plt.xlabel("East") 69 | # plt.ylabel("North") 70 | # plt.legend() 71 | # plt.show() 72 | 73 | # read the shift pose 74 | shift_dir = '/home/users/u7094434/projects/SIDFM/pixloc/ford_shift_100/' 75 | shift_R = read_numpy(shift_dir, "pred_R.np") # pre->gt 76 | shift_T = read_numpy(shift_dir, "pred_T.np") # pre->gt 77 | yaws = read_numpy(cur_log_dir+"/info_files", 'groundview_yaws_pose_gt.npy') # 'groundview_yaws_pose.npy' 78 | yaws = yaws*np.pi / 180.0 79 | rolls = read_numpy(cur_log_dir+"/info_files", 'groundview_rolls_pose_gt.npy') # 'groundview_yaws_pose.npy' 80 | rolls = rolls * np.pi / 180.0 81 | pitchs = read_numpy(cur_log_dir+"/info_files", 'groundview_pitchs_pose_gt.npy') # 'groundview_yaws_pose.npy' 82 | pitchs = pitchs * np.pi / 180.0 83 | 84 | shift_ned = [] 85 | for i, line in zip(range(len(query_ned)), query_ned): 86 | body2ned = euler_matrix(rolls[i], pitchs[i], yaws[i]) 87 | shift_ned.append(query_ned[i] + body2ned[:3,:3]@shift_T[i]) 88 | shift_ned = np.array(shift_ned) 89 | 90 | 91 | # plot the trajectories 92 | Min_Pose = 1275#627 93 | Max_Pose = 1430#2596#-1 #2596 94 | plt.plot(query_ned[Min_Pose:Max_Pose, 1], query_ned[Min_Pose:Max_Pose, 0], alpha=0.5, color='r', marker='o', markersize=2) #linewidth=1) 95 | plt.plot(shift_ned[Min_Pose:Max_Pose, 1], shift_ned[Min_Pose:Max_Pose, 0], alpha=0.5, color='b', marker='o', markersize=2)# linewidth=1) 96 | 97 | plt.xlabel("East(m)") 98 | plt.ylabel("North(m)") 99 | plt.legend() 100 | plt.show() 101 | 102 | temp = 0 103 | -------------------------------------------------------------------------------- /ford_data_process/superpoint.py: -------------------------------------------------------------------------------- 1 | # %BANNER_BEGIN% 2 | # --------------------------------------------------------------------- 3 | # %COPYRIGHT_BEGIN% 4 | # 5 | # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL 6 | # 7 | # Unpublished Copyright (c) 2020 8 | # Magic Leap, Inc., All Rights Reserved. 9 | # 10 | # NOTICE: All information contained herein is, and remains the property 11 | # of COMPANY. The intellectual and technical concepts contained herein 12 | # are proprietary to COMPANY and may be covered by U.S. and Foreign 13 | # Patents, patents in process, and are protected by trade secret or 14 | # copyright law. Dissemination of this information or reproduction of 15 | # this material is strictly forbidden unless prior written permission is 16 | # obtained from COMPANY. Access to the source code contained herein is 17 | # hereby forbidden to anyone except current COMPANY employees, managers 18 | # or contractors who have executed Confidentiality and Non-disclosure 19 | # agreements explicitly covering such access. 20 | # 21 | # The copyright notice above does not evidence any actual or intended 22 | # publication or disclosure of this source code, which includes 23 | # information that is confidential and/or proprietary, and is a trade 24 | # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION, 25 | # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS 26 | # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS 27 | # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND 28 | # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE 29 | # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS 30 | # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE, 31 | # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART. 32 | # 33 | # %COPYRIGHT_END% 34 | # ---------------------------------------------------------------------- 35 | # %AUTHORS_BEGIN% 36 | # 37 | # Originating Authors: Paul-Edouard Sarlin 38 | # 39 | # %AUTHORS_END% 40 | # --------------------------------------------------------------------*/ 41 | # %BANNER_END% 42 | 43 | from pathlib import Path 44 | import torch 45 | from torch import nn 46 | 47 | def simple_nms(scores, nms_radius: int): 48 | """ Fast Non-maximum suppression to remove nearby points """ 49 | assert(nms_radius >= 0) 50 | 51 | def max_pool(x): 52 | return torch.nn.functional.max_pool2d( 53 | x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius) 54 | 55 | zeros = torch.zeros_like(scores) 56 | max_mask = scores == max_pool(scores) 57 | for _ in range(2): 58 | supp_mask = max_pool(max_mask.float()) > 0 59 | supp_scores = torch.where(supp_mask, zeros, scores) 60 | new_max_mask = supp_scores == max_pool(supp_scores) 61 | max_mask = max_mask | (new_max_mask & (~supp_mask)) 62 | return torch.where(max_mask, scores, zeros) 63 | 64 | 65 | def remove_borders(keypoints, scores, border: int, height: int, width: int): 66 | """ Removes keypoints too close to the border """ 67 | mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border)) 68 | mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border)) 69 | mask = mask_h & mask_w 70 | return keypoints[mask], scores[mask] 71 | 72 | 73 | def top_k_keypoints(keypoints, scores, k: int): 74 | if k >= len(keypoints): 75 | return keypoints, scores 76 | scores, indices = torch.topk(scores, k, dim=0) 77 | return keypoints[indices], scores 78 | 79 | 80 | def sample_descriptors(keypoints, descriptors, s: int = 8): 81 | """ Interpolate descriptors at keypoint locations """ 82 | b, c, h, w = descriptors.shape 83 | keypoints = keypoints - s / 2 + 0.5 84 | keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)], 85 | ).to(keypoints)[None] 86 | keypoints = keypoints*2 - 1 # normalize to (-1, 1) 87 | args = {'align_corners': True} if torch.__version__ >= '1.3' else {} 88 | descriptors = torch.nn.functional.grid_sample( 89 | descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args) 90 | descriptors = torch.nn.functional.normalize( 91 | descriptors.reshape(b, c, -1), p=2, dim=1) 92 | return descriptors 93 | 94 | 95 | class SuperPoint(nn.Module): 96 | """SuperPoint Convolutional Detector and Descriptor 97 | 98 | SuperPoint: Self-Supervised Interest Point Detection and 99 | Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew 100 | Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629 101 | 102 | """ 103 | default_config = { 104 | 'descriptor_dim': 256, 105 | 'nms_radius': 4, 106 | 'keypoint_threshold': 0.005, 107 | 'max_keypoints': -1, 108 | 'remove_borders': 4, 109 | } 110 | 111 | def __init__(self, config): 112 | super().__init__() 113 | self.config = {**self.default_config, **config} 114 | 115 | self.relu = nn.ReLU(inplace=True) 116 | self.pool = nn.MaxPool2d(kernel_size=2, stride=2) 117 | c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256 118 | 119 | self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1) 120 | self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1) 121 | self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1) 122 | self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1) 123 | self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1) 124 | self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1) 125 | self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1) 126 | self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1) 127 | 128 | self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 129 | self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0) 130 | 131 | self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1) 132 | self.convDb = nn.Conv2d( 133 | c5, self.config['descriptor_dim'], 134 | kernel_size=1, stride=1, padding=0) 135 | 136 | path = Path(__file__).parent / 'weights/superpoint_v1.pth' 137 | self.load_state_dict(torch.load(str(path))) 138 | 139 | mk = self.config['max_keypoints'] 140 | if mk == 0 or mk < -1: 141 | raise ValueError('\"max_keypoints\" must be positive or \"-1\"') 142 | 143 | print('Loaded SuperPoint model') 144 | 145 | def forward(self, data): 146 | """ Compute keypoints, scores, descriptors for image """ 147 | # Shared Encoder 148 | x = self.relu(self.conv1a(data['image'])) 149 | x = self.relu(self.conv1b(x)) 150 | x = self.pool(x) 151 | x = self.relu(self.conv2a(x)) 152 | x = self.relu(self.conv2b(x)) 153 | x = self.pool(x) 154 | x = self.relu(self.conv3a(x)) 155 | x = self.relu(self.conv3b(x)) 156 | x = self.pool(x) 157 | x = self.relu(self.conv4a(x)) 158 | x = self.relu(self.conv4b(x)) 159 | 160 | # Compute the dense keypoint scores 161 | cPa = self.relu(self.convPa(x)) 162 | scores = self.convPb(cPa) 163 | scores = torch.nn.functional.softmax(scores, 1)[:, :-1] 164 | b, _, h, w = scores.shape 165 | scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8) 166 | scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8) 167 | scores = simple_nms(scores, self.config['nms_radius']) 168 | 169 | # Extract keypoints 170 | keypoints = [ 171 | torch.nonzero(s > self.config['keypoint_threshold']) 172 | for s in scores] 173 | scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)] 174 | 175 | # Discard keypoints near the image borders 176 | keypoints, scores = list(zip(*[ 177 | remove_borders(k, s, self.config['remove_borders'], h*8, w*8) 178 | for k, s in zip(keypoints, scores)])) 179 | 180 | # Keep the k keypoints with highest score 181 | if self.config['max_keypoints'] >= 0: 182 | keypoints, scores = list(zip(*[ 183 | top_k_keypoints(k, s, self.config['max_keypoints']) 184 | for k, s in zip(keypoints, scores)])) 185 | 186 | # Convert (h, w) to (x, y) 187 | keypoints = [torch.flip(k, [1]).float() for k in keypoints] 188 | 189 | # Compute the dense descriptors 190 | cDa = self.relu(self.convDa(x)) 191 | descriptors = self.convDb(cDa) 192 | descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1) 193 | 194 | # Extract descriptors 195 | descriptors = [sample_descriptors(k[None], d[None], 8)[0] 196 | for k, d in zip(keypoints, descriptors)] 197 | 198 | return { 199 | 'keypoints': keypoints, 200 | 'scores': scores, 201 | 'descriptors': descriptors, 202 | } 203 | -------------------------------------------------------------------------------- /ford_data_process/vel_npy_gener.py: -------------------------------------------------------------------------------- 1 | # for each query image, find the nearest satellite image, and calculate their distance 2 | 3 | from input_libs import * 4 | from pose_func import quat_from_pose, read_calib_yaml, read_txt, read_numpy, read_csv, write_numpy 5 | 6 | root_folder = "../" 7 | 8 | log_id = "2017-08-04-V2-Log5" #"2017-10-26-V2-Log5" 9 | info_folder = 'info_files' 10 | 11 | log_folder = os.path.join(root_folder, log_id, info_folder) 12 | 13 | # first, get the query image location 14 | # ----------------------------------------------- 15 | 16 | # 1. get the image names 17 | imageNames = read_txt(log_folder, log_id + '-FL-names.txt') 18 | imageNames.pop(0) 19 | 20 | image_times = np.zeros((len(imageNames), 1)) 21 | for i in range(len(imageNames)): 22 | image_times[i] = float(imageNames[i][:-5]) 23 | 24 | 25 | # # 2. read the vel times and data 26 | # vel_data = read_csv(log_folder , "velocity_raw.csv") 27 | # # remove the headlines 28 | # vel_data.pop(0) 29 | # # save timestamp -- >> vel 30 | # vel_dict = {} 31 | # vel_times = np.zeros((len(vel_data), 1)) 32 | # vel_x = np.zeros((len(vel_data))) 33 | # vel_y = np.zeros((len(vel_data))) 34 | # vel_z = np.zeros((len(vel_data))) 35 | # for i, line in zip(range(len(vel_data)), vel_data): 36 | # vel_timeStamp = float(line[0]) 37 | # vel_times[i] = vel_timeStamp / 1000.0 38 | # vel_x[i] = float(line[8]) # !!!! 39 | # vel_y[i] = float(line[9]) 40 | # vel_z[i] = float(line[10]) 41 | # # 3. for each query image time, find the nearest vel tag 42 | # neigh = NearestNeighbors(n_neighbors=1)#20000) # not including every 3D points 43 | # neigh.fit(vel_times) 44 | 45 | lidar_folder = os.path.join(root_folder, log_id, 'lidar_blue_pointcloud') 46 | pcd_names = glob.glob(lidar_folder + '/*.pcd') 47 | pcd_time_stamps = np.zeros((len(pcd_names),1)) 48 | for i, pcd_name in zip(range(len(pcd_names)), pcd_names): 49 | pcd_time_stamp = float(os.path.split(pcd_name)[-1][:-4]) 50 | pcd_time_stamps[i] = pcd_time_stamp 51 | neigh = NearestNeighbors(n_neighbors=1) 52 | neigh.fit(pcd_time_stamps) 53 | 54 | # KNN search given the image utms 55 | distances, indices = neigh.kneighbors(image_times, return_distance=True) 56 | distances = distances.ravel() 57 | indices = indices.ravel() 58 | 59 | # NED_coords_query = np.zeros((len(imageNames), 3)) 60 | # 61 | # vel_query = np.zeros((len(imageNames), 3)) 62 | # 63 | # for i in range(len(imageNames)): 64 | # x = vel_x[indices[i]] 65 | # y = vel_y[indices[i]] 66 | # z = vel_z[indices[i]] 67 | # 68 | # vel_query[i, 0] = x 69 | # vel_query[i, 1] = y 70 | # vel_query[i, 2] = z 71 | 72 | # save the ground-view query to satellite matching pair 73 | # save the gps coordinates of query images 74 | # write_numpy(root_folder , 'vel_per_images.npy', vel_query) 75 | write_numpy(log_folder , 'vel_time_per_images.npy', pcd_time_stamps[indices]) 76 | 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /ford_data_process/weights/superpoint_v1.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShanWang-Shan/PureACL/21a4c2e64f0eeafaa09117b6bc8aef40d9cdf4e3/ford_data_process/weights/superpoint_v1.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7 2 | torchvision>=0.8 3 | numpy 4 | opencv-python 5 | tqdm 6 | matplotlib 7 | scipy 8 | h5py 9 | omegaconf 10 | tensorboard 11 | open3d 12 | scikit-learn 13 | setuptools==59.5.0 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from setuptools import setup 3 | 4 | description = ['Training and evaluation of the PureACL'] 5 | 6 | with open(str(Path(__file__).parent / 'README.md'), 'r', encoding='utf-8') as f: 7 | readme = f.read() 8 | 9 | with open(str(Path(__file__).parent / 'requirements.txt'), 'r') as f: 10 | dependencies = f.read().split('\n') 11 | 12 | extra_dependencies = ['jupyter', 'scikit-learn', 'ffmpeg-python', 'kornia'] 13 | 14 | setup( 15 | name='PureACL', 16 | version='1.0', 17 | packages=['PureACL'], 18 | python_requires='>=3.6', 19 | install_requires=dependencies, 20 | extras_require={'extra': extra_dependencies}, 21 | author='shan wang', 22 | description=description, 23 | long_description=readme, 24 | long_description_content_type="text/markdown", 25 | url='https://github.com/*/', 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: Apache Software License", 29 | "Operating System :: OS Independent", 30 | ], 31 | ) 32 | --------------------------------------------------------------------------------