├── .gitignore ├── LICENCE ├── README.md ├── config.yml ├── descriptor_generation ├── locus_descriptor.py ├── spatial_pooling.py └── temporal_pooling.py ├── evaluation ├── place_recognition.py └── pr_curve.py ├── main.py ├── requirements.txt ├── segmentation ├── extract_segment_features.py ├── extract_segments.py └── segmappy │ ├── README.md │ ├── config │ └── default_training.ini │ ├── core │ ├── config.py │ ├── dataset.py │ ├── generator.py │ └── preprocessor.py │ └── tools │ └── classifiertools.py └── utils ├── augment_scans.py ├── docs ├── pipeline.png └── robustness_tests.png ├── get_segmap_data.bash ├── kitti_dataloader.py ├── misc_utils.py └── setup_python_pcl.txt /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pickle 3 | log* 4 | __pycache__ 5 | slurm* 6 | .vscode 7 | test* 8 | pr_results 9 | segmap_data 10 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 2 | 3 | CSIRO Open Source Software Licence Agreement (variation of the BSD / MIT License) 4 | Copyright (c) 2021, Commonwealth Scientific and Industrial Research Organisation (CSIRO) ABN 41 687 119 230. 5 | All rights reserved. CSIRO is willing to grant you a licence to this software (Locus) on the following terms, except where otherwise indicated for third party material. 6 | Redistribution and use of this software in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 7 | • Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | • Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 | • Neither the name of CSIRO nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission of CSIRO. 10 | EXCEPT AS EXPRESSLY STATED IN THIS AGREEMENT AND TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, THE SOFTWARE IS PROVIDED "AS-IS". CSIRO MAKES NO REPRESENTATIONS, WARRANTIES OR CONDITIONS OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY REPRESENTATIONS, WARRANTIES OR CONDITIONS REGARDING THE CONTENTS OR ACCURACY OF THE SOFTWARE, OR OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, THE ABSENCE OF LATENT OR OTHER DEFECTS, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. 11 | TO THE FULL EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL CSIRO BE LIABLE ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, IN AN ACTION FOR BREACH OF CONTRACT, NEGLIGENCE OR OTHERWISE) FOR ANY CLAIM, LOSS, DAMAGES OR OTHER LIABILITY HOWSOEVER INCURRED. WITHOUT LIMITING THE SCOPE OF THE PREVIOUS SENTENCE THE EXCLUSION OF LIABILITY SHALL INCLUDE: LOSS OF PRODUCTION OR OPERATION TIME, LOSS, DAMAGE OR CORRUPTION OF DATA OR RECORDS; OR LOSS OF ANTICIPATED SAVINGS, OPPORTUNITY, REVENUE, PROFIT OR GOODWILL, OR OTHER ECONOMIC LOSS; OR ANY SPECIAL, INCIDENTAL, INDIRECT, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES, ARISING OUT OF OR IN CONNECTION WITH THIS AGREEMENT, ACCESS OF THE SOFTWARE OR ANY OTHER DEALINGS WITH THE SOFTWARE, EVEN IF CSIRO HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH CLAIM, LOSS, DAMAGES OR OTHER LIABILITY. 12 | APPLICABLE LEGISLATION SUCH AS THE AUSTRALIAN CONSUMER LAW MAY APPLY REPRESENTATIONS, WARRANTIES, OR CONDITIONS, OR IMPOSES OBLIGATIONS OR LIABILITY ON CSIRO THAT CANNOT BE EXCLUDED, RESTRICTED OR MODIFIED TO THE FULL EXTENT SET OUT IN THE EXPRESS TERMS OF THIS CLAUSE ABOVE "CONSUMER GUARANTEES". TO THE EXTENT THAT SUCH CONSUMER GUARANTEES CONTINUE TO APPLY, THEN TO THE FULL EXTENT PERMITTED BY THE APPLICABLE LEGISLATION, THE LIABILITY OF CSIRO UNDER THE RELEVANT CONSUMER GUARANTEE IS LIMITED (WHERE PERMITTED AT CSIRO’S OPTION) TO ONE OF FOLLOWING REMEDIES OR SUBSTANTIALLY EQUIVALENT REMEDIES: 13 | (a) THE REPLACEMENT OF THE SOFTWARE, THE SUPPLY OF EQUIVALENT SOFTWARE, OR SUPPLYING RELEVANT SERVICES AGAIN; 14 | (b) THE REPAIR OF THE SOFTWARE; 15 | (c) THE PAYMENT OF THE COST OF REPLACING THE SOFTWARE, OF ACQUIRING EQUIVALENT SOFTWARE, HAVING THE RELEVANT SERVICES SUPPLIED AGAIN, OR HAVING THE SOFTWARE REPAIRED. 16 | IN THIS CLAUSE, CSIRO INCLUDES ANY THIRD PARTY AUTHOR OR OWNER OF ANY PART OF THE SOFTWARE OR MATERIAL DISTRIBUTED WITH IT. CSIRO MAY ENFORCE ANY RIGHTS ON BEHALF OF THE RELEVANT THIRD PARTY. 17 | 18 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 19 | 20 | Third Party Components 21 | The following third party components are distributed with the Software. You agree to comply with the licence terms for these components as part of accessing the Software. Other third party software may also be identified in separate files distributed with the Software. 22 | ___________________________________________________________________ 23 | TensorFLow https://www.tensorflow.org/ 24 | This software is licensed under the Apache License 2.0 license. 25 | ___________________________________________________________________ 26 | ___________________________________________________________________ 27 | Open3d https://github.com/intel-isl/Open3D 28 | 29 | This software is licensed under the MIT license. 30 | ___________________________________________________________________ 31 | ___________________________________________________________________ 32 | pcl https://pointclouds.org 33 | 34 | This software is licensed under the BSD License. 35 | ___________________________________________________________________ 36 | ___________________________________________________________________ 37 | python-pcl https://github.com/strawlab/python-pcl 38 | 39 | This software is licensed under the BSD License. 40 | ___________________________________________________________________ 41 | ___________________________________________________________________ 42 | SegMap https://github.com/ethz-asl/segmap 43 | 44 | This software is licensed under the BSD 3-Clause License. 45 | ___________________________________________________________________ 46 | 47 | 48 | /////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Locus 2 | 3 | 4 | This repository is an open-source implementation of the ICRA 2021 paper: [Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order Pooling](https://arxiv.org/abs/2011.14497). 5 | 6 | More information: https://research.csiro.au/robotics/locus-pr/ 7 | 8 | Paper Pre-print: https://arxiv.org/abs/2011.14497 9 | 10 | ## Method overview. 11 | *Locus* is a global descriptor for large-scale place recognition using sequential 3D LiDAR point clouds. It encodes topological relationships and temporal consistency of scene components to obtain a discriminative and view-point invariant scene representation. 12 | 13 | ![](./utils/docs/pipeline.png) 14 | 15 | 16 | 17 | ## Usage 18 | 19 | ### Set up environment 20 | This project has been tested on Ubuntu 18.04 (with [Open3D](http://www.open3d.org/docs/release/) 0.11, [tensorflow](https://www.tensorflow.org/) 1.8.0, [pcl](https://pointclouds.org/) 1.8.1 and [python-pcl](https://github.com/strawlab/python-pcl) 0.3.0). Set up the requirments as follows: 21 | - Create [conda](https://docs.conda.io/en/latest/) environment with open3d and tensorflow-1.8 with python 3.6: 22 | ```bash 23 | conda create --name locus_env python=3.6 24 | conda activate locus_env 25 | pip install -r requirements.txt 26 | ``` 27 | - Set up python-pcl. See ```utils/setup_python_pcl.txt```. For further instructions, see [here](https://github.com/strawlab/python-pcl). 28 | - Segment feature extraction uses the pre-trained model from [ethz-asl/segmap](https://github.com/ethz-asl/segmap). Download and copy the relevant content in [segmap_data](http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/) into ```~/.segmap/```: 29 | ```bash 30 | ./utils/get_segmap_data.bash 31 | ``` 32 | - Download the [KITTI odometry dataset](http://www.cvlibs.net/datasets/kitti/eval_odometry.php) and set the path in ```config.yml```. 33 | 34 | 35 | ### Descriptor Generation 36 | Segment and generate Locus descriptor for each scan in a selected sequence (e.g., KITTI sequence 06): 37 | ```bash 38 | python main.py --seq '06' 39 | ``` 40 | The following flags can be used with ```main.py```: 41 | - ```--seq```: KITTI dataset sequence number. 42 | - ```--aug_type```: Scan augmentation type (optional for robustness tests). 43 | - ```--aug_param```: Parameter corresponding to above augmentation. 44 | 45 | ### Evaluation 46 | Sequence-wise place-recognition using extracted descriptors: 47 | ```bash 48 | python ./evaluation/place_recognition.py --seq '06' 49 | ``` 50 | Evaluation of place-recognition performance using Precision-Recall curves (multiple sequences): 51 | ```bash 52 | python ./evaluation/pr_curve.py 53 | ``` 54 | 55 | ### Additional scripts 56 | 57 | #### Robustness tests: 58 | Code of the robustness tests carried out in section V.C in paper. 59 | Extract Locus descriptors from scans of select augmentation: 60 | ```bash 61 | python main.py --seq '06' --aug_type 'rot' --aug_param 180 # Rotate about z-axis by random angle between 0-180 degrees. 62 | python main.py --seq '06' --aug_type 'occ' --aug_param 90 # Occlude sector of 90 degrees about random heading. 63 | ``` 64 | Evaluation is done as before. For vizualization, set ```config.yml->segmentation->visualize``` to ```True```. 65 | 66 | 67 | 68 | #### Testing individual modules: 69 | 70 | ```bash 71 | python ./segmentation/extract_segments.py # Extract and save Euclidean segments (S). 72 | python ./segmentation/extract_segment_features.py # Extract and save SegMap-CNN features (Fa) for given S. 73 | python ./descriptor_generation/spatial_pooling.py # Generate and save spatial segment features for given S and Fa. 74 | python ./descriptor_generation/temporal_pooling.py # Generate and save temporal segment features for given S and Fa. 75 | python ./descriptor_generation/locus_descriptor.py # Generate and save Locus global descriptor using above. 76 | ``` 77 | 78 | ## Citation 79 | 80 | If you find this work usefull in your research, please consider citing: 81 | 82 | ``` 83 | @inproceedings{vid2021locus, 84 | title={Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order Pooling}, 85 | author={Vidanapathirana, Kavisha and Moghadam, Peyman and Harwood, Ben and Zhao, Muming and Sridharan, Sridha and Fookes, Clinton}, 86 | booktitle={IEEE International Conference on Robotics and Automation (ICRA)}, 87 | year={2021}, 88 | eprint={arXiv preprint arXiv:2011.14497} 89 | } 90 | ``` 91 | 92 | ## Acknowledgment 93 | Functions from 3rd party have been acknowledged at the respective function definitions or readme files. This project was mainly inspired by the following: [ethz-asl/segmap](https://github.com/ethz-asl/segmap) and [irapkaist/scancontext](https://github.com/irapkaist/scancontext). 94 | 95 | ## Contact 96 | For questions/feedback, 97 | ``` 98 | kavisha.vidanapathirana@data61.csiro.au 99 | ``` 100 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | segmentation: 2 | g_height: 1.8 # Sensor height in meters. 3 | g_dist_thresh: 0.2 # Distance threshold for ground-plane extraction. 4 | g_normal_dist_weight: 0.1 # Normal distance weight for ground-plane extraction. 5 | c_min_size: 100 # Min points per segment. 6 | c_max_size: 15000 # Max points per segment. 7 | c_tolerence: 0.2 # Lower values -> object becomes multiple clusters. 8 | ds_factor: 0.01 # Downsample voxel size. If <= 0.01, skip downsample. 9 | filter_flat_seg: False # Filter out flat segments. 10 | horizontal_ratio: 15 # Identify flat segments. Higher = flatter segment. 11 | enforce_min_seg_count: False # Ensure extraction of a minimum number of segments. 12 | min_seg_count: 7 # Minimum number of segments (if above is True). 13 | visualize: False # Visualize all intermediate steps. 14 | 15 | descriptor_generation: 16 | n_frames_max: 3 # Number of frames to consider in temporal pooling. 17 | spatial_topk: 5 # Number of neighbours to consider in spatial pooling. 18 | PE_alpha: 0.5 # alpha in Power Euclidean transform. 19 | fb_mode: 'spatiotemporal' # Type of complementary feature in O2P. ['structural', 'spatial', 'temporal', 'spatiotemporal'] 20 | 21 | place_recognition: 22 | revisit_criteria: 3 # in meters. 23 | not_revisit_criteria: 20 # in meters. 24 | skip_time: 30 # in seconds. 25 | kdtree_retrieval: False # Use KDTree or exhaustive retrieval. 26 | cd_thresh_min: 0.1 # Thresholds on cosine-distance to top-1. For evaluation. 27 | cd_thresh_max: 0.5 28 | num_thresholds: 200 # Number of thresholds. Number of points on PR curve. 29 | 30 | pr_curve: 31 | log_axis: False # See PR curve with log axis. 32 | introspection_table: True # See table with TP,TN,FP,FN counts at F1max and RP100. 33 | 34 | paths: 35 | KITTI_dataset: '/mnt/088A6CBB8A6CA742/Datasets/Kitti/dataset/' 36 | save_dir: '/mnt/088A6CBB8A6CA742/locus_data/' -------------------------------------------------------------------------------- /descriptor_generation/locus_descriptor.py: -------------------------------------------------------------------------------- 1 | """ Generate Locus descriptor. Spatiotemporal feature pooling followed by O2P + PE. """ 2 | 3 | import numpy as np 4 | from numpy.linalg import norm 5 | from sklearn.preprocessing import normalize 6 | 7 | from spatial_pooling import * 8 | from temporal_pooling import * 9 | 10 | def get_locus_descriptor(idx, config_dict, database_dict): 11 | 12 | features = database_dict['features_database'][idx] 13 | feature_dim = 64 #np.shape(features)[1] 14 | 15 | # Get spatially and temporally pooled features. 16 | spatial_features = get_spatial_features( 17 | idx, config_dict['spatial_topk'], database_dict) 18 | temporal_features = get_temporal_features( 19 | idx, config_dict['n_frames_max'], [], database_dict) 20 | 21 | if spatial_features == [] or temporal_features == []: 22 | print('Degenerate scene. ID: ', idx) 23 | return [] 24 | 25 | # Second order pooling (O2P) of complementary features. 26 | locus_matrix = np.zeros((feature_dim, feature_dim)) 27 | for feature_idx in range(len(features)): 28 | sa_feature = np.asarray(features[feature_idx]) 29 | spatial_feature = np.asarray(spatial_features[feature_idx]) 30 | temporal_feature = np.asarray(temporal_features[feature_idx]) 31 | spatiotemporal_feature = (spatial_feature + temporal_feature)/2 32 | 33 | if config_dict['fb_mode'] == 'structural': 34 | second_order_feature = np.outer(sa_feature, sa_feature) 35 | elif config_dict['fb_mode'] == 'spatial': 36 | second_order_feature = np.outer(sa_feature, spatial_feature) 37 | elif config_dict['fb_mode'] == 'temporal': 38 | second_order_feature = np.outer(sa_feature, temporal_feature) 39 | else: 40 | second_order_feature = np.outer(sa_feature, spatiotemporal_feature) 41 | 42 | locus_matrix = np.maximum(locus_matrix, second_order_feature) 43 | 44 | # Power Euclidean (PE) non-linear transform. 45 | u_, s_, vh_ = np.linalg.svd(locus_matrix) 46 | s_alpha = np.power(s_, config_dict['PE_alpha']) 47 | locus_matrix_PE = np.dot(u_ * s_alpha, vh_) 48 | 49 | # Flatten and normalize. 50 | if config_dict['fb_mode'] == 'structural': 51 | locus_descriptor = locus_matrix_PE[np.triu_indices(feature_dim)] 52 | locus_descriptor = locus_descriptor/norm(locus_descriptor) 53 | descriptor_length = int((feature_dim/2)*(feature_dim+1)) 54 | 55 | else: 56 | locus_descriptor = normalize(locus_matrix_PE, norm='l2', axis=1, copy=True, return_norm=False) 57 | locus_descriptor = locus_descriptor/norm(locus_descriptor) 58 | locus_descriptor = np.hstack(locus_descriptor) 59 | descriptor_length = feature_dim*feature_dim 60 | 61 | return locus_descriptor.reshape(-1, descriptor_length) 62 | 63 | 64 | ##################################################################################### 65 | # Test 66 | ##################################################################################### 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | import sys 72 | import yaml 73 | 74 | seq = '08' 75 | 76 | cfg_file = open('config.yml', 'r') 77 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 78 | desc_params = cfg_params['descriptor_generation'] 79 | 80 | poses_file = cfg_params['paths']['KITTI_dataset'] + 'sequences/' + seq + '/poses.txt' 81 | transforms, _ = load_poses_from_txt(poses_file) 82 | rel_transforms = get_delta_pose(transforms) 83 | 84 | data_dir = cfg_params['paths']['save_dir'] + seq 85 | features_database = load_pickle(data_dir + '/features_database.pickle') 86 | segments_database = load_pickle(data_dir + '/segments_database.pickle') 87 | 88 | num_queries = len(features_database) 89 | seg_corres_database = [] 90 | database_dict = {'segments_database': segments_database, 91 | 'features_database': features_database, 92 | 'seg_corres_database': seg_corres_database, 93 | 'rel_transforms': rel_transforms} 94 | 95 | locus_descriptor_database = [] 96 | 97 | for query_idx in range(num_queries): 98 | locus_descriptor = get_locus_descriptor(query_idx, desc_params, database_dict) 99 | locus_descriptor_database.append(locus_descriptor) 100 | 101 | if (query_idx % 100 == 0): 102 | print('', query_idx, 'complete:', (query_idx*100)/num_queries, '%') 103 | sys.stdout.flush() 104 | 105 | save_dir = cfg_params['paths']['save_dir'] + seq 106 | save_pickle(locus_descriptor_database, save_dir + 107 | '/second_order/locus_descriptor_database.pickle') 108 | -------------------------------------------------------------------------------- /descriptor_generation/spatial_pooling.py: -------------------------------------------------------------------------------- 1 | """ Topological relationships and spatial feature pooling. """ 2 | 3 | import numpy as np 4 | from sklearn.neighbors import KDTree 5 | from scipy.spatial import distance 6 | import open3d as o3d 7 | import sys 8 | import os 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | from utils.misc_utils import * 12 | 13 | def get_segment_MTD(segments, topk): 14 | """ Topological relationships based on Minimum Translational Distance (MTD)""" 15 | """ Returns IDs and distances to nearest segments for all segments""" 16 | 17 | # Compute convex hulls of segments 18 | hulls = [] 19 | for segment in segments: 20 | pcd = o3d.geometry.PointCloud() 21 | pcd.points = o3d.utility.Vector3dVector(segment) 22 | hull, _ = pcd.compute_convex_hull() 23 | hulls.append(np.asarray(hull.vertices)) 24 | 25 | # Calculate MTDs between all segments 26 | num_points = len(segments) 27 | dist_mat = np.zeros((num_points, num_points)) 28 | for j in range(num_points): 29 | p_j = hulls[j] 30 | for k in range(num_points): 31 | if j >= k: # Only need to calculate upper triangle. 32 | continue 33 | p_k = hulls[k] 34 | dist = np.min(distance.cdist(p_j, p_k, 'euclidean')) 35 | dist_mat[j][k] = dist 36 | dist_mat[k][j] = dist 37 | 38 | # Find 'topk' closest segments for each segment 39 | min_dists = [] 40 | min_dist_ids = [] 41 | for s in range(num_points): 42 | dist_vec = dist_mat[s] 43 | min_dist_id = dist_vec.argsort()[1:topk+1] 44 | min_dist_ids.append(min_dist_id) 45 | min_dists.append(dist_vec[min_dist_id]) 46 | return np.asarray(min_dists), np.asarray(min_dist_ids) 47 | 48 | 49 | def get_spatial_features(idx, topk, database_dict): 50 | """ Return the pooled feature using topological relationships """ 51 | 52 | features = database_dict['features_database'][idx] 53 | segments = database_dict['segments_database'][idx] 54 | 55 | if len(features) < 7: 56 | return [] 57 | 58 | seg_tdists, seg_tdist_ids = get_segment_MTD(segments, topk) 59 | pooled_softmax_features = np.zeros((len(features), np.shape(features)[1])) 60 | 61 | # For each segment, pool features from related segments 62 | for c in range(len(segments)): 63 | dist = seg_tdists[c] 64 | ind = seg_tdist_ids[c] 65 | exp_dists = np.exp(-0.1*dist) 66 | exp_dists /= np.sum(exp_dists) 67 | 68 | for nn_idx in range(min(topk, len(ind))): 69 | f_vec = features[ind[nn_idx]] 70 | pooled_softmax_features[c] += exp_dists[nn_idx]*f_vec 71 | 72 | return pooled_softmax_features 73 | 74 | 75 | ##################################################################################### 76 | # Test 77 | ##################################################################################### 78 | 79 | if __name__ == "__main__": 80 | 81 | seq = '06' 82 | data_dir = '/mnt/7a46b84a-7d34-49f2-b8f0-00022755f514/seg_test/kitti/' + seq 83 | topk = 5 84 | 85 | features_database = load_pickle(data_dir + '/features_database.pickle') 86 | segments_database = load_pickle(data_dir + '/segments_database.pickle') 87 | database_dict = {'segments_database': segments_database, 88 | 'features_database': features_database} 89 | num_queries = len(features_database) 90 | pooled_softmax_features_database = [] 91 | 92 | for query_idx in range(num_queries): 93 | pooled_softmax_features = get_spatial_features( 94 | query_idx, topk, database_dict) 95 | pooled_softmax_features_database.append(pooled_softmax_features) 96 | 97 | if (query_idx % 100 == 0): 98 | print('', query_idx, 'complete:', (query_idx*100)/num_queries, '%') 99 | sys.stdout.flush() 100 | 101 | save_dir = '/mnt/bracewell/seg_test/kitti/' + seq 102 | save_pickle(pooled_softmax_features_database, save_dir + 103 | '/second_order/spatial_features_database.pickle') 104 | 105 | print('Test complete.') 106 | -------------------------------------------------------------------------------- /descriptor_generation/temporal_pooling.py: -------------------------------------------------------------------------------- 1 | """ Temporal segment correspondence estimation and feature pooling. """ 2 | 3 | import numpy as np 4 | from sklearn.neighbors import KDTree 5 | import sys 6 | import os 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | from utils.misc_utils import * 10 | from utils.kitti_dataloader import * 11 | 12 | def get_segment_correspondences(idx, database_dict): 13 | """ Return the corresponding segment IDs from previous frame """ 14 | corres_segment = [] 15 | 16 | # Return nans if previous frame as low number of segments. 17 | if len(database_dict['features_database'][idx-1]) < 7: 18 | for s in range(len(database_dict['segments_database'][idx])): 19 | corres_segment.append(np.nan) 20 | return corres_segment 21 | 22 | # Calculate segment centroids of previous frame. 23 | prev_centroids = [] 24 | for ps in database_dict['segments_database'][idx-1]: 25 | prev_centroids.append(np.mean(ps, axis=0)) 26 | 27 | segments = database_dict['segments_database'][idx] 28 | features = database_dict['features_database'][idx] 29 | rel_T = database_dict['rel_transforms'][idx-1] 30 | 31 | # KDTrees for feature-space and Euclidean-space 32 | ftree = KDTree(database_dict['features_database'][idx-1]) 33 | ctree = KDTree(np.asarray(prev_centroids)) 34 | 35 | for s in range(len(segments)): 36 | # Cordinates of current segment centroid wrt previous frame. 37 | centroid = np.mean(segments[s], axis=0) 38 | rel_centroid = euclidean_to_homogeneous(centroid) 39 | rel_centroid = np.matmul(rel_T, rel_centroid) 40 | centroid_new = homogeneous_to_euclidean(rel_centroid) 41 | 42 | # Feature-space NNs 43 | distf, indf = ftree.query(features[s].reshape(1, -1), k=5) 44 | # Euclidean-space NNs 45 | indc, distc = ctree.query_radius(centroid_new.reshape( 46 | 1, -1), 2, return_distance=True, count_only=False, sort_results=True) 47 | # Correspondence candidates 48 | ind_common = np.intersect1d(indf, indc[0]) 49 | if len(ind_common) < 1: 50 | # Return nan if zero candidates 51 | corres_segment.append(np.nan) 52 | continue 53 | elif len(ind_common) > 1: 54 | # If more than one candidate, 55 | min_ind_common = ind_common[0] 56 | min_cdist = distc[0][np.where(indc[0] == ind_common[0])] 57 | minf_dist = distf[np.where(indf == ind_common[0])] 58 | # Find the ID of segment which minimizes both feature-space and Euclidean-space distance 59 | for ind_com in range(1, len(ind_common)): 60 | cdist_check = distc[0][np.where( 61 | indc[0] == ind_common[ind_com])] - min_cdist 62 | fdist_check = distf[np.where( 63 | indf == ind_common[ind_com])] - minf_dist 64 | if cdist_check < 0 and fdist_check < 0: 65 | min_ind_common = ind_common[ind_com] 66 | min_cdist = distc[0][np.where( 67 | indc[0] == ind_common[ind_com])] 68 | minf_dist = distf[np.where(indf == ind_common[ind_com])] 69 | else: 70 | min_ind_common = ind_common[0] 71 | min_cdist = distc[0][np.where(indc[0] == ind_common[0])] 72 | minf_dist = distf[np.where(indf == ind_common[0])] 73 | corres_segment.append([min_ind_common, min_cdist[0], minf_dist[0]]) 74 | return corres_segment 75 | 76 | 77 | def get_temporal_features(idx, n_frames_max, n_count, database_dict): 78 | """ Return the pooled feature using all temporal correspondences """ 79 | features = database_dict['features_database'][idx] 80 | 81 | # Return nan if low number of segments. 82 | if len(features) < 7: 83 | database_dict['seg_corres_database'].append([]) 84 | return [] 85 | 86 | if(idx > 0): 87 | database_dict['seg_corres_database'].append( 88 | get_segment_correspondences(idx, database_dict)) 89 | 90 | pooled_softmax_features = np.zeros((len(features), np.shape(features)[1])) 91 | n_frames = min(idx, n_frames_max) 92 | 93 | # Segment-wise feature pooling 94 | for s in range(len(features)): 95 | pooled_softmax_features[s] += features[s] 96 | if n_frames < 1: # No correspondences for 0th frame 97 | continue 98 | 99 | seg_ind = s 100 | past_features_database = [] 101 | past_features_dist = [] 102 | 103 | # Get features of all previous correspondences 104 | for n in range(1, n_frames + 1): 105 | past_features = database_dict['features_database'][idx - n] 106 | if len(database_dict['seg_corres_database'][idx - n]) == 0: 107 | break 108 | segment_corres = database_dict['seg_corres_database'][idx - n][seg_ind] 109 | if is_nan(segment_corres): 110 | break 111 | seg_ind = segment_corres[0] 112 | past_features_database.append(past_features[seg_ind]) 113 | past_features_dist.append(np.linalg.norm( 114 | features[s] - past_features[seg_ind])) 115 | 116 | # Calculate pooling weights 117 | exp_dists = np.exp(-0.1*np.asarray(past_features_dist)) 118 | n_count.append(len(past_features_dist)) 119 | exp_dists /= np.sum(exp_dists) 120 | 121 | # Pool features of all previous correspondences 122 | for p in range(len(past_features_database)): 123 | pooled_softmax_features[s] += exp_dists[p] * \ 124 | past_features_database[p] 125 | return pooled_softmax_features 126 | 127 | ##################################################################################### 128 | # Test 129 | ##################################################################################### 130 | 131 | 132 | if __name__ == "__main__": 133 | 134 | seq = '06' 135 | base_dir = '/mnt/7a46b84a-7d34-49f2-b8f0-00022755f514/' 136 | n_frames_max = 3 137 | 138 | poses_file = base_dir + 'datasets/Kitti/dataset/sequences/' + seq + '/poses.txt' 139 | transforms, _ = load_poses_from_txt(poses_file) 140 | rel_transforms = get_delta_pose(transforms) 141 | 142 | data_dir = base_dir + 'seg_test/kitti/' + seq 143 | features_database = load_pickle(data_dir + '/features_database.pickle') 144 | segments_database = load_pickle(data_dir + '/segments_database.pickle') 145 | 146 | num_queries = len(features_database) 147 | seg_corres_database = [] 148 | database_dict = {'segments_database': segments_database, 149 | 'features_database': features_database, 150 | 'seg_corres_database': seg_corres_database, 151 | 'rel_transforms': rel_transforms} 152 | pooled_softmax_features_database = [] 153 | n_count = [] 154 | 155 | for query_idx in range(num_queries): 156 | pooled_softmax_features = get_temporal_features( 157 | query_idx, n_frames_max, n_count, database_dict) 158 | pooled_softmax_features_database.append(pooled_softmax_features) 159 | 160 | if (query_idx % 100 == 0): 161 | print('', query_idx, 'complete:', (query_idx*100)/num_queries, '%') 162 | sys.stdout.flush() 163 | 164 | save_dir = '/mnt/bracewell/seg_test/kitti/' + seq 165 | save_pickle(pooled_softmax_features_database, save_dir + 166 | '/second_order/temporal_features_database.pickle') 167 | 168 | print('') 169 | print('Avg frames aggregated: ', np.mean(n_count)) 170 | 171 | -------------------------------------------------------------------------------- /evaluation/place_recognition.py: -------------------------------------------------------------------------------- 1 | """ Online retrieval-based place-recognition using pre-computed global descriptors. """ 2 | # Based on: https://github.com/irapkaist/scancontext/blob/master/fast_evaluator/main.m 3 | # With updated evaluation criterea as set in: config.yaml->place_recognition 4 | 5 | import numpy as np 6 | import math 7 | import yaml 8 | import sys 9 | import os 10 | import argparse 11 | from tqdm import tqdm 12 | 13 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 14 | from utils.misc_utils import * 15 | from utils.kitti_dataloader import * 16 | 17 | # Load params 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--seq", default='00', help="KITTI sequence number") 20 | parser.add_argument("--aug_type", default='none', help="Scan augmentation type ['occ', 'rot', 'ds']") 21 | parser.add_argument("--aug_param", default=0, type=float, help="Scan augmentation parameter") 22 | args = parser.parse_args() 23 | 24 | test_name = 'initial_' + args.seq 25 | 26 | cfg_file = open('config.yml', 'r') 27 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 28 | pr_params = cfg_params['place_recognition'] 29 | desc_dir = cfg_params['paths']['save_dir'] + args.seq 30 | basedir = cfg_params['paths']['KITTI_dataset'] 31 | sequence_path = basedir + 'sequences/' + args.seq + '/' 32 | revisit_criteria = pr_params['revisit_criteria'] 33 | not_revisit_criteria = pr_params['not_revisit_criteria'] 34 | skip_time = pr_params['skip_time'] 35 | kdtree_retrieval = pr_params['kdtree_retrieval'] 36 | thresholds = np.linspace(pr_params['cd_thresh_min'], pr_params['cd_thresh_max'], pr_params['num_thresholds']) 37 | desc_file_name = '/locus_descriptor_' + cfg_params['descriptor_generation']['fb_mode'] 38 | if args.aug_type != 'none': 39 | desc_file_name += '_' + args.aug_type + str(int(args.aug_param)) 40 | test_name = args.aug_type + str(int(args.aug_param)) + '_' + args.seq 41 | 42 | ##################################################################################### 43 | 44 | locus_descriptor_database = load_pickle(desc_dir + desc_file_name + '.pickle') 45 | _, positions_database = load_poses_from_txt(sequence_path + 'poses.txt') 46 | timestamps = load_timestamps(sequence_path + '/times.txt') 47 | 48 | num_queries = len(positions_database) -1 49 | num_thresholds = len(thresholds) 50 | 51 | # Databases of previously visited/'seen' places. 52 | seen_poses, seen_descriptors = [], [] 53 | 54 | # Store results of evaluation. 55 | num_true_positive = np.zeros(num_thresholds) 56 | num_false_positive = np.zeros(num_thresholds) 57 | num_true_negative = np.zeros(num_thresholds) 58 | num_false_negative = np.zeros(num_thresholds) 59 | 60 | ret_timer = Timer() 61 | 62 | for query_idx in tqdm(range(num_queries)): 63 | 64 | locus_descriptor = locus_descriptor_database[query_idx] 65 | query_pose = positions_database[query_idx] 66 | query_time = timestamps[query_idx] 67 | 68 | if len(locus_descriptor) < 1: 69 | continue 70 | 71 | seen_descriptors.append(locus_descriptor) 72 | seen_poses.append(query_pose) 73 | 74 | if (query_time - skip_time) < 0: 75 | continue 76 | 77 | # Build retrieval database using entries 30s prior to current query. 78 | tt = next(x[0] for x in enumerate(timestamps) if x[1] > (query_time - skip_time)) 79 | db_seen_descriptors = np.copy(seen_descriptors) 80 | db_seen_poses = np.copy(seen_poses) 81 | db_seen_poses = db_seen_poses[:tt+1] 82 | db_seen_descriptors = db_seen_descriptors[:tt+1] 83 | db_seen_descriptors = db_seen_descriptors.reshape(-1, np.shape(locus_descriptor)[1]) 84 | 85 | nns = len(db_seen_descriptors) # If exaustive search 86 | if kdtree_retrieval: # If KDTree search 87 | tree = KDTree(db_seen_descriptors) 88 | nn = 50 89 | if (np.shape(db_seen_descriptors)[0] < nn): 90 | nn = np.shape(db_seen_descriptors)[0] 91 | 92 | dist, ind = tree.query(locus_descriptor, k=nn) 93 | nns = np.shape(dist)[1] 94 | 95 | # Find top-1 candidate. 96 | nearest_idx = 0 97 | min_dist = math.inf 98 | ret_timer.tic() 99 | for ith_candidate in range(nns): 100 | candidate_idx = ith_candidate 101 | if kdtree_retrieval: 102 | candidate_idx = ind[0][ith_candidate] 103 | 104 | candidate_descriptor = seen_descriptors[candidate_idx] 105 | distance_to_query = cosine_distance(locus_descriptor, candidate_descriptor) 106 | 107 | if( distance_to_query < min_dist): 108 | nearest_idx = candidate_idx 109 | min_dist = distance_to_query 110 | 111 | ret_timer.toc() 112 | place_candidate = seen_poses[nearest_idx] 113 | 114 | is_revisit = check_if_revisit(query_pose, db_seen_poses, revisit_criteria) 115 | 116 | # Evaluate top-1 candidate. 117 | for thres_idx in range(num_thresholds): 118 | threshold = thresholds[thres_idx] 119 | 120 | if( min_dist < threshold): # Positive Prediction 121 | p_dist = norm(query_pose - place_candidate) 122 | if p_dist < revisit_criteria: 123 | num_true_positive[thres_idx] += 1 124 | 125 | elif p_dist > not_revisit_criteria: 126 | num_false_positive[thres_idx] += 1 127 | 128 | else: # Negative Prediction 129 | if(is_revisit == 0): 130 | num_true_negative[thres_idx] += 1 131 | else: 132 | num_false_negative[thres_idx] += 1 133 | 134 | 135 | print('Average retrieval time per scan:') 136 | print(f"--- {ret_timer.avg}s---") 137 | 138 | save_dir = cfg_params['paths']['save_dir'] + 'pr_results/' + test_name 139 | if not os.path.exists(save_dir): 140 | os.makedirs(save_dir) 141 | print('Saving pickles: ', test_name) 142 | save_pickle(num_true_positive, save_dir + '/num_true_positive.pickle') 143 | save_pickle(num_false_positive, save_dir + '/num_false_positive.pickle') 144 | save_pickle(num_true_negative, save_dir + '/num_true_negative.pickle') 145 | save_pickle(num_false_negative, save_dir + '/num_false_negative.pickle') 146 | 147 | -------------------------------------------------------------------------------- /evaluation/pr_curve.py: -------------------------------------------------------------------------------- 1 | """ Precision-Recall curves plus introspection tools. """ 2 | 3 | import sys 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import yaml 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 10 | from utils.misc_utils import * 11 | 12 | cfg_file = open('config.yml', 'r') 13 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 14 | data_dir = cfg_params['paths']['save_dir'] + 'pr_results/' 15 | 16 | cfg_file = open('config.yml', 'r') 17 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 18 | log_axis = cfg_params['pr_curve']['log_axis'] 19 | introspect = cfg_params['pr_curve']['introspection_table'] 20 | test_name = 'initial_' #'rot180_' 21 | macros = { # folder, label, colour. 22 | 0: [test_name + '00', '00', 'red'], 23 | 1: [test_name + '02', '02', 'blue'], 24 | 2: [test_name + '05', '05', 'green'], 25 | 3: [test_name + '06', '06', 'yellow'], 26 | 4: [test_name + '07', '07', 'purple'], 27 | 5: [test_name + '08', '08', 'grey'], 28 | } 29 | 30 | ######################################################################################################################## 31 | EPS = [] 32 | F1s = [] 33 | F1data = [] 34 | EPdata = [] 35 | table_rows = [] 36 | 37 | for i in range(len(macros)): 38 | folder = macros[i][0] 39 | label = macros[i][1] 40 | colour = macros[i][2] 41 | 42 | _dir = data_dir + folder 43 | if os.path.isdir(_dir): 44 | table_rows.append(label) 45 | 46 | num_true_positive = load_pickle(_dir + '/num_true_positive.pickle') 47 | num_false_positive = load_pickle(_dir + '/num_false_positive.pickle') 48 | num_true_negative = load_pickle(_dir + '/num_true_negative.pickle') 49 | num_false_negative = load_pickle(_dir + '/num_false_negative.pickle') 50 | 51 | Precisions = [] 52 | Recalls = [] 53 | Accuracies = [] 54 | nThres = len(num_true_positive) 55 | 56 | RP100 = 0.0 57 | EP = 0.0 58 | F1max = 0.0 59 | 60 | for ithThres in range(nThres): 61 | nTrueNegative = num_true_negative[ithThres] 62 | nFalsePositive = num_false_positive[ithThres] 63 | nTruePositive = num_true_positive[ithThres] 64 | nFalseNegative = num_false_negative[ithThres] 65 | 66 | nTotalTestPlaces = nTrueNegative + nFalsePositive + nTruePositive + nFalseNegative 67 | 68 | Precision = 0.0 69 | Recall = 0.0 70 | F1 = 0.0 71 | Acc = (nTruePositive + nTrueNegative)/nTotalTestPlaces 72 | 73 | if nTruePositive > 0.0: 74 | Precision = nTruePositive / (nTruePositive + nFalsePositive) 75 | Recall = nTruePositive / (nTruePositive + nFalseNegative) 76 | F1 = 2 * Precision * Recall * (1/(Precision + Recall)) 77 | 78 | Precisions.append(Precision) 79 | Recalls.append(Recall) 80 | Accuracies.append(Acc) 81 | 82 | if F1 > F1max: 83 | F1max = F1 84 | f1max_tn = nTrueNegative 85 | f1max_fp = nFalsePositive 86 | f1max_tp = nTruePositive 87 | f1max_fn = nFalseNegative 88 | f1max_id = ithThres 89 | f1max_total = nTotalTestPlaces 90 | 91 | if int(Precision) == 1: 92 | RP100 = Recall 93 | EP_id = ithThres 94 | rp100_tn = nTrueNegative 95 | rp100_fp = nFalsePositive 96 | rp100_tp = nTruePositive 97 | rp100_fn = nFalseNegative 98 | rp100_total = nTotalTestPlaces 99 | 100 | if RP100 == 0.0: 101 | EP = Precisions[1]/2.0 102 | 103 | else: 104 | EP = 0.5 + (RP100/2.0) 105 | 106 | if log_axis: 107 | Precisions = 1- np.asarray(Precisions) 108 | Recalls = 1- np.asarray(Recalls) 109 | 110 | 111 | print('EP: ' , EP) 112 | print('F1max: ' , F1max) 113 | print('f1max_id: ' , f1max_id) 114 | EPS.append(EP) 115 | F1s.append(F1max) 116 | F1data.append([str(val) for val in (f1max_tn, f1max_tp, f1max_fn, f1max_fp, f1max_total, "{:.3f}".format(F1max))]) 117 | EPdata.append([str(val) for val in (rp100_tn, rp100_tp, rp100_fn, rp100_fp, rp100_total, "{:.3f}".format(EP))]) 118 | 119 | label = label + ', EP: ' + "{:.3f}".format(EP) + ', F1max: ' + "{:.3f}".format(F1max) 120 | 121 | plt.plot(Recalls, Precisions, marker='.', color=colour, label=label) 122 | 123 | ########################################################################################################################## 124 | """ Plot Precision-Recall curves """ 125 | 126 | plt.legend(prop=font_legend) 127 | plt.title('Locus performance on KITTI') 128 | 129 | if log_axis: 130 | plt.xlabel('Recall (log)', fontdict=font) 131 | plt.ylabel('Precision (log)', fontdict=font) 132 | plt.yscale('log') 133 | plt.xscale('log') 134 | ax = plt.gca() 135 | ax.set_xlim(1, 0.001) 136 | ax.set_ylim(1, 0.001) 137 | plt.xticks([1,0.1,0.01,0.001],['0%', '90%', '99%', '99.9%']) 138 | plt.yticks([1,0.1,0.01,0.001],['0%', '90%', '99%', '99.9%']) 139 | plt.grid(True, which='major') 140 | plt.grid(True, which='minor', color = 'whitesmoke') 141 | plt.show() 142 | else: 143 | plt.xlabel('Recall', fontdict=font) 144 | plt.ylabel('Precision', fontdict=font) 145 | plt.axis([0, 1, 0, 1.1]) 146 | plt.xticks(np.arange(0, 1.01, step=0.1)) 147 | plt.grid(True) 148 | plt.show() 149 | 150 | ########################################################################################################################## 151 | """ Tables for introspection of 'TN', 'TP', 'FN', 'FP' counts. """ 152 | 153 | if introspect: 154 | # rows = [macros[i][1] for i in range(len(macros))] 155 | F1columns = ('TN', 'TP', 'FN', 'FP', 'Total', 'F1max') 156 | EPcolumns = ('TN', 'TP', 'FN', 'FP', 'Total', 'EP') 157 | fig = plt.figure() 158 | ax1 = fig.add_subplot(2,1,1) 159 | table1 = ax1.table(cellText=F1data, rowLabels=table_rows, colLabels=F1columns, loc='center') 160 | table1.scale(0.7,2) 161 | ax1.axis('off') 162 | ax2 = fig.add_subplot(2,1,2) 163 | table2 = ax2.table(cellText=EPdata, rowLabels=table_rows, colLabels=EPcolumns, loc='center') 164 | table2.scale(0.7,2) 165 | ax2.axis('off') 166 | plt.title('Introspection table (Top: F1max, Bot: RP100)') 167 | plt.show() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ Segment and generate Locus descriptor for each scan in a sequence. """ 2 | 3 | import sys 4 | import os 5 | import glob 6 | import yaml 7 | import argparse 8 | from tqdm import tqdm 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), 'segmentation')) 11 | sys.path.append(os.path.join(os.path.dirname(__file__), 'descriptor_generation')) 12 | from utils.kitti_dataloader import * 13 | from utils.augment_scans import * 14 | from segmentation.extract_segments import * 15 | from segmentation.extract_segment_features import * 16 | from descriptor_generation.locus_descriptor import * 17 | 18 | seg_timer, feat_timer, desc_timer = Timer(), Timer(), Timer() 19 | 20 | # Load params 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--seq", default='02', help="KITTI sequence number") 23 | parser.add_argument("--aug_type", default='none', help="Scan augmentation type ['occ', 'rot', 'ds']") 24 | parser.add_argument("--aug_param", default=0, type=float, help="Scan augmentation parameter") 25 | args = parser.parse_args() 26 | print('Sequence: ', args.seq, ', Augmentation: ', args.aug_type, ', Param: ', args.aug_param) 27 | 28 | cfg_file = open('config.yml', 'r') 29 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 30 | desc_params = cfg_params['descriptor_generation'] 31 | seg_params = cfg_params['segmentation'] 32 | 33 | # Load data 34 | basedir = cfg_params['paths']['KITTI_dataset'] 35 | sequence_path = basedir + 'sequences/' + args.seq + '/' 36 | bin_files = sorted(glob.glob(os.path.join( 37 | sequence_path, 'velodyne', '*.bin'))) 38 | scans = yield_bin_scans(bin_files) 39 | 40 | transforms, _ = load_poses_from_txt(sequence_path + 'poses.txt') 41 | rel_transforms = get_delta_pose(transforms) 42 | 43 | 44 | # Setup database variables 45 | num_queries = len(rel_transforms) 46 | segments_database, features_database = [], [] 47 | seg_corres_database, locus_descriptor_database = [], [] 48 | database_dict = {'segments_database': segments_database, 49 | 'features_database': features_database, 50 | 'seg_corres_database': seg_corres_database, 51 | 'rel_transforms': rel_transforms} 52 | 53 | 54 | for query_idx in tqdm(range(num_queries)): 55 | # Load LiDAR scan point cloud 56 | scan = next(scans) 57 | scan = scan[:, :-1] 58 | 59 | # Optional scan augmentation for robustness tests 60 | if args.aug_type == 'rot': 61 | scan, rot_mat = augmented_scan(scan, args.aug_type, args.aug_param) 62 | transforms[query_idx][:3,:3] = np.dot(transforms[query_idx][:3,:3], rot_mat) 63 | if query_idx > 0: 64 | database_dict['rel_transforms'][query_idx-1] = get_delta_pose([transforms[query_idx-1], transforms[query_idx]])[0] 65 | elif args.aug_type == 'occ': 66 | scan = augmented_scan(scan, args.aug_type, args.aug_param) 67 | 68 | # Extract segments 69 | seg_timer.tic() 70 | segments = get_segments(scan, seg_params) 71 | segments_database.append(segments) 72 | seg_timer.toc() 73 | 74 | # Extract segment features 75 | feat_timer.tic() 76 | features = get_segment_features(segments) 77 | features_database.append(features) 78 | feat_timer.toc() 79 | 80 | # Generate 'Locus' global descriptor 81 | desc_timer.tic() 82 | locus_descriptor = get_locus_descriptor(query_idx, desc_params, database_dict) 83 | locus_descriptor_database.append(locus_descriptor) 84 | desc_timer.toc() 85 | 86 | print('Average time per scan:') 87 | print(f"--- seg: {seg_timer.avg}s, feat: {feat_timer.avg}s, desc: {desc_timer.avg}s ---") 88 | 89 | save_dir = cfg_params['paths']['save_dir'] + args.seq 90 | if not os.path.exists(save_dir): 91 | os.makedirs(save_dir) 92 | desc_file_name = '/locus_descriptor_' + desc_params['fb_mode'] 93 | if args.aug_type != 'none': 94 | desc_file_name = desc_file_name + '_' + args.aug_type + str(int(args.aug_param)) 95 | save_pickle(locus_descriptor_database, save_dir + 96 | desc_file_name + '.pickle') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | open3d 2 | tensorflow==1.8.0 -------------------------------------------------------------------------------- /segmentation/extract_segment_features.py: -------------------------------------------------------------------------------- 1 | """Code for extracting Segmap-CNN features for input segments. """ 2 | # Ref: https://github.com/ethz-asl/segmap/blob/master/segmappy/bin/segmappy_plot_roc_from_matches 3 | import numpy as np 4 | import tensorflow as tf 5 | tf.logging.set_verbosity(tf.logging.ERROR) 6 | 7 | from segmappy.core.config import * 8 | from segmappy.core.dataset import * 9 | from segmappy.core.preprocessor import * 10 | from segmappy.core.generator import * 11 | from segmappy.tools.classifiertools import get_default_preprocessor 12 | 13 | 14 | # read config file 15 | configfile = "default_training.ini" 16 | config = Config(configfile) 17 | 18 | # load preprocessor 19 | preprocessor = get_default_preprocessor(config) 20 | 21 | def get_segment_features(segments): 22 | 23 | # Create dummy inputs 24 | n_classes = len(segments) 25 | classes = np.arange(len(segments), dtype=np.int64) 26 | positions = [map(float, np.zeros(3)) for i in range(len(segments))] 27 | 28 | preprocessor.init_segments(segments, classes, positions=np.asarray(positions)) 29 | gen_test = Generator( 30 | preprocessor, 31 | range(len(segments)), 32 | n_classes, 33 | train=False, 34 | batch_size=config.batch_size, 35 | shuffle=False, 36 | ) 37 | 38 | tf.reset_default_graph() 39 | saver = tf.train.import_meta_graph( 40 | os.path.join(config.cnn_model_folder, "model.ckpt.meta") 41 | ) 42 | 43 | # get key tensorflow variables 44 | cnn_graph = tf.get_default_graph() 45 | cnn_input = cnn_graph.get_tensor_by_name("InputScope/input:0") 46 | scales = cnn_graph.get_tensor_by_name("scales:0") 47 | descriptor = cnn_graph.get_tensor_by_name("OutputScope/descriptor_read:0") 48 | 49 | cnn_features = [] 50 | with tf.Session() as sess: 51 | saver.restore(sess, tf.train.latest_checkpoint(config.cnn_model_folder)) 52 | 53 | for batch in range(gen_test.n_batches): 54 | batch_segments, _ = gen_test.next() 55 | batch_descriptors = sess.run( 56 | descriptor, 57 | feed_dict={cnn_input: batch_segments, scales: preprocessor.last_scales}, 58 | ) 59 | for batch_descriptor in batch_descriptors: 60 | cnn_features.append(batch_descriptor) 61 | 62 | return np.array(cnn_features) 63 | 64 | 65 | ##################################################################################### 66 | # Test 67 | ##################################################################################### 68 | 69 | 70 | if __name__ == "__main__": 71 | import os 72 | import sys 73 | import yaml 74 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 75 | from utils.misc_utils import * 76 | 77 | seq = '06' 78 | cfg_file = open('config.yml', 'r') 79 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 80 | 81 | data_dir = cfg_params['paths']['save_dir'] + seq 82 | segments_database = load_pickle(data_dir + '/segments_database.pickle') 83 | 84 | features_database = [] 85 | 86 | for idx in range(len(segments_database)): 87 | segments = segments_database[idx] 88 | features = get_segment_features(segments) 89 | print("cnn_features: ", np.shape(features)) 90 | features_database.append(features) 91 | 92 | save_dir = cfg_params['paths']['save_dir'] + seq 93 | save_pickle(features_database, save_dir + 94 | '/segment_features_database.pickle') -------------------------------------------------------------------------------- /segmentation/extract_segments.py: -------------------------------------------------------------------------------- 1 | """Code for extracting Euclidean segments from a point cloud.""" 2 | 3 | import numpy as np 4 | import pcl 5 | import os 6 | import sys 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 9 | from utils.kitti_dataloader import visualize_scan_open3d 10 | 11 | def extract_cluster_indices(cloud_filtered, seg_params): 12 | tree = cloud_filtered.make_kdtree() 13 | ec = cloud_filtered.make_EuclideanClusterExtraction() 14 | ec.set_ClusterTolerance(seg_params['c_tolerence']) 15 | ec.set_MinClusterSize(seg_params['c_min_size']) 16 | ec.set_MaxClusterSize(seg_params['c_max_size']) 17 | ec.set_SearchMethod(tree) 18 | cluster_indices = ec.Extract() 19 | return cluster_indices 20 | 21 | def extract_segments(scan, seg_params): 22 | if seg_params['visualize']: 23 | visualize_scan_open3d(scan) 24 | 25 | # Ground Plane Removal 26 | cloud = pcl.PointCloud(scan) 27 | seg = cloud.make_segmenter_normals(ksearch=50) 28 | seg.set_optimize_coefficients(True) 29 | seg.set_model_type(pcl.SACMODEL_NORMAL_PLANE) 30 | seg.set_method_type(pcl.SAC_RANSAC) 31 | seg.set_distance_threshold(seg_params['g_dist_thresh']) 32 | seg.set_normal_distance_weight(seg_params['g_normal_dist_weight']) 33 | seg.set_max_iterations(100) 34 | indices, coefficients = seg.segment() 35 | 36 | crop_xyz = np.asarray(cloud) 37 | for k, indice in enumerate(indices): 38 | crop_xyz[indice][2] = -20.0 39 | 40 | crop_xyz = crop_xyz[crop_xyz[:, -1] > -seg_params['g_height']] 41 | 42 | # Voxel filter (optional) 43 | ds_f = seg_params['ds_factor'] 44 | if ds_f > 0.01: 45 | cloud = pcl.PointCloud(crop_xyz) 46 | vg = cloud.make_voxel_grid_filter() 47 | vg.set_leaf_size(ds_f, ds_f, ds_f) 48 | cloud_filtered = vg.filter() 49 | else: 50 | cloud_filtered = pcl.PointCloud(crop_xyz) 51 | 52 | if seg_params['visualize']: 53 | visualize_scan_open3d(cloud_filtered) 54 | 55 | # Euclidean Cluster Extraction 56 | cluster_indices = extract_cluster_indices(cloud_filtered, seg_params) 57 | 58 | if seg_params['visualize']: 59 | print('cluster_indices : ' , np.shape(cluster_indices)) 60 | 61 | segments = [] 62 | points_database = [] 63 | colours_database = [] 64 | init = False 65 | 66 | for j, indices in enumerate(cluster_indices): 67 | points = np.zeros((len(indices), 3), dtype=np.float32) 68 | 69 | for k, indice in enumerate(indices): 70 | points[k][0] = cloud_filtered[indice][0] 71 | points[k][1] = cloud_filtered[indice][1] 72 | points[k][2] = cloud_filtered[indice][2] 73 | 74 | # Additional filtering step to remove flat(ground-plane) segments 75 | x_diff = (max(points[:,0]) - min(points[:,0])) 76 | y_diff = (max(points[:,1]) - min(points[:,1])) 77 | z_diff = (max(points[:,2]) - min(points[:,2])) 78 | if (not seg_params['filter_flat_seg']) or (max(x_diff,y_diff)/z_diff < seg_params['horizontal_ratio']): 79 | segments.append(points) 80 | colour = np.random.random_sample((3)) 81 | if init: 82 | points_database = np.vstack((points_database, points)) 83 | colour = np.tile(colour, (len(indices), 1)) 84 | colours_database = np.vstack((colours_database, colour)) 85 | else: 86 | points_database = points 87 | colour = np.tile(colour, (len(indices), 1)) 88 | colours_database = colour 89 | init = True 90 | 91 | if seg_params['visualize']: 92 | visualize_scan_open3d(points_database, colours_database) 93 | 94 | return segments 95 | 96 | def get_segments(scan, seg_params): 97 | segments = extract_segments(scan, seg_params) 98 | 99 | # Handling rare degenerate scenes 100 | c_tolerence = seg_params['c_tolerence'] 101 | while seg_params['enforce_min_seg_count'] and len(segments) < seg_params['min_seg_count']: 102 | seg_params['c_tolerence'] -= 0.01 103 | if seg_params['c_tolerence'] < 0.01: 104 | break 105 | segments = extract_segments(scan, seg_params) 106 | seg_params['c_tolerence'] = c_tolerence 107 | return segments 108 | 109 | ##################################################################################### 110 | # Test 111 | ##################################################################################### 112 | 113 | 114 | if __name__ == "__main__": 115 | 116 | import glob 117 | import yaml 118 | from utils.kitti_dataloader import yield_bin_scans 119 | from utils.misc_utils import save_pickle 120 | 121 | seq = '06' 122 | 123 | cfg_file = open('config.yml', 'r') 124 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 125 | seg_params = cfg_params['segmentation'] 126 | 127 | basedir = cfg_params['paths']['KITTI_dataset'] 128 | sequence_path = basedir + 'sequences/' + seq + '/' 129 | bin_files = sorted(glob.glob(os.path.join( 130 | sequence_path, 'velodyne', '*.bin'))) 131 | scans = yield_bin_scans(bin_files) 132 | 133 | segments_database = [] 134 | 135 | for i in range(10): 136 | scan = next(scans) 137 | segments = extract_segments(scan[:, :-1], seg_params) 138 | print('Extracted segments: ', np.shape(segments)) 139 | segments_database.append(segments) 140 | 141 | save_dir = cfg_params['paths']['save_dir'] + seq 142 | save_pickle(segments_database, save_dir + 143 | '/segments_database.pickle') -------------------------------------------------------------------------------- /segmentation/segmappy/README.md: -------------------------------------------------------------------------------- 1 | # segmappy 2 | 3 | 4 | All content is this directory is from [ethz-asl/.../segmappy](https://github.com/ethz-asl/segmap/tree/master/segmappy/segmappy). -------------------------------------------------------------------------------- /segmentation/segmappy/config/default_training.ini: -------------------------------------------------------------------------------- 1 | [general] 2 | # Database folder 3 | # If no base_dir is given, the default base_dir will be used instead 4 | #base_dir = ... 5 | cnn_train_folders = dataset18,dataset20 6 | cnn_test_folder = dataset27 7 | semantics_train_folder = dataset18 8 | 9 | # Combine sequences based on merge events triggered in segmatch 10 | use_merges = true 11 | 12 | # Size of the merged sequence compared to the last element in the merged 13 | # sequence to keep matches containing the merged sequence 14 | keep_match_thresh = 0.3 15 | 16 | # Combine the views based on the segmatch matches 17 | use_matches = true 18 | 19 | # Discard classes of segments that are smaller than min_class_size 20 | min_class_size = 2 21 | 22 | # The relative size of a segment compared to the last segment in the sequence 23 | # so that it is still considered relevant and kept 24 | require_relevance = 0.05 25 | 26 | # The number of points that must be different so that two segments are 27 | # considered different. Similar segments are removed in chronological order 28 | require_diff_points = 0 29 | 30 | [augment] 31 | # Generate new samples by randomly rotating each sample by 32 | # [-augment_angle, augment_angle] degrees. 33 | augment_angle = 180 34 | 35 | # Augment by randomly removing a percentage of points from each sample 36 | augment_remove_random_min = 0.0 37 | augment_remove_random_max = 0.1 38 | augment_remove_plane_min = 0.0 39 | augment_remove_plane_max = 0.5 40 | 41 | # Augment by randomly jittering the segment after centering 42 | augment_jitter = 0.0 43 | 44 | [normalize] 45 | # Align the segments (robot/eigen/none) 46 | align = eigen 47 | 48 | # Which type of scaling to use 49 | # - fixed: use a fixed scale 50 | # - aspect: scale, but maintain aspect ratio 51 | # - fit: scale each dimenstion indipendently 52 | scale_method = fit 53 | 54 | # How to center the segment 55 | # - mean: based on the segments mean, some point will be out of bounds 56 | # - min_max: centers based on the min and max of each dimension 57 | # - none: no centering 58 | center_method = mean 59 | 60 | # Size of the voxel parallelepiped in meters 61 | scale_x = 8 62 | scale_y = 8 63 | scale_z = 4 64 | 65 | # Number of voxels in the rectangular parallelepiped into 66 | # which to normalize each segment 67 | voxels_x = 32 68 | voxels_y = 32 69 | voxels_z = 16 70 | 71 | # Remove the mean and std 72 | remove_mean = false 73 | remove_std = false 74 | 75 | [train] 76 | # Folder into which to save the model after training 77 | # If no model_base_dir is given, the default model_base_dir will be used instead 78 | #model_base_dir = ... 79 | cnn_model_folder = segmap64 80 | semantics_model_folder = segmap64_semantics 81 | 82 | # Percentage of match sequences to put in the test set 83 | test_size = 0.3 84 | 85 | # Number of epochs to train for 86 | n_epochs = 256 87 | 88 | # Batch size 89 | batch_size = 64 90 | 91 | # Root path to save tensorboard logs 92 | log_path = /tmp/segmap/tensorboard/ 93 | 94 | # Directory where to save debug outputs 95 | debug_path = /tmp/segmap/debug 96 | -------------------------------------------------------------------------------- /segmentation/segmappy/core/config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | 4 | def get_segmap_home_dir(): 5 | segmap_home = os.path.abspath( 6 | os.path.join(os.path.expanduser("~"), ".segmap/") 7 | ) 8 | 9 | # If home directory doesn't exist create 10 | if not os.path.isdir(segmap_home): 11 | try: 12 | os.mkdir(segmap_home) 13 | except OSError as e: 14 | print('Error: Could not create SegMap home directory') 15 | raise 16 | 17 | # Copy config into the new home directory 18 | config_src = os.path.abspath(os.path.join( 19 | os.path.dirname(__file__), 20 | '../config/default_training.ini') 21 | ) 22 | 23 | import shutil 24 | shutil.copy(config_src, segmap_home) 25 | 26 | return segmap_home 27 | 28 | def get_config_dir(): 29 | # Returns the package-wide config directory. 30 | return get_segmap_home_dir() 31 | 32 | def get_default_model_dir(): 33 | # Returns the package-wide default trained model directory. 34 | return os.path.join(get_segmap_home_dir(), "trained_models/") 35 | 36 | def get_default_dataset_dir(): 37 | # Returns the package-wide default datasets directory. 38 | return os.path.join(get_segmap_home_dir(), "training_datasets/") 39 | 40 | class Config(object): 41 | def __init__(self, name="train.ini"): 42 | path = os.path.join(get_config_dir(), name) 43 | if not os.path.isfile(path): 44 | raise IOError("Config file '{}' not found.".format(path)) 45 | 46 | config = configparser.ConfigParser() 47 | config.read(path) 48 | 49 | # general 50 | try: 51 | self.base_dir = os.path.abspath( 52 | os.path.join( 53 | os.path.dirname(__file__), config.get("general", "base_dir") 54 | ) 55 | ) 56 | except: 57 | self.base_dir = get_default_dataset_dir() 58 | print( 59 | "No dataset base directory provided, defaulting to {}.".format( 60 | self.base_dir 61 | ) 62 | ) 63 | 64 | self.cnn_train_folders = config.get("general", "cnn_train_folders") 65 | self.cnn_test_folder = config.get("general", "cnn_test_folder") 66 | self.semantics_train_folder = config.get("general", "semantics_train_folder") 67 | self.use_merges = config.getboolean("general", "use_merges") 68 | self.keep_match_thresh = config.getfloat("general", "keep_match_thresh") 69 | self.use_matches = config.getboolean("general", "use_matches") 70 | self.min_class_size = config.getint("general", "min_class_size") 71 | self.require_relevance = config.getfloat("general", "require_relevance") 72 | self.require_diff_points = config.getint("general", "require_diff_points") 73 | 74 | # augment 75 | self.augment_angle = config.getfloat("augment", "augment_angle") 76 | self.augment_remove_random_min = config.getfloat( 77 | "augment", "augment_remove_random_min" 78 | ) 79 | self.augment_remove_random_max = config.getfloat( 80 | "augment", "augment_remove_random_max" 81 | ) 82 | assert self.augment_remove_random_max >= self.augment_remove_random_min 83 | self.augment_remove_plane_min = config.getfloat( 84 | "augment", "augment_remove_plane_min" 85 | ) 86 | self.augment_remove_plane_max = config.getfloat( 87 | "augment", "augment_remove_plane_max" 88 | ) 89 | assert self.augment_remove_plane_max >= self.augment_remove_plane_min 90 | self.augment_jitter = config.getfloat("augment", "augment_jitter") 91 | 92 | # normalize 93 | self.align = config.get("normalize", "align") 94 | assert self.align in ("none", "eigen", "robot") 95 | self.scale_method = config.get("normalize", "scale_method") 96 | assert self.scale_method in ("fixed", "aspect", "fit") 97 | self.center_method = config.get("normalize", "center_method") 98 | assert self.center_method in ("mean", "min_max", "none") 99 | self.scale = tuple( 100 | config.getint("normalize", "scale_" + axis) for axis in ("x", "y", "z") 101 | ) 102 | self.voxels = tuple( 103 | config.getint("normalize", "voxels_" + axis) for axis in ("x", "y", "z") 104 | ) 105 | self.remove_mean = config.getboolean("normalize", "remove_mean") 106 | self.remove_std = config.getboolean("normalize", "remove_std") 107 | 108 | # train 109 | try: 110 | self.model_base_dir = os.path.abspath( 111 | os.path.join(os.path.dirname(__file__), config.get("train", "model_base_dir")) 112 | ) 113 | except: 114 | self.model_base_dir = get_default_model_dir() 115 | print( 116 | "No model base directory provided, defaulting to {}.".format( 117 | self.model_base_dir 118 | ) 119 | ) 120 | self.cnn_model_folder = os.path.abspath( 121 | os.path.join(self.model_base_dir, config.get("train", "cnn_model_folder")) 122 | ) 123 | try: 124 | self.semantics_folder_name = config.get("train", "semantics_model_folder") 125 | except: 126 | self.semantics_folder_name = "semantics_nn" 127 | self.semantics_model_folder = os.path.abspath( 128 | os.path.join(self.model_base_dir, self.semantics_folder_name) 129 | ) 130 | self.test_size = config.getfloat("train", "test_size") 131 | self.n_epochs = config.getint("train", "n_epochs") 132 | self.batch_size = config.getint("train", "batch_size") 133 | self.log_path = config.get("train", "log_path") 134 | self.debug_path = config.get("train", "debug_path") 135 | -------------------------------------------------------------------------------- /segmentation/segmappy/core/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import os 4 | 5 | from .config import get_default_dataset_dir 6 | 7 | class Dataset(object): 8 | # load config values 9 | def __init__( 10 | self, 11 | folder="dataset", 12 | base_dir=get_default_dataset_dir(), 13 | require_change=0.0, 14 | use_merges=True, 15 | keep_match_thresh=0.0, 16 | use_matches=True, 17 | min_class_size=1, 18 | require_relevance=0.0, 19 | require_diff_points=0, 20 | normalize_classes=True, 21 | ): 22 | abs_folder = os.path.abspath(os.path.join(base_dir, folder)) 23 | try: 24 | assert os.path.isdir(abs_folder) 25 | except AssertionError: 26 | raise IOError("Dataset folder {} not found.".format(abs_folder)) 27 | 28 | self.folder = abs_folder 29 | self.require_change = require_change 30 | self.use_merges = use_merges 31 | self.keep_match_thresh = keep_match_thresh 32 | self.use_matches = use_matches 33 | self.min_class_size = min_class_size 34 | self.require_relevance = require_relevance 35 | self.require_diff_points = require_diff_points 36 | self.normalize_classes = normalize_classes 37 | 38 | # load the segment dataset 39 | def load(self, preprocessor=None): 40 | from ..tools.import_export import load_segments, load_positions, load_features 41 | 42 | # load all the csv files 43 | self.segments, sids, duplicate_sids = load_segments(folder=self.folder) 44 | self.positions, pids, duplicate_pids = load_positions(folder=self.folder) 45 | self.features, self.feature_names, fids, duplicate_fids = load_features( 46 | folder=self.folder 47 | ) 48 | 49 | self.classes = np.array(sids) 50 | self.duplicate_classes = self.classes.copy() 51 | self.positions = np.array(self.positions) 52 | self.features = np.array(self.features) 53 | self.duplicate_ids = np.array(duplicate_sids) 54 | 55 | # load labels 56 | from ..tools.import_export import load_labels 57 | 58 | self.labels, self.lids = load_labels(folder=self.folder) 59 | self.labels = np.array(self.labels) 60 | self.labels_dict = dict(zip(self.lids, self.labels)) 61 | 62 | # load matches 63 | from ..tools.import_export import load_matches 64 | 65 | self.matches = load_matches(folder=self.folder) 66 | 67 | if self.require_change > 0.0: 68 | self._remove_unchanged() 69 | 70 | # combine sequences that are part of a merger 71 | if self.use_merges: 72 | from ..tools.import_export import load_merges 73 | 74 | merges, _ = load_merges(folder=self.folder) 75 | self._combine_sequences(merges) 76 | self.duplicate_classes = self.classes.copy() 77 | 78 | # remove small irrelevant segments 79 | if self.require_relevance > 0: 80 | self._remove_irrelevant() 81 | 82 | # only use segments that are different enough 83 | if self.require_diff_points > 0: 84 | assert preprocessor is not None 85 | self._remove_similar(preprocessor) 86 | 87 | # combine classes based on matches 88 | if self.use_matches: 89 | self._combine_classes() 90 | 91 | # normalize ids and remove small classes 92 | self._normalize_classes() 93 | 94 | print( 95 | " Found", 96 | self.n_classes, 97 | "valid classes with", 98 | len(self.segments), 99 | "segments", 100 | ) 101 | 102 | self._sort_ids() 103 | 104 | return ( 105 | self.segments, 106 | self.positions, 107 | self.classes, 108 | self.n_classes, 109 | self.features, 110 | self.matches, 111 | self.labels_dict, 112 | ) 113 | 114 | def _remove_unchanged(self): 115 | keep = np.ones(self.classes.size).astype(np.bool) 116 | for cls in np.unique(self.classes): 117 | class_ids = np.where(self.classes == cls)[0] 118 | 119 | prev_size = self.segments[class_ids[0]].shape[0] 120 | for class_id in class_ids[1:]: 121 | size = self.segments[class_id].shape[0] 122 | if size < prev_size * (1.0 + self.require_change): 123 | keep[class_id] = False 124 | else: 125 | prev_size = size 126 | 127 | self._trim_data(keep) 128 | 129 | print(" Found %d segments that changed enough" % len(self.segments)) 130 | 131 | # list of sequence pairs to merge and correct from the matches table 132 | def _combine_sequences(self, merges): 133 | # calculate the size of each sequence based on the last element 134 | last_sizes = {} 135 | subclasses = {} 136 | for cls in np.unique(self.classes): 137 | class_ids = np.where(self.classes == cls)[0] 138 | last_id = class_ids[np.argmax(self.duplicate_ids[class_ids])] 139 | last_sizes[cls] = len(self.segments[last_id]) 140 | subclasses[cls] = [] 141 | 142 | # make merges and keep a list of the merged sequences for each class 143 | for merge in merges: 144 | merge_sequence, target_sequence = merge 145 | 146 | merge_ids = np.where(self.classes == merge_sequence)[0] 147 | target_ids = np.where(self.classes == target_sequence)[0] 148 | 149 | self.classes[merge_ids] = target_sequence 150 | self.duplicate_ids[target_ids] += merge_ids.size 151 | 152 | subclasses[target_sequence].append(merge_sequence) 153 | subclasses[target_sequence] += subclasses[merge_sequence] 154 | del subclasses[merge_sequence] 155 | 156 | # calculate how relevant the merges are based on size 157 | relevant = {} 158 | new_class = {} 159 | for main_class in subclasses: 160 | relevant[main_class] = True 161 | new_class[main_class] = main_class 162 | 163 | main_size = last_sizes[main_class] 164 | for sub_class in subclasses[main_class]: 165 | new_class[sub_class] = main_class 166 | sub_size = last_sizes[sub_class] 167 | if float(sub_size) / main_size < self.keep_match_thresh: 168 | relevant[sub_class] = False 169 | else: 170 | relevant[sub_class] = True 171 | 172 | # ignore non-relevant merges and for the relevant merges replace 173 | # the merged class with the new class name 174 | new_matches = [] 175 | for match in self.matches: 176 | new_match = [] 177 | for cls in match: 178 | if relevant[cls]: 179 | new_match.append(new_class[cls]) 180 | 181 | if len(new_match) > 1: 182 | new_matches.append(new_match) 183 | 184 | print(" Found %d matches that are relevant after merges" % len(new_matches)) 185 | 186 | self.matches = new_matches 187 | 188 | # combine the classes in a 1d vector of labeled classes based on a 2d 189 | # listing of segments that should share the same class 190 | def _combine_classes(self): 191 | # filtered out non-unique matches 192 | unique_matches = set() 193 | for match in self.matches: 194 | unique_match = [] 195 | for cls in match: 196 | if cls not in unique_match: 197 | unique_match.append(cls) 198 | 199 | if len(unique_match) > 1: 200 | unique_match = tuple(sorted(unique_match)) 201 | if unique_match not in unique_matches: 202 | unique_matches.add(unique_match) 203 | 204 | unique_matches = [list(match) for match in unique_matches] 205 | print(" Found %d matches that are unique" % len(unique_matches)) 206 | 207 | # combine matches with classes in common 208 | groups = {} 209 | class_group = {} 210 | 211 | for i, cls in enumerate(np.unique(unique_matches)): 212 | groups[i] = [cls] 213 | class_group[cls] = i 214 | 215 | for match in unique_matches: 216 | main_group = class_group[match[0]] 217 | 218 | for cls in match: 219 | other_group = class_group[cls] 220 | if other_group != main_group: 221 | for other_class in groups[other_group]: 222 | if other_class not in groups[main_group]: 223 | groups[main_group].append(other_class) 224 | class_group[other_class] = main_group 225 | 226 | del groups[other_group] 227 | 228 | self.matches = [groups[i] for i in groups] 229 | print(" Found %d matches after grouping" % len(self.matches)) 230 | 231 | # combine the sequences into the same class 232 | for match in self.matches: 233 | assert len(match) > 1 234 | for other_class in match[1:]: 235 | self.classes[self.classes == other_class] = match[0] 236 | 237 | # make class ids sequential and remove classes that are too small 238 | def _normalize_classes(self): 239 | # mask of segments to keep 240 | keep = np.ones(self.classes.size).astype(np.bool) 241 | 242 | # number of classes and current class counter 243 | self.n_classes = 0 244 | for i in np.unique(self.classes): 245 | # find the elements in the class 246 | idx = self.classes == i 247 | if np.sum(idx) >= self.min_class_size: 248 | # if class is large enough keep and relabel 249 | if self.normalize_classes: 250 | self.classes[idx] = self.n_classes 251 | 252 | # found one more class 253 | self.n_classes = self.n_classes + 1 254 | else: 255 | # mark class for removal and delete label information 256 | keep = np.logical_and(keep, np.logical_not(idx)) 257 | 258 | # remove data on the removed classes 259 | self._trim_data(keep) 260 | 261 | # remove segments that are too small compared to the last 262 | # element in the sequence 263 | def _remove_irrelevant(self): 264 | keep = np.ones(self.classes.size).astype(np.bool) 265 | for cls in np.unique(self.classes): 266 | class_ids = np.where(self.classes == cls)[0] 267 | last_id = class_ids[np.argmax(self.duplicate_ids[class_ids])] 268 | last_size = len(self.segments[last_id]) 269 | 270 | for class_id in class_ids: 271 | segment_size = len(self.segments[class_id]) 272 | if float(segment_size) / last_size < self.require_relevance: 273 | keep[class_id] = False 274 | 275 | self._trim_data(keep) 276 | 277 | print(" Found %d segments that are relevant" % len(self.segments)) 278 | 279 | # remove segments that are too similar based on hamming distance 280 | def _remove_similar(self, preprocessor): 281 | keep = np.ones(self.classes.size).astype(np.bool) 282 | for c in np.unique(self.classes): 283 | class_ids = np.where(self.classes == c)[0] 284 | 285 | # sort duplicates in chronological order 286 | class_ids = class_ids[np.argsort(self.duplicate_ids[class_ids])] 287 | 288 | segments_class = [self.segments[i] for i in class_ids] 289 | segments_class = preprocessor._rescale_coordinates(segments_class) 290 | segments_class = preprocessor._voxelize(segments_class) 291 | 292 | for i, segment_1 in enumerate(segments_class): 293 | for segment_2 in segments_class[i + 1 :]: 294 | diff = np.sum(np.abs(segment_1 - segment_2)) 295 | 296 | if diff < self.require_diff_points: 297 | keep[class_ids[i]] = False 298 | break 299 | 300 | self._trim_data(keep) 301 | 302 | print(" Found %d segments that are dissimilar" % len(self.segments)) 303 | 304 | def _sort_ids(self): 305 | ordered_ids = [] 306 | for cls in np.unique(self.classes): 307 | class_ids = np.where(self.classes == cls)[0] 308 | class_sequences = self.duplicate_classes[class_ids] 309 | unique_sequences = np.unique(class_sequences) 310 | 311 | for unique_sequence in unique_sequences: 312 | sequence_ids = np.where(class_sequences == unique_sequence)[0] 313 | sequence_ids = class_ids[sequence_ids] 314 | sequence_frame_ids = self.duplicate_ids[sequence_ids] 315 | 316 | # order chronologically according to frame id 317 | sequence_ids = sequence_ids[np.argsort(sequence_frame_ids)] 318 | 319 | ordered_ids += sequence_ids.tolist() 320 | 321 | ordered_ids = np.array(ordered_ids) 322 | 323 | self.segments = [self.segments[i] for i in ordered_ids] 324 | self.classes = self.classes[ordered_ids] 325 | 326 | if self.positions.size > 0: 327 | self.positions = self.positions[ordered_ids] 328 | if self.features.size > 0: 329 | self.features = self.features[ordered_ids] 330 | 331 | self.duplicate_ids = self.duplicate_ids[ordered_ids] 332 | self.duplicate_classes = self.duplicate_classes[ordered_ids] 333 | 334 | # keep only segments and corresponding data where the keep parameter is true 335 | def _trim_data(self, keep): 336 | self.segments = [segment for (k, segment) in zip(keep, self.segments) if k] 337 | self.classes = self.classes[keep] 338 | 339 | if self.positions.size > 0: 340 | self.positions = self.positions[keep] 341 | if self.features.size > 0: 342 | self.features = self.features[keep] 343 | 344 | self.duplicate_ids = self.duplicate_ids[keep] 345 | self.duplicate_classes = self.duplicate_classes[keep] 346 | -------------------------------------------------------------------------------- /segmentation/segmappy/core/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | from ..tools.classifiertools import to_onehot 5 | 6 | 7 | class Generator(object): 8 | def __init__( 9 | self, 10 | preprocessor, 11 | segment_ids, 12 | n_classes, 13 | train=True, 14 | batch_size=16, 15 | shuffle=False, 16 | ): 17 | self.preprocessor = preprocessor 18 | self.segment_ids = segment_ids 19 | self.n_classes = n_classes 20 | self.train = train 21 | self.batch_size = batch_size 22 | self.shuffle = shuffle 23 | 24 | self.n_segments = len(self.segment_ids) 25 | self.n_batches = int(np.ceil(float(self.n_segments) / batch_size)) 26 | 27 | self._i = 0 28 | 29 | def __iter__(self): 30 | return self 31 | 32 | def next(self): 33 | if self.shuffle and self._i == 0: 34 | np.random.shuffle(self.segment_ids) 35 | 36 | self.batch_ids = self.segment_ids[self._i : self._i + self.batch_size] 37 | 38 | self._i = self._i + self.batch_size 39 | if self._i >= self.n_segments: 40 | self._i = 0 41 | 42 | batch_segments, batch_classes = self.preprocessor.get_processed( 43 | self.batch_ids, train=self.train 44 | ) 45 | 46 | batch_segments = batch_segments[:, :, :, :, None] 47 | batch_classes = to_onehot(batch_classes, self.n_classes) 48 | 49 | return batch_segments, batch_classes 50 | 51 | 52 | class GeneratorFeatures(object): 53 | def __init__(self, features, classes, n_classes=2, batch_size=16, shuffle=True): 54 | self.features = features 55 | self.classes = np.asarray(classes) 56 | self.n_classes = n_classes 57 | self.batch_size = batch_size 58 | self.shuffle = shuffle 59 | self.n_samples = features.shape[0] 60 | self.n_batches = int(np.ceil(float(self.n_samples) / batch_size)) 61 | self._i = 0 62 | 63 | self.sample_ids = list(range(self.n_samples)) 64 | if shuffle: 65 | np.random.shuffle(self.sample_ids) 66 | 67 | def next(self): 68 | batch_ids = self.sample_ids[self._i : self._i + self.batch_size] 69 | 70 | self._i = self._i + self.batch_size 71 | if self._i >= self.n_samples: 72 | self._i = 0 73 | 74 | batch_features = self.features[batch_ids, :] 75 | batch_classes = self.classes[batch_ids] 76 | batch_classes = to_onehot(batch_classes, self.n_classes) 77 | 78 | return batch_features, batch_classes 79 | -------------------------------------------------------------------------------- /segmentation/segmappy/core/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | 5 | class Preprocessor(object): 6 | def __init__( 7 | self, 8 | augment_angle=0.0, 9 | augment_remove_random_min=0.0, 10 | augment_remove_random_max=0.0, 11 | augment_remove_plane_min=0.0, 12 | augment_remove_plane_max=0.0, 13 | augment_jitter=0.0, 14 | align="none", 15 | voxelize=True, 16 | scale_method="fixed", 17 | center_method="none", 18 | scale=(1, 1, 1), 19 | voxels=(1, 1, 1), 20 | remove_mean=False, 21 | remove_std=False, 22 | batch_size=16, 23 | scaler_train_passes=1, 24 | ): 25 | self.augment_remove_random_min = augment_remove_random_min 26 | self.augment_remove_random_max = augment_remove_random_max 27 | self.augment_remove_plane_min = augment_remove_plane_min 28 | self.augment_remove_plane_max = augment_remove_plane_max 29 | self.augment_angle = augment_angle 30 | self.augment_jitter = augment_jitter 31 | 32 | self.align = align 33 | self.voxelize = voxelize 34 | self.scale_method = scale_method 35 | self.center_method = center_method 36 | self.scale = np.array(scale) 37 | self.voxels = np.array(voxels) 38 | self.remove_mean = remove_mean 39 | self.remove_std = remove_std 40 | self.batch_size = batch_size 41 | self.scaler_train_passes = scaler_train_passes 42 | self._scaler_exists = False 43 | 44 | min_voxel_side_length_m = 0.1 45 | self.min_scale = self.voxels * min_voxel_side_length_m 46 | 47 | self.last_scales = [] 48 | 49 | def init_segments( 50 | self, segments, classes, positions=None, train_ids=None, scaler_path=None 51 | ): 52 | 53 | self.segments = segments 54 | self.classes = np.array(classes) 55 | 56 | if self.align == "robot": 57 | assert positions is not None 58 | self.segments = self._align_robot(self.segments, positions) 59 | 60 | # check if we need to train a scaler 61 | if self.remove_mean or self.remove_std: 62 | if scaler_path is None: 63 | assert train_ids is not None 64 | self._train_scaler(train_ids) 65 | else: 66 | self.load_scaler(scaler_path) 67 | 68 | def get_processed(self, segment_ids, train=True, normalize=True): 69 | batch_segments = [] 70 | for i in segment_ids: 71 | batch_segments.append(self.segments[i]) 72 | 73 | batch_segments = self.process(batch_segments, train, normalize) 74 | batch_classes = self.classes[segment_ids] 75 | 76 | return batch_segments, batch_classes 77 | 78 | def process(self, segments, train=True, normalize=True): 79 | # augment through distorsions 80 | if train and self.augment_remove_random_max > 0: 81 | segments = self._augment_remove_random(segments) 82 | 83 | if train and self.augment_remove_plane_max > 0: 84 | segments = self._augment_remove_plane(segments) 85 | 86 | # align after distorsions 87 | if self.align == "eigen": 88 | segments = self._align_eigen(segments) 89 | 90 | # augment rotation 91 | if train and self.augment_angle > 0: 92 | segments = self._augment_rotation(segments) 93 | 94 | if self.voxelize: 95 | # rescale coordinates and center 96 | segments = self._rescale_coordinates(segments) 97 | 98 | # randomly displace the segment 99 | if train and self.augment_jitter > 0: 100 | segments = self._augment_jitter(segments) 101 | 102 | # insert into voxel grid 103 | segments = self._voxelize(segments) 104 | 105 | # remove mean and/or std 106 | if normalize and self._scaler_exists: 107 | segments = self._normalize_voxel_matrix(segments) 108 | 109 | return segments 110 | 111 | def get_n_batches(self, train=True): 112 | if train: 113 | return self.n_batches_train 114 | else: 115 | return self.n_batches_test 116 | 117 | # create rotation matrix that rotates point around 118 | # the origin by an angle theta, expressed in radians 119 | def _get_rotation_matrix_z(self, theta): 120 | R_z = [ 121 | [np.cos(theta), -np.sin(theta), 0], 122 | [np.sin(theta), np.cos(theta), 0], 123 | [0, 0, 1], 124 | ] 125 | 126 | return np.array(R_z) 127 | 128 | # align according to the robot's position 129 | def _align_robot(self, segments, positions): 130 | aligned_segments = [] 131 | for i, seg in enumerate(segments): 132 | center = np.mean(seg, axis=0) 133 | 134 | robot_pos = positions[i] - center 135 | seg = seg - center 136 | 137 | # angle between robot and x-axis 138 | angle = np.arctan2(robot_pos[1], robot_pos[0]) 139 | 140 | # align the segment so the robots perspective is along the x-axis 141 | inv_rotation_matrix_z = self._get_rotation_matrix_z(angle) 142 | aligned_seg = np.dot(seg, inv_rotation_matrix_z) 143 | 144 | aligned_segments.append(aligned_seg) 145 | 146 | return aligned_segments 147 | 148 | def _align_eigen(self, segments): 149 | aligned_segments = [] 150 | for segment in segments: 151 | # Calculate covariance 152 | center = np.mean(segment, axis=0) 153 | 154 | covariance_2d = np.cov(segment[:, :2] - center[:2], rowvar=False, bias=True) 155 | 156 | eigenvalues, eigenvectors = np.linalg.eig(covariance_2d) 157 | alignment_rad = np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]) 158 | 159 | if eigenvalues[0] < eigenvalues[1]: 160 | alignment_rad = alignment_rad + np.pi / 2 161 | 162 | inv_rotation_matrix_z = self._get_rotation_matrix_z(alignment_rad) 163 | aligned_segment = np.dot(segment, inv_rotation_matrix_z) 164 | 165 | y_center = np.mean(segment[:, 1]) 166 | n_below = np.sum(segment[:, 1] < y_center) 167 | 168 | if n_below < segment.shape[0] / 2: 169 | alignment_rad = alignment_rad + np.pi 170 | inv_rotation_matrix_z = self._get_rotation_matrix_z(np.pi) 171 | aligned_segment = np.dot(aligned_segment, inv_rotation_matrix_z) 172 | 173 | aligned_segments.append(aligned_segment) 174 | 175 | return aligned_segments 176 | 177 | # augment with multiple rotation of the same segment 178 | def _augment_rotation(self, segments): 179 | angle_rad = self.augment_angle * np.pi / 180 180 | 181 | augmented_segments = [] 182 | for segment in segments: 183 | rotation = np.random.uniform(-angle_rad, angle_rad) 184 | segment = np.dot(segment, self._get_rotation_matrix_z(rotation)) 185 | augmented_segments.append(segment) 186 | 187 | return augmented_segments 188 | 189 | def _augment_remove_random(self, segments): 190 | augmented_segments = [] 191 | for segment in segments: 192 | # percentage of points to remove 193 | remove = ( 194 | np.random.random() 195 | * (self.augment_remove_random_max - self.augment_remove_random_min) 196 | + self.augment_remove_random_min 197 | ) 198 | 199 | # randomly choose the points 200 | idx = np.arange(segment.shape[0]) 201 | np.random.shuffle(idx) 202 | idx = idx[int(idx.size * remove) :] 203 | 204 | segment = segment[idx] 205 | augmented_segments.append(segment) 206 | 207 | return augmented_segments 208 | 209 | def _augment_remove_plane(self, segments): 210 | augmented_segments = [] 211 | for segment in segments: 212 | # center segment 213 | center = np.mean(segment, axis=0) 214 | segment = segment - center 215 | 216 | # slice off a section of the segment 217 | while True: 218 | # generate random plane 219 | plane_norm = np.random.random(3) - 0.5 220 | plane_norm = plane_norm / np.sqrt(np.sum(plane_norm ** 2)) 221 | 222 | # on which side of the plane each point is 223 | sign = np.dot(segment, plane_norm) 224 | 225 | # find an offset that removes a desired amount of points 226 | found = False 227 | plane_offsets = np.linspace( 228 | -np.max(self.scale), np.max(self.scale), 100 229 | ) 230 | np.random.shuffle(plane_offsets) 231 | for plane_offset in plane_offsets: 232 | keep = sign + plane_offset > 0 233 | remove_percentage = 1 - (np.sum(keep) / float(keep.size)) 234 | 235 | if ( 236 | remove_percentage > self.augment_remove_plane_min 237 | and remove_percentage < self.augment_remove_plane_max 238 | ): 239 | segment = segment[keep] 240 | found = True 241 | break 242 | 243 | if found: 244 | break 245 | 246 | segment = segment + center 247 | augmented_segments.append(segment) 248 | 249 | return augmented_segments 250 | 251 | def _augment_jitter(self, segments): 252 | jitter_segments = [] 253 | for segment in segments: 254 | jitter = np.random.random(3) * 2 - 1 255 | jitter = jitter * self.augment_jitter * self.voxels / 2 256 | 257 | segment = segment + jitter 258 | jitter_segments.append(segment) 259 | 260 | return jitter_segments 261 | 262 | def _rescale_coordinates(self, segments): 263 | # center corner to origin 264 | centered_segments = [] 265 | for segment in segments: 266 | segment = segment - np.min(segment, axis=0) 267 | centered_segments.append(segment) 268 | segments = centered_segments 269 | 270 | # store the last scaling factors that were used 271 | self.last_scales = [] 272 | 273 | # rescale coordinates to fit inside voxel matrix 274 | rescaled_segments = [] 275 | for segment in segments: 276 | # choose scale 277 | if self.scale_method == "fixed": 278 | scale = self.scale 279 | segment = segment / scale * (self.voxels - 1) 280 | elif self.scale_method == "aspect": 281 | scale = np.tile(np.max(segment), 3) 282 | segment = segment / scale * (self.voxels - 1) 283 | elif self.scale_method == "fit": 284 | scale = np.max(segment, axis=0) 285 | thresholded_scale = np.maximum(scale, self.min_scale) 286 | segment = segment / thresholded_scale * (self.voxels - 1) 287 | 288 | # recenter segment 289 | if self.center_method != "none": 290 | if self.center_method == "mean": 291 | center = np.mean(segment, axis=0) 292 | elif self.center_method == "min_max": 293 | center = np.max(segment, axis=0) / 2.0 294 | 295 | segment = segment + (self.voxels - 1) / 2.0 - center 296 | 297 | self.last_scales.append(scale) 298 | rescaled_segments.append(segment) 299 | 300 | return rescaled_segments 301 | 302 | def _voxelize(self, segments): 303 | voxelized_segments = np.zeros((len(segments),) + tuple(self.voxels)) 304 | for i, segment in enumerate(segments): 305 | # remove out of bounds points 306 | segment = segment[np.all(segment < self.voxels, axis=1), :] 307 | segment = segment[np.all(segment >= 0, axis=1), :] 308 | 309 | # round coordinates 310 | segment = segment.astype(np.int) 311 | 312 | # fill voxel grid 313 | voxelized_segments[i, segment[:, 0], segment[:, 1], segment[:, 2]] = 1 314 | 315 | return voxelized_segments 316 | 317 | def _train_scaler(self, train_ids): 318 | from sklearn.preprocessing import StandardScaler 319 | 320 | scaler = StandardScaler(with_mean=self.remove_mean, with_std=self.remove_std) 321 | 322 | for p in range(self.scaler_train_passes): 323 | for i in np.arange(0, len(train_ids), self.batch_size): 324 | segment_ids = train_ids[i : i + self.batch_size] 325 | segments, _ = self.get_processed(segment_ids) 326 | segments = np.reshape(segments, (segments.shape[0], -1)) 327 | scaler.partial_fit(segments) 328 | 329 | self._scaler = scaler 330 | self._scaler_exists = True 331 | 332 | # remove mean and std 333 | def _normalize_voxel_matrix(self, segments): 334 | segments = np.reshape(segments, (segments.shape[0], -1)) 335 | segments = self._scaler.transform(segments) 336 | segments = np.reshape(segments, (segments.shape[0],) + tuple(self.voxels)) 337 | 338 | return segments 339 | 340 | def save_scaler(self, path): 341 | import pickle 342 | 343 | with open(path, "w") as fp: 344 | pickle.dump(self._scaler, fp) 345 | 346 | def load_scaler(self, path): 347 | import pickle 348 | 349 | with open(path, "r") as fp: 350 | self._scaler = pickle.load(fp) 351 | self._scaler_exists = True 352 | -------------------------------------------------------------------------------- /segmentation/segmappy/tools/classifiertools.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from builtins import input 3 | import numpy as np 4 | import sys 5 | 6 | # sequentially view a set of segments 7 | def visualize(segments, extra_info=None, show_all=False, no_ticks=False): 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.mplot3d import Axes3D 10 | 11 | # scale the axes to match for all the segments 12 | axes_min = np.array(np.min(segments[0], axis=0)) 13 | axes_max = np.array(np.max(segments[0], axis=0)) 14 | 15 | for seg in segments[1:]: 16 | axes_min = np.minimum(axes_min, np.min(seg, axis=0)) 17 | axes_max = np.maximum(axes_max, np.max(seg, axis=0)) 18 | 19 | # display segments 20 | fig_id = 1 21 | plt.ion() 22 | for i, seg in enumerate(segments): 23 | if show_all: 24 | fig_id = i + 1 25 | 26 | fig = plt.figure(fig_id) 27 | plt.clf() 28 | ax = fig.add_subplot(111, projection="3d") 29 | 30 | ax.set_xlim(axes_min[0], axes_max[0]) 31 | ax.set_ylim(axes_min[1], axes_max[1]) 32 | ax.set_zlim(axes_min[2], axes_max[2]) 33 | 34 | if no_ticks: 35 | tick_count = 3 36 | ax.set_xticks(np.linspace(axes_min[0], axes_max[0], tick_count + 2)[1:-1]) 37 | ax.set_yticks(np.linspace(axes_min[1], axes_max[1], tick_count + 2)[1:-1]) 38 | ax.set_zticks(np.linspace(axes_min[2], axes_max[2], tick_count + 2)[1:-1]) 39 | 40 | plt.setp(ax.get_xmajorticklabels(), visible=False) 41 | plt.setp(ax.get_ymajorticklabels(), visible=False) 42 | plt.setp(ax.get_zmajorticklabels(), visible=False) 43 | 44 | ax.scatter(seg[:, 0], seg[:, 1], seg[:, 2]) 45 | 46 | info = "Segment " + str(i) 47 | if extra_info is not None: 48 | info = info + " " + str(extra_info[i]) 49 | sys.stdout.write(info) 50 | 51 | fig.canvas.flush_events() 52 | 53 | if not show_all: 54 | key = input() 55 | if key == "q": 56 | break 57 | else: 58 | sys.stdout.write("\n") 59 | 60 | if show_all: 61 | input() 62 | 63 | plt.ioff() 64 | plt.close("all") 65 | 66 | 67 | def to_onehot(y, n_classes): 68 | y_onehot = np.zeros((len(y), n_classes)) 69 | for i, cls in enumerate(y): 70 | y_onehot[i, cls] = 1 71 | 72 | return y_onehot 73 | 74 | 75 | # sequentially view a set of segments 76 | def visualize_side_by_side(segments, extra_info=None, show_all=False): 77 | 78 | import matplotlib.cm as cm 79 | 80 | n_views = 6.0 81 | if len(segments) < n_views: 82 | return 83 | 84 | import matplotlib.pyplot as plt 85 | from mpl_toolkits.mplot3d import Axes3D 86 | 87 | # scale the axes to match for all the segments 88 | axes_min = np.array(np.min(segments[0], axis=0)) 89 | axes_max = np.array(np.max(segments[0], axis=0)) 90 | max_range = 0 91 | for seg in segments[1:]: 92 | axes_min = np.minimum(axes_min, np.min(seg, axis=0)) 93 | axes_max = np.maximum(axes_max, np.max(seg, axis=0)) 94 | X = seg[:, 0] 95 | Y = seg[:, 1] 96 | Z = seg[:, 2] 97 | new_max_range = np.array([X.max() - X.min(), Y.max() - Y.min()]).max() / 2.0 98 | if new_max_range > max_range: 99 | max_range = new_max_range 100 | fig = plt.figure(1, frameon=False) 101 | plt.clf() 102 | cmap = plt.cm.jet 103 | # fig, axs = plt.subplots(1,len(segments), projection='3d', facecolor='w', edgecolor='w') #figsize=(15, 6) 104 | fig.subplots_adjust(hspace=.5, wspace=.001) 105 | 106 | views_ids = [0] 107 | segments_temp = [] 108 | for i in range(int(n_views - 1)): 109 | idx = i * len(segments) / n_views 110 | print(idx) 111 | views_ids = views_ids + [int(idx)] 112 | segments_temp.append(segments[int(idx)]) 113 | segments_temp.append(segments[len(segments) - 1]) 114 | segments = segments_temp 115 | 116 | print(max_range) 117 | 118 | for i, seg in enumerate(segments): 119 | ax = fig.add_subplot(1, len(segments), i + 1, projection="3d") 120 | ax.set_xlim(axes_min[0], axes_max[0]) 121 | ax.set_ylim(axes_min[1], axes_max[1]) 122 | ax.set_zlim(axes_min[2], axes_max[2]) 123 | 124 | mid_x = (seg[:, 0].max() + seg[:, 0].min()) * 0.5 125 | mid_y = (seg[:, 1].max() + seg[:, 1].min()) * 0.5 126 | mid_z = (seg[:, 2].max() + seg[:, 2].min()) * 0.5 127 | ax.set_xlim(mid_x - max_range, mid_x + max_range) 128 | ax.set_ylim(mid_y - max_range, mid_y + max_range) 129 | ax.set_zlim(mid_z - max_range, mid_z + max_range) 130 | ax.set_aspect(1) 131 | 132 | plt.setp(ax.get_xmajorticklabels(), visible=False) 133 | plt.setp(ax.get_ymajorticklabels(), visible=False) 134 | plt.setp(ax.get_zmajorticklabels(), visible=False) 135 | 136 | tick_count = 3 137 | ax.set_xticks(np.linspace(axes_min[0], axes_max[0], tick_count + 2)[1:-1]) 138 | ax.set_yticks(np.linspace(axes_min[1], axes_max[1], tick_count + 2)[1:-1]) 139 | ax.set_zticks(np.linspace(axes_min[2], axes_max[2], tick_count + 2)[1:-1]) 140 | 141 | ax.set_xticklabels([1, 2, 3, 4]) 142 | # fig.patch.set_visible(False) 143 | # ax.axis('off') 144 | ax.scatter( 145 | seg[:, 0], 146 | seg[:, 1], 147 | seg[:, 2], 148 | s=1, 149 | c=seg[:, 2], 150 | marker="o", 151 | lw=0, 152 | depthshade=False, 153 | cmap="jet_r", 154 | ) 155 | ax.grid(b=False) 156 | ax.patch.set_facecolor("white") 157 | ax.set_axis_off() 158 | plt.draw() 159 | plt.pause(0.001) 160 | 161 | key = input() 162 | 163 | 164 | def get_default_dataset(config, folder): 165 | from ..core.dataset import Dataset 166 | 167 | dataset = Dataset( 168 | folder=folder, 169 | base_dir=config.base_dir, 170 | use_merges=config.use_merges, 171 | use_matches=config.use_matches, 172 | min_class_size=config.min_class_size, 173 | require_diff_points=config.require_diff_points, 174 | keep_match_thresh=config.keep_match_thresh, 175 | require_relevance=config.require_relevance, 176 | ) 177 | 178 | return dataset 179 | 180 | 181 | def get_default_preprocessor(config): 182 | from ..core.preprocessor import Preprocessor 183 | 184 | preprocessor = Preprocessor( 185 | augment_angle=config.augment_angle, 186 | augment_remove_random_min=config.augment_remove_random_min, 187 | augment_remove_random_max=config.augment_remove_random_max, 188 | augment_remove_plane_min=config.augment_remove_plane_min, 189 | augment_remove_plane_max=config.augment_remove_plane_max, 190 | augment_jitter=config.augment_jitter, 191 | align=config.align, 192 | scale_method=config.scale_method, 193 | scale=config.scale, 194 | center_method=config.center_method, 195 | voxels=config.voxels, 196 | remove_mean=config.remove_mean, 197 | remove_std=config.remove_std, 198 | batch_size=config.batch_size, 199 | ) 200 | 201 | return preprocessor 202 | -------------------------------------------------------------------------------- /utils/augment_scans.py: -------------------------------------------------------------------------------- 1 | """Code for distorting point clouds.""" 2 | 3 | import numpy as np 4 | import open3d as o3d 5 | 6 | def occlude_scan(scan, angle): 7 | # Remove points within a sector of fixed angle (degrees) and random heading direction. 8 | thetas = (180/np.pi) * np.arctan2(scan[:,1],scan[:,0]) 9 | heading = (180-angle/2)*np.random.uniform(-1,1) 10 | occ_scan = np.vstack((scan[thetas < (heading - angle/2)] , scan[thetas > (heading + angle/2)])) 11 | return occ_scan 12 | 13 | def random_rotate_scan(scan, r_angle, is_random = True): 14 | # If is_random = True: Rotate about z-axis by random angle upto 'r_angle'. 15 | # Else: Rotate about z-axis by fixed angle 'r_angle'. 16 | r_angle = (np.pi/180) * r_angle 17 | if is_random: 18 | r_angle = r_angle*np.random.uniform() 19 | cos_angle = np.cos(r_angle) 20 | sin_angle = np.sin(r_angle) 21 | rot_matrix = np.array([[cos_angle, -sin_angle, 0], 22 | [sin_angle, cos_angle, 0], 23 | [0, 0, 1]]) 24 | augmented_scan = np.dot(scan, rot_matrix) 25 | 26 | return np.asarray(augmented_scan, dtype=np.float32), rot_matrix 27 | 28 | def downsample_scan(scan, voxel_size): 29 | pcd = o3d.geometry.PointCloud() 30 | pcd.points = o3d.utility.Vector3dVector(scan) 31 | downpcd = pcd.voxel_down_sample(voxel_size=voxel_size) 32 | return np.asarray(downpcd.points) 33 | 34 | def distort_scan(scan, n_sigma, r_angle): 35 | # Add gaussian noise and rotate about z-axis. 36 | noise = np.clip(n_sigma * np.random.randn(*scan.shape), -0.1, 0.1) 37 | noisy_scan = scan + noise 38 | 39 | return random_rotate_scan(noisy_scan, r_angle, False) 40 | 41 | def augmented_scan(scan, aug_type, param): 42 | if aug_type == 'occ': 43 | return occlude_scan(scan, param) 44 | elif aug_type == 'rot': 45 | return random_rotate_scan(scan, param) 46 | elif aug_type == 'ds': 47 | return downsample_scan(scan, param) 48 | else: 49 | return [] 50 | 51 | ##################################################################################### 52 | # Test 53 | ##################################################################################### 54 | 55 | 56 | if __name__ == "__main__": 57 | 58 | # Set the dataset location here: 59 | sequence = '06' 60 | 61 | import os 62 | import sys 63 | import glob 64 | import yaml 65 | 66 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 67 | from utils.misc_utils import * 68 | from utils.kitti_dataloader import * 69 | 70 | cfg_file = open('config.yml', 'r') 71 | cfg_params = yaml.load(cfg_file, Loader=yaml.FullLoader) 72 | basedir = cfg_params['paths']['KITTI_dataset'] 73 | 74 | sequence_path = basedir + 'sequences/' + sequence + '/' 75 | bin_files = sorted(glob.glob(os.path.join( 76 | sequence_path, 'velodyne', '*.bin'))) 77 | scans = yield_bin_scans(bin_files) 78 | 79 | for i in range(10): 80 | scan = next(scans) 81 | scan = scan[:, :-1] 82 | print('Scan ID: ', i) 83 | visualize_scan_open3d(scan) 84 | visualize_scan_open3d(augmented_scan(scan, 'occ', 90)) 85 | visualize_scan_open3d(augmented_scan(scan, 'rot', 180)[0]) 86 | visualize_scan_open3d(augmented_scan(scan, 'ds', 0.5)) 87 | -------------------------------------------------------------------------------- /utils/docs/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/locus/781c1ba340db4bad6ac2760fc8f1920998f56230/utils/docs/pipeline.png -------------------------------------------------------------------------------- /utils/docs/robustness_tests.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csiro-robotics/locus/781c1ba340db4bad6ac2760fc8f1920998f56230/utils/docs/robustness_tests.png -------------------------------------------------------------------------------- /utils/get_segmap_data.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Downloading pre-trained model into '~/.segmap/'" 3 | `wget -P ~/.segmap/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/default_training.ini` 4 | `wget -P ~/.segmap/trained_models/segmap64/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/trained_models/segmap64/checkpoint` 5 | `wget -P ~/.segmap/trained_models/segmap64/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/trained_models/segmap64/graph.pb` 6 | `wget -P ~/.segmap/trained_models/segmap64/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/trained_models/segmap64/model.ckpt.index` 7 | `wget -P ~/.segmap/trained_models/segmap64/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/trained_models/segmap64/model.ckpt.meta` 8 | `wget -P ~/.segmap/trained_models/segmap64/ http://robotics.ethz.ch/~asl-datasets/segmap/segmap_data/trained_models/segmap64/model.ckpt.data-00000-of-00001` 9 | echo "Renaming path" 10 | _model_path=~/.segmap/trained_models/segmap64/model.ckpt 11 | echo "model_checkpoint_path:" \""$_model_path"\" > ~/.segmap/trained_models/segmap64/checkpoint 12 | echo "all_model_checkpoint_paths:" \""$_model_path"\" >> ~/.segmap/trained_models/segmap64/checkpoint 13 | echo "Finished downloading pre-trained model" -------------------------------------------------------------------------------- /utils/kitti_dataloader.py: -------------------------------------------------------------------------------- 1 | """Code for loading KITTI odometry dataset""" 2 | 3 | import glob 4 | import os 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import sys 8 | import open3d as o3d 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) 11 | from utils.misc_utils import * 12 | 13 | ##################################################################################### 14 | # Load poses 15 | ##################################################################################### 16 | 17 | def transfrom_cam2velo(Tcam): 18 | R = np.array([ 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 19 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 20 | ]).reshape(3, 3) 21 | t = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 22 | cam2velo = np.vstack((np.hstack([R, t]), [0, 0, 0, 1])) 23 | 24 | return Tcam @ cam2velo 25 | 26 | def load_poses_from_txt(file_name): 27 | """ 28 | Modified function from: https://github.com/Huangying-Zhan/kitti-odom-eval/blob/master/kitti_odometry.py 29 | """ 30 | f = open(file_name, 'r') 31 | s = f.readlines() 32 | f.close() 33 | transforms = {} 34 | positions = [] 35 | for cnt, line in enumerate(s): 36 | P = np.eye(4) 37 | line_split = [float(i) for i in line.split(" ") if i != ""] 38 | withIdx = len(line_split) == 13 39 | for row in range(3): 40 | for col in range(4): 41 | P[row, col] = line_split[row*4 + col + withIdx] 42 | if withIdx: 43 | frame_idx = line_split[0] 44 | else: 45 | frame_idx = cnt 46 | transforms[frame_idx] = transfrom_cam2velo(P) 47 | positions.append([P[0, 3], P[2, 3], P[1, 3]]) 48 | return transforms, np.asarray(positions) 49 | 50 | 51 | def get_delta_pose(transforms): 52 | rel_transforms = [] 53 | for i in range(len(transforms)-1): 54 | w_T_p1 = transforms[i] 55 | w_T_p2 = transforms[i+1] 56 | 57 | p1_T_w = T_inv(w_T_p1) 58 | p1_T_p2 = np.matmul(p1_T_w, w_T_p2) 59 | rel_transforms.append(p1_T_p2) 60 | return rel_transforms 61 | 62 | ##################################################################################### 63 | # Load scans 64 | ##################################################################################### 65 | 66 | 67 | """ Helper functions from https://github.com/utiasSTARS/pykitti/blob/master/pykitti/utils.py """ 68 | 69 | 70 | def load_bin_scan(file): 71 | """Load and reshape binary file containing single point cloud""" 72 | scan = np.fromfile(file, dtype=np.float32) 73 | return scan.reshape((-1, 4)) 74 | 75 | 76 | def yield_bin_scans(bin_files): 77 | """Generator to load multiple point clouds sequentially""" 78 | for file in bin_files: 79 | yield load_bin_scan(file) 80 | 81 | def visualize_scan_open3d(ptcloud_xyz, colors = []): 82 | pcd = o3d.geometry.PointCloud() 83 | pcd.points = o3d.utility.Vector3dVector(ptcloud_xyz) 84 | if colors != []: 85 | pcd.colors = o3d.utility.Vector3dVector(colors) 86 | o3d.visualization.draw_geometries([pcd]) 87 | 88 | def visualize_sequence_open3d(bin_files, n_scans): 89 | """Visualize scans using Open3D""" 90 | 91 | scans = yield_bin_scans(bin_files) 92 | 93 | for i in range(n_scans): 94 | scan = next(scans) 95 | ptcloud_xyz = scan[:, :-1] 96 | print(ptcloud_xyz.shape) 97 | visualize_scan_open3d(ptcloud_xyz) 98 | 99 | 100 | ##################################################################################### 101 | # Load timestamps 102 | ##################################################################################### 103 | 104 | 105 | def load_timestamps(file_name): 106 | # file_name = data_dir + '/times.txt' 107 | file1 = open(file_name, 'r+') 108 | stimes_list = file1.readlines() 109 | s_exp_list = np.asarray([float(t[-4:-1]) for t in stimes_list]) 110 | times_list = np.asarray([float(t[:-2]) for t in stimes_list]) 111 | times_listn = [times_list[t] * (10**(s_exp_list[t])) 112 | for t in range(len(times_list))] 113 | file1.close() 114 | return times_listn 115 | 116 | ##################################################################################### 117 | # Test 118 | ##################################################################################### 119 | 120 | 121 | if __name__ == "__main__": 122 | 123 | # Set the dataset location here: 124 | basedir = '/mnt/088A6CBB8A6CA742/Datasets/Kitti/dataset/' 125 | 126 | ################## 127 | # Test poses 128 | 129 | fig, axs = plt.subplots(4, 6, constrained_layout=True) 130 | fig.suptitle('KITTI sequences', fontsize=16) 131 | for i in range(22): 132 | sequence = str(i) 133 | if i < 10: 134 | sequence = '0' + str(i) 135 | sequence_path = basedir + 'sequences/' + sequence + '/' 136 | poses_file = sorted( 137 | glob.glob(os.path.join(sequence_path, 'poses.txt'))) 138 | _, positions = load_poses_from_txt(poses_file[0]) 139 | print('seq: ', sequence, 'len', len(positions)) 140 | 141 | axs[i//6, i % 6].plot(positions[:,0], positions[:,1]) 142 | axs[i//6, i % 6].set_title('seq: ' + sequence + 'len' + str(len(positions))) 143 | 144 | plt.show() 145 | 146 | ################## 147 | # Test scans 148 | 149 | sequence = '00' 150 | sequence_path = basedir + 'sequences/' + sequence + '/' 151 | bin_files = sorted(glob.glob(os.path.join( 152 | sequence_path, 'velodyne', '*.bin'))) 153 | # Visualize some scans 154 | visualize_sequence_open3d(bin_files, 2) 155 | 156 | ################## 157 | # Test timestamps 158 | timestamps_file = basedir + 'sequences/' + sequence + '/times.txt' 159 | timestamps = load_timestamps(timestamps_file) 160 | print("Start time (s): ", timestamps[0]) 161 | print("End time (s): ", timestamps[-1]) 162 | 163 | print('Test complete.') 164 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | """ Miscellaneous functions """ 2 | import numpy as np 3 | from numpy import dot 4 | from numpy.linalg import norm 5 | import time 6 | import pickle 7 | 8 | ##################################################################################### 9 | # Data loading/saving 10 | 11 | 12 | def load_pickle(file_name): 13 | dbfile1 = open(file_name, 'rb') 14 | file_data = pickle.load(dbfile1) 15 | dbfile1.close() 16 | return file_data 17 | 18 | 19 | def save_pickle(data_variable, file_name): 20 | dbfile2 = open(file_name, 'ab') 21 | pickle.dump(data_variable, dbfile2) 22 | dbfile2.close() 23 | print('Finished saving: ', file_name) 24 | 25 | ##################################################################################### 26 | # Place recognition 27 | 28 | def check_if_revisit(query_pose, db_poses, thres): 29 | num_dbs = np.shape(db_poses)[0] 30 | is_revisit = 0 31 | 32 | for i in range(num_dbs): 33 | dist = norm(query_pose - db_poses[i]) 34 | if ( dist < thres ): 35 | is_revisit = 1 36 | break 37 | 38 | return is_revisit 39 | 40 | ##################################################################################### 41 | # Math 42 | 43 | def cosine_distance(feature_a, feature_b): 44 | return 1 - dot(feature_a, np.transpose(feature_b))/(norm(feature_a)*norm(feature_b)) 45 | 46 | def T_inv(T_in): 47 | """ Return the inverse of input homogeneous transformation matrix """ 48 | R_in = T_in[:3, :3] 49 | t_in = T_in[:3, [-1]] 50 | R_out = R_in.T 51 | t_out = -np.matmul(R_out, t_in) 52 | return np.vstack((np.hstack((R_out, t_out)), np.array([0, 0, 0, 1]))) 53 | 54 | 55 | def is_nan(x): 56 | return (x != x) 57 | 58 | 59 | def euclidean_to_homogeneous(e_point): 60 | """ Coversion from Eclidean coordinates to Homogeneous """ 61 | h_point = np.concatenate([e_point,[1]]) 62 | return h_point 63 | 64 | 65 | def homogeneous_to_euclidean(h_point): 66 | """ Coversion from Homogeneous coordinates to Eclidean """ 67 | e_point = h_point/ h_point[3] 68 | e_point = e_point[:3] 69 | return e_point 70 | 71 | ##################################################################################### 72 | # Timing 73 | 74 | class Timer(object): 75 | """A simple timer.""" 76 | # Ref: https://github.com/chrischoy/FCGF/blob/master/lib/timer.py 77 | 78 | def __init__(self, binary_fn=None, init_val=0): 79 | self.total_time = 0. 80 | self.calls = 0 81 | self.start_time = 0. 82 | self.diff = 0. 83 | self.binary_fn = binary_fn 84 | self.tmp = init_val 85 | 86 | def reset(self): 87 | self.total_time = 0 88 | self.calls = 0 89 | self.start_time = 0 90 | self.diff = 0 91 | 92 | @property 93 | def avg(self): 94 | return self.total_time / self.calls 95 | 96 | def tic(self): 97 | # using time.time instead of time.clock because time time.clock 98 | # does not normalize for multithreading 99 | self.start_time = time.time() 100 | 101 | def toc(self, average=True): 102 | self.diff = time.time() - self.start_time 103 | self.total_time += self.diff 104 | self.calls += 1 105 | if self.binary_fn: 106 | self.tmp = self.binary_fn(self.tmp, self.diff) 107 | if average: 108 | return self.avg 109 | else: 110 | return self.diff 111 | 112 | ##################################################################################### 113 | # Config 114 | 115 | font = {'family': 'serif', 116 | # 'color': 'black', 117 | 'weight': 'normal', 118 | 'size': 16, 119 | } 120 | font_legend = {'family': 'serif', 121 | # 'color': 'black', 122 | 'weight': 'normal', 123 | 'size': 12, 124 | } 125 | -------------------------------------------------------------------------------- /utils/setup_python_pcl.txt: -------------------------------------------------------------------------------- 1 | sudo apt-get update -y 2 | sudo apt-get install libpcl-dev -y 3 | conda config --add channels conda-forge 4 | conda install -c sirokujira python-pcl 5 | 6 | If 'import pcl' doesnt' work: 7 | cd ~/anaconda3/envs/locus_env/lib/ 8 | ln -s libboost_system.so.1.64.0 libboost_system.so.1.54.0 9 | ln -s libboost_filesystem.so.1.64.0 libboost_filesystem.so.1.54.0 10 | ln -s libboost_thread.so.1.64.0 libboost_thread.so.1.54.0 11 | ln -s libboost_iostreams.so.1.64.0 libboost_iostreams.so.1.54.0 --------------------------------------------------------------------------------