├── .gitignore ├── .idea └── .gitignore ├── LICENSE ├── README.md ├── config └── config_egonn.txt ├── datasets ├── augmentation.py ├── base_datasets.py ├── dataset_utils.py ├── kitti │ ├── generate_evaluation_sets.py │ ├── kitti_raw.py │ └── utils.py ├── mulran │ ├── generate_evaluation_sets.py │ ├── generate_training_tuples.py │ ├── mulran_raw.py │ ├── mulran_train.py │ └── utils.py ├── quantization.py ├── samplers.py └── southbay │ ├── generate_evaluation_sets.py │ ├── generate_training_tuples.py │ └── southbay_raw.py ├── eval ├── evaluate.py └── evaluate_with_rotations.py ├── images ├── key1_.png ├── keypoints_vis.png ├── pair2.png └── registered_pairs.png ├── layers ├── eca_block.py ├── netvlad.py ├── pooling.py └── senet_block.py ├── misc ├── point_clouds.py ├── poses.py └── utils.py ├── models ├── egonn.txt ├── loss.py ├── loss_utils.py ├── minkfpn.py ├── minkgl.py ├── minkloc.py ├── minkloc3d_mulran.txt ├── model_factory.py └── resnet.py ├── third_party ├── minkloc3d │ └── minkloc.py ├── pypcd.py └── scan_context │ ├── evaluate_scan_context.py │ └── scan_context.py ├── training ├── train.py └── trainer.py └── weights └── model_egonn_20210916_1104.pth /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 jac99 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EgonNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale 2 | 3 | Paper: [EgoNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale](https://ieeexplore.ieee.org/document/9645340/) 4 | IEEE Robotics and Automation Letters (RA-L) Volume 7 Issue 2 April 2022 5 | 6 | [Jacek Komorowski](mailto:jacek.komorowski@pw.edu.pl), Monika Wysoczanska, Tomasz Trzcinski 7 | 8 | Warsaw University of Technology 9 | 10 | ### What's new ### 11 | * [2021-10-24] Evaluation code and pretrained models released. 12 | * [2021-12-16] Training code released. 13 | 14 | ### Our other projects ### 15 | * MinkLoc3D: Point Cloud Based Large-Scale Place Recognition (WACV 2021): [MinkLoc3D](https://github.com/jac99/MinkLoc3D) 16 | * MinkLoc++: Lidar and Monocular Image Fusion for Place Recognition (IJCNN 2021): [MinkLoc++](https://github.com/jac99/MinkLocMultimodal) 17 | * Large-Scale Topological Radar Localization Using Learned Descriptors (ICONIP 2021): [RadarLoc](https://github.com/jac99/RadarLoc) 18 | * Improving Point Cloud Based Place Recognition with Ranking-based Loss and Large Batch Training (2022): [MinkLoc3Dv2](https://github.com/jac99/MinkLoc3Dv2) 19 | 20 | ### Introduction 21 | The paper presents a **deep neural network-based method for global and local descriptors extraction from a point 22 | cloud acquired by a rotating 3D LiDAR sensor**. The descriptors can be used for two-stage 6DoF relocalization. First, a course 23 | position is retrieved by finding candidates with the closest global descriptor in the database of geo-tagged point clouds. Then, 24 | 6DoF pose between a query point cloud and a database point cloud is estimated by matching local descriptors and using a 25 | robust estimator such as RANSAC. Our method has a simple, fully convolutional architecture and uses a sparse voxelized 26 | representation of the input point cloud. It can **efficiently extract a global descriptor and a set of keypoints with 27 | their local descriptors from large point clouds with tens of thousand points**. 28 | 29 | ![](images/key1_.png) 30 | ![](images/pair2.png) 31 | 32 | ### Citation 33 | If you find this work useful, please consider citing: 34 | 35 | @ARTICLE{9645340, 36 | author={Komorowski, Jacek and Wysoczanska, Monika and Trzcinski, Tomasz}, 37 | journal={IEEE Robotics and Automation Letters}, 38 | title={EgoNN: Egocentric Neural Network for Point Cloud Based 6DoF Relocalization at the City Scale}, 39 | year={2022}, 40 | volume={7}, 41 | number={2}, 42 | pages={722-729}, 43 | doi={10.1109/LRA.2021.3133593}} 44 | 45 | ### Environment and Dependencies 46 | Code was tested using Python 3.8 with PyTorch 1.10.1 and MinkowskiEngine 0.5.4 on Ubuntu 20.04 with CUDA 10.2. 47 | Note: CUDA 11.1 is not recommended as there are some issues with MinkowskiEngine 0.5.4 on CUDA 11.1. 48 | 49 | The following Python packages are required: 50 | * PyTorch (version 1.10.1 or above) 51 | * MinkowskiEngine (version 0.5.4 or above) 52 | * pytorch_metric_learning (version 1.0.0 or above) 53 | * Open3D (version 0.14 or above) 54 | * python-lzf (version 0.2.4 or above) 55 | * wandb 56 | 57 | Modify the `PYTHONPATH` environment variable to include absolute path to the project root folder: 58 | ```export PYTHONPATH 59 | export PYTHONPATH=$PYTHONPATH:/home/.../Egonn 60 | ``` 61 | 62 | ### Datasets 63 | 64 | **EgoNN** is trained and evaluated using the following datasets: 65 | * MulRan dataset: Sejong traversal is used. The traversal is split into training and evaluation part [link](https://sites.google.com/view/mulran-pr) 66 | * Apollo-SouthBay dataset: SunnyvaleBigLoop trajectory is used for evaluation, other 5 trajectories (BaylandsToSeafood, 67 | ColumbiaPark, Highway237, MathildaAVE, SanJoseDowntown) are used for training [link](https://apollo.auto/southbay.html) 68 | * Kitti dataset: Sequence 00 is used for evaluation [link](http://www.cvlibs.net/datasets/kitti/) 69 | 70 | First, you need to download datasets: 71 | 72 | * For MulRan dataset you need to download ground truth data (*.csv) and LiDAR point clouds (Ouster.zip) for traversals: 73 | Sejong01 and Sejong02 ([link](https://sites.google.com/view/mulran-pr/download)). 74 | * Download Apollo-SouthBay dataset using the download link on the dataset website ([link](https://apollo.auto/southbay.html)). 75 | * Download Kitti odometry dataset (calibration files, ground truth poses, Velodyne laser data) ([link](http://www.cvlibs.net/datasets/kitti/eval_odometry.php)). 76 | 77 | After loading datasets you need to generate **training pickles** for the network training and **evaluation pickles** for model 78 | evaluation. 79 | 80 | ##### Training pickles generation 81 | 82 | Generating training tuples is very time consuming, as ICP is used to refine the ground truth poses between each pair 83 | of neighbourhood point clouds. 84 | 85 | ``` 86 | cd datasets/mulran 87 | python generate_training_tuples.py --dataset_root 88 | 89 | cd ../southbay 90 | python generate_training_tuples.py --dataset_root 91 | ``` 92 | 93 | ##### Evaluation pickles generation 94 | 95 | ``` 96 | cd datasets/mulran 97 | python generate_evaluation_sets.py --dataset_root 98 | 99 | cd ../southbay 100 | python generate_evaluation_sets.py --dataset_root 101 | 102 | cd ../kitti 103 | python generate_evaluation_sets.py --dataset_root 104 | ``` 105 | 106 | ### Training 107 | 108 | First, download datasets and generate training and evaluation pickles as described above. 109 | Edit the configuration file *config_egonn.txt*. 110 | Set *dataset_folder* parameter to point to the dataset root folder. 111 | Modify *batch_size_limit* and *secondary_batch_size_limit* parameters depending on available GPU memory. 112 | Default limits require at least 11GB of GPU RAM. 113 | 114 | To train the EgoNN model, run: 115 | 116 | ``` 117 | cd training 118 | 119 | python train.py --config ../config/config_egonn.txt --model_config ../models/egonn.txt 120 | ``` 121 | 122 | ### Pre-trained Model 123 | 124 | EgoNN model trained (on training splits of MulRan and Apollo-SouthBay datasets) is available in 125 | *weights/model_egonn_20210916_1104.pth* folder. 126 | 127 | ### Evaluation 128 | 129 | To evaluate a pretrained model run below commands. 130 | Ground truth poses between different traversals in all three datasets are slightly misaligned. 131 | To reproduce results from the paper, use `--icp_refine` option to refine ground truth poses using ICP. 132 | 133 | ``` 134 | cd eval 135 | 136 | # To evaluate on test split of Mulran dataset 137 | python evaluate.py --dataset_root --dataset_type mulran --eval_set test_Sejong01_Sejong02.pickle --model_config ../models/egonn.txt --weights ../weights/model_egonn_20210916_1104.pth --icp_refine 138 | 139 | # To evaluate on test split of Apollo-SouthBay dataset 140 | python evaluate.py --dataset_root --dataset_type southbay --eval_set test_SunnyvaleBigloop_1.0_5.pickle --model_config ../models/egonn.txt --weights ../weights/model_egonn_20210916_1104.pth --icp_refine 141 | 142 | # To evaluate on test split of KITTI dataset 143 | python evaluate.py --dataset_root --dataset_type kitti --eval_set kitti_00_eval.pickle --model_config ../models/egonn.txt --weights ../weights/model_egonn_20210916_1104.pth --icp_refine 144 | ``` 145 | 146 | ## Results 147 | 148 | **EgoNN** performance... 149 | 150 | ## Visualizations 151 | 152 | Visualizations of our keypoint detector results. 153 | On the left, we show 128 keypoints with the lowest saliency uncertainty (red dots). 154 | On the right, 128 keypoints with the highest uncertainty (yellow dots). 155 | 156 | ![](images/keypoints_vis.png) 157 | 158 | Successful registration of point cloud pairs from KITTI dataset gathered during revisiting the same place from different directions. 159 | On the left we show keypoint correspondences (RANSAC inliers) found during 6DoF pose estimation with RANSAC. 160 | On the right we show point clouds aligned using estimated poses. 161 | 162 | ![](images/registered_pairs.png) 163 | 164 | ### License 165 | Our code is released under the MIT License (see LICENSE file for details). 166 | 167 | 168 | -------------------------------------------------------------------------------- /config/config_egonn.txt: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | dataset = Mulran 3 | dataset_folder = /data3/mulran 4 | secondary_dataset = southbay 5 | secondary_dataset_folder = /data2/Apollo-SouthBay 6 | 7 | [TRAIN] 8 | num_workers = 8 9 | batch_size = 32 10 | batch_size_limit = 128 11 | batch_expansion_rate = 1.4 12 | batch_expansion_th = 0.7 13 | secondary_batch_size_limit = 96 14 | local_batch_size = 8 15 | 16 | lr = 1e-3 17 | epochs = 160 18 | scheduler_milestones = 80 19 | aug_mode = 2 20 | weight_decay = 1e-4 21 | 22 | loss = BatchHardTripletMarginLoss 23 | l_gammas = 1., 1., 1., 4. 24 | margin = 0.2 25 | 26 | train_file = train_Sejong01_Sejong02_2_10.pickle 27 | val_file = val_Sejong01_Sejong02_2_10.pickle 28 | secondary_train_file = train_southbay_2_10.pickle 29 | test_file = test_Sejong01_Sejong02.pickle 30 | -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from scipy.linalg import expm, norm 7 | from torchvision import transforms as transforms 8 | 9 | 10 | class TrainTransform: 11 | def __init__(self, aug_mode): 12 | self.aug_mode = aug_mode 13 | if self.aug_mode == 1: 14 | # Augmentations without random rotation around z-axis 15 | t = [JitterPoints(sigma=0.1, clip=0.2), RemoveRandomPoints(r=(0.0, 0.1)), 16 | RandomTranslation(max_delta=0.3), RemoveRandomBlock(p=0.4)] 17 | elif self.aug_mode == 2: 18 | # Augmentations with random rotation around z-axis 19 | t = [JitterPoints(sigma=0.1, clip=0.2), RemoveRandomPoints(r=(0.0, 0.1)), 20 | RandomTranslation(max_delta=0.3), 21 | RandomRotation(max_theta=180, axis=np.array([0, 0, 1])), 22 | RemoveRandomBlock(p=0.4)] 23 | else: 24 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 25 | self.transform = transforms.Compose(t) 26 | 27 | def __call__(self, e): 28 | if self.transform is not None: 29 | e = self.transform(e) 30 | return e 31 | 32 | 33 | class TrainSetTransform: 34 | def __init__(self, aug_mode): 35 | self.aug_mode = aug_mode 36 | if self.aug_mode == 1: 37 | t = [RandomRotation(max_theta=5, axis=np.array([0, 0, 1])), 38 | RandomFlip([0.25, 0.25, 0.])] 39 | elif self.aug_mode == 2: 40 | t = [RandomFlip([0.25, 0.25, 0.])] 41 | else: 42 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 43 | self.transform = transforms.Compose(t) 44 | 45 | def __call__(self, e): 46 | if self.transform is not None: 47 | e = self.transform(e) 48 | return e 49 | 50 | 51 | class RandomFlip: 52 | def __init__(self, p): 53 | # p = [p_x, p_y, p_z] probability of flipping each axis 54 | assert len(p) == 3 55 | assert 0 < sum(p) <= 1, 'sum(p) must be in (0, 1] range, is: {}'.format(sum(p)) 56 | self.p = p 57 | self.p_cum_sum = np.cumsum(p) 58 | 59 | def __call__(self, coords): 60 | r = random.random() 61 | if r <= self.p_cum_sum[0]: 62 | # Flip the first axis 63 | coords[..., 0] = -coords[..., 0] 64 | elif r <= self.p_cum_sum[1]: 65 | # Flip the second axis 66 | coords[..., 1] = -coords[..., 1] 67 | elif r <= self.p_cum_sum[2]: 68 | # Flip the third axis 69 | coords[..., 2] = -coords[..., 2] 70 | 71 | return coords 72 | 73 | 74 | class RandomRotation: 75 | def __init__(self, axis=None, max_theta=180, max_theta2=None): 76 | self.axis = axis 77 | self.max_theta = max_theta # Rotation around axis 78 | self.max_theta2 = max_theta2 # Smaller rotation in random direction 79 | 80 | def _M(self, axis, theta): 81 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)).astype(np.float32) 82 | 83 | def __call__(self, coords): 84 | if self.axis is not None: 85 | axis = self.axis 86 | else: 87 | axis = np.random.rand(3) - 0.5 88 | R = self._M(axis, (np.pi * self.max_theta / 180.) * 2. * (np.random.rand(1) - 0.5)) 89 | if self.max_theta2 is None: 90 | coords = coords @ R 91 | else: 92 | R_n = self._M(np.random.rand(3) - 0.5, (np.pi * self.max_theta2 / 180.) * 2. * (np.random.rand(1) - 0.5)) 93 | coords = coords @ R @ R_n 94 | 95 | return coords 96 | 97 | 98 | class Rotation: 99 | def __init__(self, axis=None, theta=180): 100 | self.axis = axis 101 | self.theta = theta # Rotation around axis 102 | 103 | def _M(self, axis, theta): 104 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)).astype(np.float32) 105 | 106 | def __call__(self, coords): 107 | if self.axis is not None: 108 | axis = self.axis 109 | else: 110 | axis = np.random.rand(3) - 0.5 111 | R = self._M(axis, np.pi * self.theta / 180.) 112 | coords = coords @ R 113 | return coords 114 | 115 | 116 | class RandomTranslation: 117 | def __init__(self, max_delta=0.05): 118 | self.max_delta = max_delta 119 | 120 | def __call__(self, coords): 121 | trans = self.max_delta * np.random.randn(1, 3) 122 | return coords + trans.astype(np.float32) 123 | 124 | 125 | class RandomScale: 126 | def __init__(self, min, max): 127 | self.scale = max - min 128 | self.bias = min 129 | 130 | def __call__(self, coords): 131 | s = self.scale * np.random.rand(1) + self.bias 132 | return coords * s.astype(np.float32) 133 | 134 | 135 | class RandomShear: 136 | def __init__(self, delta=0.1): 137 | self.delta = delta 138 | 139 | def __call__(self, coords): 140 | T = np.eye(3) + self.delta * np.random.randn(3, 3) 141 | return coords @ T.astype(np.float32) 142 | 143 | 144 | class JitterPoints: 145 | def __init__(self, sigma=0.01, clip=None, p=1.): 146 | assert 0 < p <= 1. 147 | assert sigma > 0. 148 | 149 | self.sigma = sigma 150 | self.clip = clip 151 | self.p = p 152 | 153 | def __call__(self, e): 154 | """ Randomly jitter points. jittering is per point. 155 | Input: 156 | BxNx3 array, original batch of point clouds 157 | Return: 158 | BxNx3 array, jittered batch of point clouds 159 | """ 160 | 161 | sample_shape = (e.shape[0],) 162 | if self.p < 1.: 163 | # Create a mask for points to jitter 164 | m = torch.distributions.categorical.Categorical(probs=torch.tensor([1 - self.p, self.p])) 165 | mask = m.sample(sample_shape=sample_shape) 166 | else: 167 | mask = torch.ones(sample_shape, dtype=torch.int64 ) 168 | 169 | mask = mask == 1 170 | jitter = self.sigma * torch.randn_like(e[mask]) 171 | 172 | if self.clip is not None: 173 | jitter = torch.clamp(jitter, min=-self.clip, max=self.clip) 174 | 175 | e[mask] = e[mask] + jitter 176 | return e 177 | 178 | 179 | class RemoveRandomPoints: 180 | def __init__(self, r): 181 | if type(r) is list or type(r) is tuple: 182 | assert len(r) == 2 183 | assert 0 <= r[0] <= 1 184 | assert 0 <= r[1] <= 1 185 | self.r_min = float(r[0]) 186 | self.r_max = float(r[1]) 187 | else: 188 | assert 0 <= r <= 1 189 | self.r_min = None 190 | self.r_max = float(r) 191 | 192 | def __call__(self, e): 193 | n = len(e) 194 | if self.r_min is None: 195 | r = self.r_max 196 | else: 197 | # Randomly select removal ratio 198 | r = random.uniform(self.r_min, self.r_max) 199 | 200 | mask = np.random.choice(range(n), size=int(n*r), replace=False) # select elements to remove 201 | e[mask] = torch.zeros_like(e[mask]) 202 | return e 203 | 204 | 205 | class RemoveRandomBlock: 206 | """ 207 | Randomly remove part of the point cloud. Similar to PyTorch RandomErasing but operating on 3D point clouds. 208 | Erases fronto-parallel cuboid. 209 | Instead of erasing we set coords of removed points to (0, 0, 0) to retain the same number of points 210 | """ 211 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 212 | self.p = p 213 | self.scale = scale 214 | self.ratio = ratio 215 | 216 | def get_params(self, coords): 217 | # Find point cloud 3D bounding box 218 | flattened_coords = coords.view(-1, 3) 219 | min_coords, _ = torch.min(flattened_coords, dim=0) 220 | max_coords, _ = torch.max(flattened_coords, dim=0) 221 | span = max_coords - min_coords 222 | area = span[0] * span[1] 223 | erase_area = random.uniform(self.scale[0], self.scale[1]) * area 224 | aspect_ratio = random.uniform(self.ratio[0], self.ratio[1]) 225 | 226 | h = math.sqrt(erase_area * aspect_ratio) 227 | w = math.sqrt(erase_area / aspect_ratio) 228 | 229 | x = min_coords[0] + random.uniform(0, 1) * (span[0] - w) 230 | y = min_coords[1] + random.uniform(0, 1) * (span[1] - h) 231 | 232 | return x, y, w, h 233 | 234 | def __call__(self, coords): 235 | if random.random() < self.p: 236 | x, y, w, h = self.get_params(coords) # Fronto-parallel cuboid to remove 237 | mask = (x < coords[..., 0]) & (coords[..., 0] < x+w) & (y < coords[..., 1]) & (coords[..., 1] < y+h) 238 | coords[mask] = torch.zeros_like(coords[mask]) 239 | return coords -------------------------------------------------------------------------------- /datasets/base_datasets.py: -------------------------------------------------------------------------------- 1 | # Base dataset classes, inherited by dataset-specific classes 2 | import os 3 | import pickle 4 | from typing import List, Dict 5 | import torch 6 | import numpy as np 7 | from torch.utils.data import Dataset 8 | 9 | from datasets.kitti.kitti_raw import KittiPointCloudLoader 10 | from datasets.mulran.mulran_raw import MulranPointCloudLoader 11 | from datasets.southbay.southbay_raw import SouthbayPointCloudLoader 12 | from misc.point_clouds import PointCloudLoader 13 | 14 | 15 | class TrainingTuple: 16 | # Tuple describing an element for training/validation 17 | def __init__(self, id: int, timestamp: int, rel_scan_filepath: str, positives: np.ndarray, 18 | non_negatives: np.ndarray, pose: np, positives_poses: Dict[int, np.ndarray] = None): 19 | # id: element id (ids start from 0 and are consecutive numbers) 20 | # ts: timestamp 21 | # rel_scan_filepath: relative path to the scan 22 | # positives: sorted ndarray of positive elements id 23 | # negatives: sorted ndarray of elements id 24 | # pose: pose as 4x4 matrix 25 | # positives_poses: relative poses of positive examples refined using ICP 26 | self.id = id 27 | self.timestamp = timestamp 28 | self.rel_scan_filepath = rel_scan_filepath 29 | self.positives = positives 30 | self.non_negatives = non_negatives 31 | self.pose = pose 32 | self.positives_poses = positives_poses 33 | 34 | 35 | class EvaluationTuple: 36 | # Tuple describing an evaluation set element 37 | def __init__(self, timestamp: int, rel_scan_filepath: str, position: np.array, pose: np.array = None): 38 | # position: x, y position in meters 39 | # pose: 6 DoF pose (as 4x4 pose matrix) 40 | assert position.shape == (2,) 41 | assert pose is None or pose.shape == (4, 4) 42 | self.timestamp = timestamp 43 | self.rel_scan_filepath = rel_scan_filepath 44 | self.position = position 45 | self.pose = pose 46 | 47 | def to_tuple(self): 48 | return self.timestamp, self.rel_scan_filepath, self.position, self.pose 49 | 50 | 51 | class TrainingDataset(Dataset): 52 | def __init__(self, dataset_path: str, dataset_type: str, query_filename: str, transform=None, set_transform=None): 53 | # remove_zero_points: remove points with all zero coords 54 | assert os.path.exists(dataset_path), 'Cannot access dataset path: {}'.format(dataset_path) 55 | self.dataset_path = dataset_path 56 | self.dataset_type = dataset_type 57 | self.query_filepath = os.path.join(dataset_path, query_filename) 58 | assert os.path.exists(self.query_filepath), 'Cannot access query file: {}'.format(self.query_filepath) 59 | self.transform = transform 60 | self.set_transform = set_transform 61 | self.queries: Dict[int, TrainingTuple] = pickle.load(open(self.query_filepath, 'rb')) 62 | print('{} queries in the dataset'.format(len(self))) 63 | 64 | # pc_loader must be set in the inheriting class 65 | self.pc_loader = get_pointcloud_loader(self.dataset_type) 66 | 67 | def __len__(self): 68 | return len(self.queries) 69 | 70 | def __getitem__(self, ndx): 71 | # Load point cloud and apply transform 72 | file_pathname = os.path.join(self.dataset_path, self.queries[ndx].rel_scan_filepath) 73 | query_pc = self.pc_loader(file_pathname) 74 | query_pc = torch.tensor(query_pc, dtype=torch.float) 75 | if self.transform is not None: 76 | query_pc = self.transform(query_pc) 77 | return query_pc, ndx 78 | 79 | def get_positives(self, ndx): 80 | return self.queries[ndx].positives 81 | 82 | def get_non_negatives(self, ndx): 83 | return self.queries[ndx].non_negatives 84 | 85 | 86 | class EvaluationSet: 87 | # Evaluation set consisting of map and query elements 88 | def __init__(self, query_set: List[EvaluationTuple] = None, map_set: List[EvaluationTuple] = None): 89 | self.query_set = query_set 90 | self.map_set = map_set 91 | 92 | def save(self, pickle_filepath: str): 93 | # Pickle the evaluation set 94 | 95 | # Convert data to tuples and save as tuples 96 | query_l = [] 97 | for e in self.query_set: 98 | query_l.append(e.to_tuple()) 99 | 100 | map_l = [] 101 | for e in self.map_set: 102 | map_l.append(e.to_tuple()) 103 | pickle.dump([query_l, map_l], open(pickle_filepath, 'wb')) 104 | 105 | def load(self, pickle_filepath: str): 106 | # Load evaluation set from the pickle 107 | query_l, map_l = pickle.load(open(pickle_filepath, 'rb')) 108 | 109 | self.query_set = [] 110 | for e in query_l: 111 | self.query_set.append(EvaluationTuple(e[0], e[1], e[2], e[3])) 112 | 113 | self.map_set = [] 114 | for e in map_l: 115 | self.map_set.append(EvaluationTuple(e[0], e[1], e[2], e[3])) 116 | 117 | def get_map_positions(self): 118 | # Get map positions as (N, 2) array 119 | positions = np.zeros((len(self.map_set), 2), dtype=self.map_set[0].position.dtype) 120 | for ndx, pos in enumerate(self.map_set): 121 | positions[ndx] = pos.position 122 | return positions 123 | 124 | def get_query_positions(self): 125 | # Get query positions as (N, 2) array 126 | positions = np.zeros((len(self.query_set), 2), dtype=self.query_set[0].position.dtype) 127 | for ndx, pos in enumerate(self.query_set): 128 | positions[ndx] = pos.position 129 | return positions 130 | 131 | 132 | def get_pointcloud_loader(dataset_type) -> PointCloudLoader: 133 | if dataset_type == 'mulran': 134 | return MulranPointCloudLoader() 135 | elif dataset_type == 'southbay': 136 | return SouthbayPointCloudLoader() 137 | elif dataset_type == 'kitti': 138 | return KittiPointCloudLoader() 139 | else: 140 | raise NotImplementedError(f"Unsupported dataset type: {dataset_type}") -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import numpy as np 4 | from typing import List 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import MinkowskiEngine as ME 8 | from sklearn.neighbors import KDTree 9 | 10 | from datasets.base_datasets import TrainingDataset, EvaluationTuple 11 | from datasets.mulran.mulran_train import MulranTraining6DOFDataset 12 | from datasets.augmentation import TrainTransform, TrainSetTransform 13 | from datasets.samplers import BatchSampler 14 | from misc.utils import TrainingParams 15 | from datasets.base_datasets import TrainingDataset 16 | 17 | 18 | def make_datasets(params: TrainingParams, local=False, validation=True): 19 | # Create training and validation datasets 20 | datasets = {} 21 | train_transform = TrainTransform(params.aug_mode) 22 | train_set_transform = TrainSetTransform(params.aug_mode) 23 | 24 | datasets['global_train'] = TrainingDataset(params.dataset_folder, params.dataset, params.train_file, 25 | transform=train_transform, set_transform=train_set_transform) 26 | if validation: 27 | datasets['global_val'] = TrainingDataset(params.dataset_folder, params.dataset, params.val_file) 28 | 29 | if params.secondary_dataset is not None: 30 | datasets['secondary_train'] = TrainingDataset(params.secondary_dataset_folder, params.secondary_dataset, 31 | params.secondary_train_file, transform=train_transform, 32 | set_transform=train_set_transform) 33 | 34 | if local: 35 | datasets['local_train'] = MulranTraining6DOFDataset(params.dataset_folder, params.train_file, 36 | params.model_params.quantizer, 37 | rot_max=params.rot_max, trans_max=params.trans_max) 38 | if validation: 39 | datasets['local_val'] = MulranTraining6DOFDataset(params.dataset_folder, params.val_file, 40 | params.model_params.quantizer, 41 | rot_max=params.rot_max, trans_max=params.trans_max) 42 | 43 | return datasets 44 | 45 | 46 | def local_batch_to_device(batch, device): 47 | # Move the batch used to train the local descriptor to the proper device 48 | # Move everything except for len_batch 49 | batch['anc_batch'] = {'coords': batch['anc_batch']['coords'].to(device), 50 | 'features': batch['anc_batch']['features'].to(device)} 51 | batch['pos_batch'] = {'coords': batch['pos_batch']['coords'].to(device), 52 | 'features': batch['pos_batch']['features'].to(device)} 53 | batch['anc_pcd'] = batch['anc_pcd'].to(device) 54 | batch['pos_pcd'] = batch['pos_pcd'].to(device) 55 | batch['T_gt'] = batch['T_gt'].to(device) 56 | 57 | return batch 58 | 59 | 60 | def make_collate_fn(dataset: TrainingDataset, quantizer): 61 | # quantizer: converts to polar (when polar coords are used) and quantizes 62 | def collate_fn(data_list): 63 | # Constructs a batch object 64 | clouds = [e[0] for e in data_list] 65 | labels = [e[1] for e in data_list] 66 | 67 | if dataset.set_transform is not None: 68 | # Apply the same transformation on all dataset elements 69 | lens = [len(cloud) for cloud in clouds] 70 | clouds = torch.cat(clouds, dim=0) 71 | clouds = dataset.set_transform(clouds) 72 | clouds = clouds.split(lens) 73 | 74 | # Convert to polar (when polar coords are used) and quantize 75 | # Use the first value returned by quantizer 76 | coords = [quantizer(e)[0] for e in clouds] 77 | coords = ME.utils.batched_coordinates(coords) 78 | 79 | # Assign a dummy feature equal to 1 to each point 80 | feats = torch.ones((coords.shape[0], 1), dtype=torch.float32) 81 | batch = {'coords': coords, 'features': feats} 82 | 83 | # Compute positives and negatives mask 84 | # dataset.queries[label]['positives'] is bitarray 85 | positives_mask = [[in_sorted_array(e, dataset.queries[label].positives) for e in labels] for label in labels] 86 | negatives_mask = [[not in_sorted_array(e, dataset.queries[label].non_negatives) for e in labels] for label in labels] 87 | positives_mask = torch.tensor(positives_mask) 88 | negatives_mask = torch.tensor(negatives_mask) 89 | 90 | # Returns (batch_size, n_points, 3) tensor and positives_mask and negatives_mask which are 91 | # batch_size x batch_size boolean tensors 92 | #return batch, positives_mask, negatives_mask, torch.tensor(sampled_positive_ndx), torch.tensor(relative_poses) 93 | return batch, positives_mask, negatives_mask 94 | 95 | return collate_fn 96 | 97 | 98 | def make_collate_fn_6DOF(dataset: TrainingDataset, quantizer, device): 99 | 100 | # ego_converted: function to convert from cartesian to polar coordinates; if None no conversion is done 101 | # set_transform: the transform to be applied to all batch elements 102 | def collate_fn(data_list): 103 | # Constructs a batch object 104 | anchor_clouds = [e[0] for e in data_list] 105 | positive_clouds = [e[1] for e in data_list] 106 | rel_transforms = [e[2] for e in data_list] 107 | 108 | xyz_batch1, xyz_batch2 = [], [] 109 | trans_batch, len_batch = [], [] 110 | curr_start_inds = np.zeros((1, 2), dtype=np.int) 111 | 112 | for batch_id, _ in enumerate(anchor_clouds): 113 | N1 = anchor_clouds[batch_id].shape[0] 114 | N2 = positive_clouds[batch_id].shape[0] 115 | 116 | xyz_batch1.append(anchor_clouds[batch_id]) 117 | xyz_batch2.append(positive_clouds[batch_id]) 118 | trans_batch.append(rel_transforms[batch_id]) 119 | 120 | len_batch.append([N1, N2]) 121 | 122 | # Move the head 123 | curr_start_inds[0, 0] += N1 124 | curr_start_inds[0, 1] += N2 125 | 126 | anch_coords = [quantizer(e)[0] for e in anchor_clouds] 127 | anch_coords = ME.utils.batched_coordinates(anch_coords) 128 | pos_coords = [quantizer(e)[0] for e in positive_clouds] 129 | pos_coords = ME.utils.batched_coordinates(pos_coords) 130 | 131 | # Assign a dummy feature equal to 1 to each point 132 | anch_feats = torch.ones((anch_coords.shape[0], 1), dtype=torch.float32) 133 | anch_batch = {'coords': anch_coords, 'features': anch_feats} 134 | pos_feats = torch.ones((pos_coords.shape[0], 1), dtype=torch.float32) 135 | pos_batch = {'coords': pos_coords, 'features': pos_feats} 136 | 137 | # Concatenate point coordinates 138 | xyz_batch1 = torch.cat(xyz_batch1, 0).float() 139 | xyz_batch2 = torch.cat(xyz_batch2, 0).float() 140 | 141 | # Stack transforms 142 | trans_batch = torch.stack(trans_batch, 0).float() 143 | 144 | # Returns: 145 | # (batch_size, n_points, 3) tensor, 146 | # positives_mask and negatives_mask which are batch_size x batch_size boolean tensors 147 | # relative transformations between a pair of clouds 148 | return {"anc_pcd": xyz_batch1, "pos_pcd": xyz_batch2, "anc_batch": anch_batch, "pos_batch": pos_batch, 149 | "T_gt": trans_batch, "len_batch": len_batch} 150 | 151 | return collate_fn 152 | 153 | 154 | def make_dataloaders(params: TrainingParams, debug=False, device='cpu', local=False, validation=True): 155 | """ 156 | Create training and validation dataloaders that return groups of k=2 similar elements 157 | :param train_params: 158 | :param model_params: 159 | :return: 160 | """ 161 | datasets = make_datasets(params, local=local, validation=validation) 162 | 163 | dataloders = {} 164 | train_sampler = BatchSampler(datasets['global_train'], batch_size=params.batch_size, 165 | batch_size_limit=params.batch_size_limit, 166 | batch_expansion_rate=params.batch_expansion_rate) 167 | 168 | 169 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 170 | quantizer = params.model_params.quantizer 171 | train_collate_fn = make_collate_fn(datasets['global_train'], quantizer) 172 | dataloders['global_train'] = DataLoader(datasets['global_train'], batch_sampler=train_sampler, 173 | collate_fn=train_collate_fn, num_workers=params.num_workers, 174 | pin_memory=True) 175 | if validation and 'global_val' in datasets: 176 | val_collate_fn = make_collate_fn(datasets['global_val'], quantizer) 177 | val_sampler = BatchSampler(datasets['global_val'], batch_size=params.batch_size_limit) 178 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 179 | # Currently validation dataset has empty set_transform function, but it may change in the future 180 | dataloders['global_val'] = DataLoader(datasets['global_val'], batch_sampler=val_sampler, 181 | collate_fn=val_collate_fn, 182 | num_workers=params.num_workers, pin_memory=True) 183 | 184 | if params.secondary_dataset is not None: 185 | secondary_train_sampler = BatchSampler(datasets['secondary_train'], batch_size=params.batch_size, 186 | batch_size_limit=params.secondary_batch_size_limit, 187 | batch_expansion_rate=params.batch_expansion_rate, max_batches=2000) 188 | 189 | secondary_train_collate_fn = make_collate_fn(datasets['secondary_train'], quantizer) 190 | dataloders['secondary_train'] = DataLoader(datasets['secondary_train'], 191 | batch_sampler=secondary_train_sampler, 192 | collate_fn=secondary_train_collate_fn, 193 | num_workers=params.num_workers, 194 | pin_memory=True) 195 | 196 | if local: 197 | train_collate_fn_loc = make_collate_fn_6DOF(datasets['local_train'], quantizer, device) 198 | val_collate_fn_loc = make_collate_fn_6DOF(datasets['local_val'], quantizer, device) 199 | dataloders['local_train'] = DataLoader(datasets['local_train'], shuffle=True, collate_fn=train_collate_fn_loc, 200 | batch_size=params.local_batch_size, num_workers=params.num_workers, 201 | pin_memory=True) 202 | if validation and 'local_val' in datasets: 203 | dataloders['local_val'] = DataLoader(datasets['local_val'], 204 | collate_fn=val_collate_fn_loc, batch_size=params.local_batch_size, 205 | num_workers=params.num_workers, pin_memory=True) 206 | 207 | return dataloders 208 | 209 | 210 | def filter_query_elements(query_set: List[EvaluationTuple], map_set: List[EvaluationTuple], 211 | dist_threshold: float) -> List[EvaluationTuple]: 212 | # Function used in evaluation dataset generation 213 | # Filters out query elements without a corresponding map element within dist_threshold threshold 214 | map_pos = np.zeros((len(map_set), 2), dtype=np.float32) 215 | for ndx, e in enumerate(map_set): 216 | map_pos[ndx] = e.position 217 | 218 | # Build a kdtree 219 | kdtree = KDTree(map_pos) 220 | 221 | filtered_query_set = [] 222 | count_ignored = 0 223 | for ndx, e in enumerate(query_set): 224 | position = e.position.reshape(1, -1) 225 | nn = kdtree.query_radius(position, dist_threshold, count_only=True)[0] 226 | if nn > 0: 227 | filtered_query_set.append(e) 228 | else: 229 | count_ignored += 1 230 | 231 | print(f"{count_ignored} query elements ignored - not having corresponding map element within {dist_threshold} [m] radius") 232 | return filtered_query_set 233 | 234 | 235 | def preprocess_pointcloud(pc, remove_zero_points: bool = False, 236 | min_x: float = None, max_x: float = None, 237 | min_y: float = None, max_y: float = None, 238 | min_z: float = None, max_z: float = None): 239 | if remove_zero_points: 240 | mask = np.all(np.isclose(pc, 0.), axis=1) 241 | pc = pc[~mask] 242 | 243 | if min_x is not None: 244 | mask = pc[:, 0] > min_x 245 | pc = pc[mask] 246 | 247 | if max_x is not None: 248 | mask = pc[:, 0] <= max_x 249 | pc = pc[mask] 250 | 251 | if min_y is not None: 252 | mask = pc[:, 1] > min_y 253 | pc = pc[mask] 254 | 255 | if max_y is not None: 256 | mask = pc[:, 1] <= max_y 257 | pc = pc[mask] 258 | 259 | if min_z is not None: 260 | mask = pc[:, 2] > min_z 261 | pc = pc[mask] 262 | 263 | if max_z is not None: 264 | mask = pc[:, 2] <= max_z 265 | pc = pc[mask] 266 | 267 | return pc 268 | 269 | 270 | def in_sorted_array(e: int, array: np.ndarray) -> bool: 271 | pos = np.searchsorted(array, e) 272 | if pos == len(array) or pos == -1: 273 | return False 274 | else: 275 | return array[pos] == e 276 | -------------------------------------------------------------------------------- /datasets/kitti/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Test set for Kitti Sequence 00 dataset. 2 | # Following procedures in [cite papers Kitti for place reco] we use 170 seconds of drive from sequence for map generation 3 | # and the rest is left for queries 4 | 5 | import numpy as np 6 | import argparse 7 | from typing import List 8 | import os 9 | 10 | from datasets.kitti.kitti_raw import KittiSequence 11 | from datasets.base_datasets import EvaluationTuple, EvaluationSet 12 | from datasets.dataset_utils import filter_query_elements 13 | 14 | 15 | MAP_TIMERANGE = (0, 170) 16 | 17 | 18 | def get_scans(sequence: KittiSequence, min_displacement: float = 0.1, ts_range: tuple = None) -> List[EvaluationTuple]: 19 | # Get a list of all point clouds from the sequence (the full sequence or test split only) 20 | 21 | elems = [] 22 | old_pos = None 23 | count_skipped = 0 24 | displacements = [] 25 | 26 | for ndx in range(len(sequence)): 27 | if ts_range is not None: 28 | if (ts_range[0] > sequence.rel_lidar_timestamps[ndx]) or (ts_range[1] < sequence.rel_lidar_timestamps[ndx]): 29 | continue 30 | pose = sequence.lidar_poses[ndx] 31 | # Kitti poses are in camera coordinates system where where y is upper axis dim 32 | position = pose[[0,2], 3] 33 | 34 | if old_pos is not None: 35 | displacements.append(np.linalg.norm(old_pos - position)) 36 | 37 | if np.linalg.norm(old_pos - position) < min_displacement: 38 | # Ignore the point cloud if the vehicle didn't move 39 | count_skipped += 1 40 | continue 41 | 42 | item = EvaluationTuple(sequence.rel_lidar_timestamps[ndx], sequence.rel_scan_filepath[ndx], position, pose) 43 | elems.append(item) 44 | old_pos = position 45 | 46 | print(f'{count_skipped} clouds skipped due to displacement smaller than {min_displacement}') 47 | print(f'mean displacement {np.mean(np.array(displacements))}') 48 | return elems 49 | 50 | 51 | def generate_evaluation_set(dataset_root: str, map_sequence: str, min_displacement: float = 0.1, 52 | dist_threshold: float = 5.) -> EvaluationSet: 53 | 54 | sequence = KittiSequence(dataset_root, map_sequence) 55 | 56 | map_set = get_scans(sequence, min_displacement, MAP_TIMERANGE) 57 | query_set = get_scans(sequence, min_displacement, (MAP_TIMERANGE[-1], sequence.rel_lidar_timestamps[-1])) 58 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 59 | print(f'{len(map_set)} database elements, {len(query_set)} query elements') 60 | return EvaluationSet(query_set, map_set) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser(description='Generate evaluation sets for KItti dataset') 65 | parser.add_argument('--dataset_root', type=str, required=True) 66 | parser.add_argument('--min_displacement', type=float, default=0.1) 67 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 68 | parser.add_argument('--dist_threshold', type=float, default=5.) 69 | 70 | args = parser.parse_args() 71 | 72 | # Sequences are fixed 73 | sequence = '00' 74 | print(f'Dataset root: {args.dataset_root}') 75 | print(f'Kitti sequence: {sequence}') 76 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 77 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 78 | 79 | kitti_eval_set = generate_evaluation_set(args.dataset_root, sequence, min_displacement=args.min_displacement, 80 | dist_threshold=args.dist_threshold) 81 | file_path_name = os.path.join(args.dataset_root, f'kitti_{sequence}_eval.pickle') 82 | print(f"Saving evaluation pickle: {file_path_name}") 83 | kitti_eval_set.save(file_path_name) 84 | -------------------------------------------------------------------------------- /datasets/kitti/kitti_raw.py: -------------------------------------------------------------------------------- 1 | # Functions and classes operating on a raw Kitti dataset 2 | 3 | import numpy as np 4 | import os 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from misc.point_clouds import PointCloudLoader 9 | 10 | 11 | class KittiPointCloudLoader(PointCloudLoader): 12 | def set_properties(self): 13 | # Set point cloud properties, such as ground_plane_level. 14 | self.ground_plane_level = -1.5 15 | 16 | def read_pc(self, file_pathname: str) -> torch.Tensor: 17 | # Reads the point cloud without pre-processing 18 | # Returns Nx3 tensor 19 | pc = np.fromfile(file_pathname, dtype=np.float32) 20 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 21 | pc = np.reshape(pc, (-1, 4))[:, :3] 22 | return pc 23 | 24 | 25 | class KittiSequence(Dataset): 26 | """ 27 | Point cloud from a sequence from a raw Mulran dataset 28 | """ 29 | def __init__(self, dataset_root: str, sequence_name: str, pose_time_tolerance: float = 1., 30 | remove_zero_points: bool = True): 31 | # pose_time_tolerance: (in seconds) skip point clouds without corresponding pose information (based on 32 | # timestamps difference) 33 | # remove_zero_points: remove (0,0,0) points 34 | 35 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}' 36 | self.dataset_root = dataset_root 37 | self.sequence_name = sequence_name 38 | # self.sequence_path = os.path.join(self.dataset_root, 'sequences') 39 | # assert os.path.exists(self.sequence_path), f'Cannot access sequence: {self.sequence_path}' 40 | self.rel_lidar_path = os.path.join('sequences', self.sequence_name, 'velodyne') 41 | # lidar_path = os.path.join(self.sequence_path, self.rel_lidar_path) 42 | # assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}' 43 | self.pose_file = os.path.join(self.dataset_root, 'poses', self.sequence_name + '.txt') 44 | assert os.path.exists(self.pose_file), f'Cannot access sequence pose file: {self.pose_file}' 45 | self.times_file = os.path.join(self.dataset_root, 'sequences', self.sequence_name, 'times.txt') 46 | assert os.path.exists(self.pose_file), f'Cannot access sequence times file: {self.times_file}' 47 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds 48 | self.pose_time_tolerance = pose_time_tolerance 49 | self.remove_zero_points = remove_zero_points 50 | 51 | self.rel_lidar_timestamps, self.lidar_poses, filenames = self._read_lidar_poses() 52 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, '%06d%s' % (e, '.bin')) for e in filenames] 53 | 54 | def __len__(self): 55 | return len(self.rel_lidar_timestamps) 56 | 57 | def __getitem__(self, ndx): 58 | scan_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx]) 59 | pc = load_pc(scan_filepath) 60 | if self.remove_zero_points: 61 | mask = np.all(np.isclose(pc, 0), axis=1) 62 | pc = pc[~mask] 63 | return {'pc': pc, 'pose': self.lidar_poses[ndx], 'ts': self.rel_lidar_timestamps[ndx]} 64 | 65 | def _read_lidar_poses(self): 66 | fnames = os.listdir(os.path.join(self.dataset_root, self.rel_lidar_path)) 67 | temp = os.path.join(self.dataset_root, self.rel_lidar_path) 68 | fnames = [e for e in fnames if os.path.isfile(os.path.join(temp, e))] 69 | assert len(fnames) > 0, f"Make sure that the path {self.rel_lidar_path}" 70 | filenames = sorted([int(os.path.split(fname)[-1][:-4]) for fname in fnames]) 71 | with open(self.pose_file, "r") as h: 72 | txt_poses = h.readlines() 73 | 74 | n = len(txt_poses) 75 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix 76 | 77 | for ndx, pose in enumerate(txt_poses): 78 | # Split by comma and remove whitespaces 79 | temp = [e.strip() for e in pose.split(' ')] 80 | assert len(temp) == 12, f'Invalid line in global poses file: {temp}' 81 | # poses in kitti ar ein cam0 reference 82 | poses[ndx] = np.array([[float(temp[0]), float(temp[1]), float(temp[2]), float(temp[3])], 83 | [float(temp[4]), float(temp[5]), float(temp[6]), float(temp[7])], 84 | [float(temp[8]), float(temp[9]), float(temp[10]), float(temp[11])], 85 | [0., 0., 0., 1.]]) 86 | rel_ts = np.genfromtxt(self.times_file) 87 | 88 | return rel_ts, poses, filenames 89 | 90 | 91 | def load_pc(filepath): 92 | # Load point cloud, does not apply any transform 93 | # Returns Nx3 matrix 94 | pc = np.fromfile(filepath, dtype=np.float32) 95 | # PC in Kitti is of size [num_points, 4] -> x,y,z,reflectance 96 | pc = np.reshape(pc, (-1, 4))[:, :3] 97 | return pc 98 | -------------------------------------------------------------------------------- /datasets/kitti/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def velo2cam(): 5 | R = np.array([ 6 | 7.533745e-03, -9.999714e-01, -6.166020e-04, 1.480249e-02, 7.280733e-04, 7 | -9.998902e-01, 9.998621e-01, 7.523790e-03, 1.480755e-02 8 | ]).reshape(3, 3) 9 | T = np.array([-4.069766e-03, -7.631618e-02, -2.717806e-01]).reshape(3, 1) 10 | velo2cam = np.hstack([R, T]) 11 | velo2cam = np.vstack((velo2cam, [0, 0, 0, 1])).T 12 | return velo2cam 13 | 14 | 15 | def get_relative_pose(pose_1, pose_2): 16 | # as seen in https://github.com/chrischoy/FCGF 17 | M = (velo2cam() @ pose_1.T @ np.linalg.inv(pose_2.T) @ np.linalg.inv(velo2cam())).T 18 | return M 19 | -------------------------------------------------------------------------------- /datasets/mulran/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Test sets for Mulran dataset. 2 | 3 | import argparse 4 | from typing import List 5 | import os 6 | 7 | from datasets.mulran.mulran_raw import MulranSequence 8 | from datasets.base_datasets import EvaluationTuple, EvaluationSet 9 | from datasets.dataset_utils import filter_query_elements 10 | 11 | DEBUG = False 12 | 13 | 14 | def get_scans(sequence: MulranSequence) -> List[EvaluationTuple]: 15 | # Get a list of all readings from the test area in the sequence 16 | elems = [] 17 | for ndx in range(len(sequence)): 18 | pose = sequence.poses[ndx] 19 | position = pose[:2, 3] 20 | item = EvaluationTuple(sequence.timestamps[ndx], sequence.rel_scan_filepath[ndx], position=position, pose=pose) 21 | elems.append(item) 22 | return elems 23 | 24 | 25 | def generate_evaluation_set(dataset_root: str, map_sequence: str, query_sequence: str, min_displacement: float = 0.2, 26 | dist_threshold=20) -> EvaluationSet: 27 | split = 'test' 28 | map_sequence = MulranSequence(dataset_root, map_sequence, split=split, min_displacement=min_displacement) 29 | query_sequence = MulranSequence(dataset_root, query_sequence, split=split, min_displacement=min_displacement) 30 | 31 | map_set = get_scans(map_sequence) 32 | query_set = get_scans(query_sequence) 33 | 34 | # Function used in evaluation dataset generation 35 | # Filters out query elements without a corresponding map element within dist_threshold threshold 36 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 37 | print(f'{len(map_set)} database elements, {len(query_set)} query elements') 38 | return EvaluationSet(query_set, map_set) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Mulran dataset') 43 | parser.add_argument('--dataset_root', type=str, required=True) 44 | parser.add_argument('--min_displacement', type=float, default=0.2) 45 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 46 | parser.add_argument('--dist_threshold', type=float, default=20) 47 | args = parser.parse_args() 48 | 49 | print(f'Dataset root: {args.dataset_root}') 50 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 51 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 52 | 53 | # Sequences is a list of (map sequence, query sequence) 54 | sequences = [('Sejong01', 'Sejong02')] 55 | if DEBUG: 56 | sequences = [('ParkingLot', 'ParkingLot')] 57 | 58 | for map_sequence, query_sequence in sequences: 59 | print(f'Map sequence: {map_sequence}') 60 | print(f'Query sequence: {query_sequence}') 61 | 62 | test_set = generate_evaluation_set(args.dataset_root, map_sequence, query_sequence, 63 | min_displacement=args.min_displacement, dist_threshold=args.dist_threshold) 64 | 65 | pickle_name = f'test_{map_sequence}_{query_sequence}.pickle' 66 | file_path_name = os.path.join(args.dataset_root, pickle_name) 67 | test_set.save(file_path_name) 68 | -------------------------------------------------------------------------------- /datasets/mulran/generate_training_tuples.py: -------------------------------------------------------------------------------- 1 | # Training tuples generation for Mulran dataset. 2 | 3 | import numpy as np 4 | import argparse 5 | import tqdm 6 | import pickle 7 | import os 8 | 9 | from datasets.mulran.mulran_raw import MulranSequences 10 | from datasets.base_datasets import TrainingTuple 11 | from datasets.mulran.utils import relative_pose 12 | from misc.point_clouds import icp 13 | 14 | DEBUG = False 15 | 16 | 17 | def load_pc(file_pathname): 18 | # Load point cloud, clip x, y and z coords (points far away and the ground plane) 19 | # Returns Nx3 matrix 20 | pc = np.fromfile(file_pathname, dtype=np.float32) 21 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 22 | pc = np.reshape(pc, (-1, 4))[:, :3] 23 | 24 | mask = np.all(np.isclose(pc, 0.), axis=1) 25 | pc = pc[~mask] 26 | mask = pc[:, 0] > -80 27 | pc = pc[mask] 28 | mask = pc[:, 0] <= 80 29 | 30 | pc = pc[mask] 31 | mask = pc[:, 1] > -80 32 | pc = pc[mask] 33 | mask = pc[:, 1] <= 80 34 | pc = pc[mask] 35 | 36 | mask = pc[:, 2] > -0.9 37 | pc = pc[mask] 38 | return pc 39 | 40 | 41 | def generate_training_tuples(ds: MulranSequences, pos_threshold: float = 10, neg_threshold: float = 50): 42 | # displacement: displacement between consecutive anchors (if None all scans are takes as anchors). 43 | # Use some small displacement to ensure there's only one scan if the vehicle does not move 44 | 45 | tuples = {} # Dictionary of training tuples: tuples[ndx] = (sef ot positives, set of non negatives) 46 | for anchor_ndx in tqdm.tqdm(range(len(ds))): 47 | anchor_pos = ds.get_xy()[anchor_ndx] 48 | 49 | # Find timestamps of positive and negative elements 50 | positives = ds.find_neighbours_ndx(anchor_pos, pos_threshold) 51 | non_negatives = ds.find_neighbours_ndx(anchor_pos, neg_threshold) 52 | # Remove anchor element from positives, but leave it in non_negatives 53 | positives = positives[positives != anchor_ndx] 54 | 55 | # Sort ascending order 56 | positives = np.sort(positives) 57 | non_negatives = np.sort(non_negatives) 58 | 59 | # ICP pose refinement 60 | fitness_l = [] 61 | inlier_rmse_l = [] 62 | positive_poses = {} 63 | 64 | if DEBUG: 65 | # Use ground truth transform without pose refinement 66 | anchor_pose = ds.poses[anchor_ndx] 67 | for positive_ndx in positives: 68 | positive_pose = ds.poses[positive_ndx] 69 | # Compute initial relative pose 70 | m, fitness, inlier_rmse = relative_pose(anchor_pose, positive_pose), 1., 1. 71 | fitness_l.append(fitness) 72 | inlier_rmse_l.append(inlier_rmse) 73 | positive_poses[positive_ndx] = m 74 | else: 75 | anchor_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[anchor_ndx])) 76 | anchor_pose = ds.poses[anchor_ndx] 77 | for positive_ndx in positives: 78 | positive_pc = load_pc(os.path.join(ds.dataset_root, ds.rel_scan_filepath[positive_ndx])) 79 | positive_pose = ds.poses[positive_ndx] 80 | # Compute initial relative pose 81 | transform = relative_pose(anchor_pose, positive_pose) 82 | # Refine the pose using ICP 83 | m, fitness, inlier_rmse = icp(anchor_pc, positive_pc, transform) 84 | 85 | fitness_l.append(fitness) 86 | inlier_rmse_l.append(inlier_rmse) 87 | positive_poses[positive_ndx] = m 88 | 89 | # Tuple(id: int, timestamp: int, rel_scan_filepath: str, positives: List[int], non_negatives: List[int]) 90 | tuples[anchor_ndx] = TrainingTuple(id=anchor_ndx, timestamp=ds.timestamps[anchor_ndx], 91 | rel_scan_filepath=ds.rel_scan_filepath[anchor_ndx], 92 | positives=positives, non_negatives=non_negatives, pose=anchor_pose, 93 | positives_poses=positive_poses) 94 | 95 | print(f'{len(tuples)} training tuples generated') 96 | print('ICP pose refimenement stats:') 97 | print(f'Fitness - min: {np.min(fitness_l):0.3f} mean: {np.mean(fitness_l):0.3f} max: {np.max(fitness_l):0.3f}') 98 | print(f'Inlier RMSE - min: {np.min(inlier_rmse_l):0.3f} mean: {np.mean(inlier_rmse_l):0.3f} max: {np.max(inlier_rmse_l):0.3f}') 99 | 100 | return tuples 101 | 102 | 103 | if __name__ == '__main__': 104 | parser = argparse.ArgumentParser(description='Generate training tuples') 105 | parser.add_argument('--dataset_root', type=str, required=True) 106 | parser.add_argument('--pos_threshold', default=2) 107 | parser.add_argument('--neg_threshold', default=10) 108 | parser.add_argument('--min_displacement', type=float, default=0.2) 109 | args = parser.parse_args() 110 | 111 | sequences = ['Sejong01', 'Sejong02'] 112 | if DEBUG: 113 | sequences = ['ParkingLot', 'ParkingLot'] 114 | 115 | print(f'Dataset root: {args.dataset_root}') 116 | print(f'Sequences: {sequences}') 117 | print(f'Threshold for positive examples: {args.pos_threshold}') 118 | print(f'Threshold for negative examples: {args.neg_threshold}') 119 | print(f'Minimum displacement between consecutive anchors: {args.min_displacement}') 120 | 121 | ds = MulranSequences(args.dataset_root, sequences, split='train', min_displacement=args.min_displacement) 122 | train_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold) 123 | pickle_name = f'train_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle' 124 | train_tuples_filepath = os.path.join(args.dataset_root, pickle_name) 125 | pickle.dump(train_tuples, open(train_tuples_filepath, 'wb')) 126 | train_tuples = None 127 | 128 | ds = MulranSequences(args.dataset_root, sequences, split='test', min_displacement=args.min_displacement) 129 | test_tuples = generate_training_tuples(ds, args.pos_threshold, args.neg_threshold) 130 | pickle_name = f'val_{sequences[0]}_{sequences[1]}_{args.pos_threshold}_{args.neg_threshold}.pickle' 131 | test_tuples_filepath = os.path.join(args.dataset_root, pickle_name) 132 | pickle.dump(test_tuples, open(test_tuples_filepath, 'wb')) 133 | -------------------------------------------------------------------------------- /datasets/mulran/mulran_raw.py: -------------------------------------------------------------------------------- 1 | # Functions and classes operating on a raw Mulran dataset 2 | 3 | import numpy as np 4 | import os 5 | from typing import List 6 | from torch.utils.data import Dataset, ConcatDataset 7 | from sklearn.neighbors import KDTree 8 | import torch 9 | 10 | from datasets.mulran.utils import read_lidar_poses, in_test_split, in_train_split 11 | from misc.point_clouds import PointCloudLoader 12 | 13 | 14 | class MulranPointCloudLoader(PointCloudLoader): 15 | def set_properties(self): 16 | # Set point cloud properties, such as ground_plane_level. 17 | self.ground_plane_level = -0.9 18 | 19 | def read_pc(self, file_pathname: str) -> torch.Tensor: 20 | # Reads the point cloud without pre-processing 21 | # Returns Nx3 tensor 22 | pc = np.fromfile(file_pathname, dtype=np.float32) 23 | # PC in Mulran is of size [num_points, 4] -> x,y,z,reflectance 24 | pc = np.reshape(pc, (-1, 4))[:, :3] 25 | return pc 26 | 27 | 28 | class MulranSequence(Dataset): 29 | """ 30 | Dataset returns a point cloud from a train or test split from one sequence from a raw Mulran dataset 31 | """ 32 | def __init__(self, dataset_root: str, sequence_name: str, split: str, min_displacement: float = 0.2): 33 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}' 34 | assert split in ['train', 'test', 'all'] 35 | 36 | self.dataset_root = dataset_root 37 | self.sequence_name = sequence_name 38 | sequence_path = os.path.join(self.dataset_root, self.sequence_name) 39 | assert os.path.exists(sequence_path), f'Cannot access sequence: {sequence_path}' 40 | self.split = split 41 | self.min_displacement = min_displacement 42 | # Maximum discrepancy between timestamps of LiDAR scan and global pose in seconds 43 | self.pose_time_tolerance = 1. 44 | 45 | self.pose_file = os.path.join(sequence_path, 'global_pose.csv') 46 | assert os.path.exists(self.pose_file), f'Cannot access global pose file: {self.pose_file}' 47 | 48 | self.rel_lidar_path = os.path.join(self.sequence_name, 'Ouster') 49 | lidar_path = os.path.join(self.dataset_root, self.rel_lidar_path) 50 | assert os.path.exists(lidar_path), f'Cannot access lidar scans: {lidar_path}' 51 | self.pc_loader = MulranPointCloudLoader() 52 | 53 | timestamps, poses = read_lidar_poses(self.pose_file, lidar_path, self.pose_time_tolerance) 54 | self.timestamps, self.poses = self.filter(timestamps, poses) 55 | self.rel_scan_filepath = [os.path.join(self.rel_lidar_path, str(e) + '.bin') for e in self.timestamps] 56 | 57 | assert len(self.timestamps) == len(self.poses) 58 | assert len(self.timestamps) == len(self.rel_scan_filepath) 59 | print(f'{len(self.timestamps)} scans in {sequence_name}-{split}') 60 | 61 | def __len__(self): 62 | return len(self.rel_scan_filepath) 63 | 64 | def __getitem__(self, ndx): 65 | reading_filepath = os.path.join(self.dataset_root, self.rel_scan_filepath[ndx]) 66 | reading = self.pc_loader(reading_filepath) 67 | return {'pc': reading, 'pose': self.poses[ndx], 'ts': self.timestamps[ndx], 68 | 'position': self.poses[ndx][:2, 3]} 69 | 70 | def filter(self, ts: np.ndarray, poses: np.ndarray): 71 | # Filter out scans - retain only scans within a given split with minimum displacement 72 | positions = poses[:, :2, 3] 73 | 74 | # Retain elements in the given split 75 | # Only sejong sequence has train/test split 76 | if self.split != 'all' and self.sequence_name.lower()[:6] == 'sejong': 77 | if self.split == 'train': 78 | mask = in_train_split(positions) 79 | elif self.split == 'test': 80 | mask = in_test_split(positions) 81 | 82 | ts = ts[mask] 83 | poses = poses[mask] 84 | positions = positions[mask] 85 | #print(f'Split: {self.split} Mask len: {len(mask)} Mask True: {np.sum(mask)}') 86 | 87 | # Filter out scans - retain only scans within a given split 88 | prev_position = None 89 | mask = [] 90 | for ndx, position in enumerate(positions): 91 | if prev_position is None: 92 | mask.append(ndx) 93 | else: 94 | displacement = np.linalg.norm(prev_position - position) 95 | if displacement > self.min_displacement: 96 | mask.append(ndx) 97 | prev_position = position 98 | 99 | ts = ts[mask] 100 | poses = poses[mask] 101 | return ts, poses 102 | 103 | 104 | class MulranSequences(Dataset): 105 | """ 106 | Multiple Mulran sequences indexed as a single dataset. Each element is identified by a unique global index. 107 | """ 108 | def __init__(self, dataset_root: str, sequence_names: List[str], split: str, min_displacement: float = 0.2): 109 | assert len(sequence_names) > 0 110 | assert os.path.exists(dataset_root), f'Cannot access dataset root: {dataset_root}' 111 | assert split in ['train', 'test', 'all'] 112 | 113 | self.dataset_root = dataset_root 114 | self.sequence_names = sequence_names 115 | self.split = split 116 | self.min_displacement = min_displacement 117 | 118 | sequences = [] 119 | for seq_name in self.sequence_names: 120 | ds = MulranSequence(self.dataset_root, seq_name, split=split, min_displacement=min_displacement) 121 | sequences.append(ds) 122 | 123 | self.dataset = ConcatDataset(sequences) 124 | 125 | # Concatenate positions from all sequences 126 | self.poses = np.zeros((len(self.dataset), 4, 4), dtype=np.float64) 127 | self.timestamps = np.zeros((len(self.dataset),), dtype=np.int64) 128 | self.rel_scan_filepath = [] 129 | 130 | for cum_size, ds in zip(self.dataset.cumulative_sizes, sequences): 131 | # Consolidated lidar positions, timestamps and relative filepaths 132 | self.poses[cum_size - len(ds): cum_size, :] = ds.poses 133 | self.timestamps[cum_size - len(ds): cum_size] = ds.timestamps 134 | self.rel_scan_filepath.extend(ds.rel_scan_filepath) 135 | 136 | assert len(self.timestamps) == len(self.poses) 137 | assert len(self.timestamps) == len(self.rel_scan_filepath) 138 | 139 | # Build a kdtree based on X, Y position 140 | self.kdtree = KDTree(self.get_xy()) 141 | 142 | def __len__(self): 143 | return len(self.dataset) 144 | 145 | def __getitem__(self, ndx): 146 | return self.dataset[ndx] 147 | 148 | def get_xy(self): 149 | # Get X, Y position from (4, 4) pose 150 | return self.poses[:, :2, 3] 151 | 152 | def find_neighbours_ndx(self, position, radius): 153 | # Returns indices of neighbourhood point clouds for a given position 154 | assert position.ndim == 1 155 | assert position.shape[0] == 2 156 | # Reshape into (1, 2) axis 157 | position = position.reshape(1, -1) 158 | neighbours = self.kdtree.query_radius(position, radius)[0] 159 | return neighbours.astype(np.int32) 160 | 161 | 162 | if __name__ == '__main__': 163 | dataset_root = '/media/sf_Datasets/MulRan' 164 | sequence_names = ['Sejong01'] 165 | 166 | db = MulranSequences(dataset_root, sequence_names, split='train') 167 | print(f'Number of scans in the sequence: {len(db)}') 168 | e = db[0] 169 | 170 | res = db.find_neighbours_ndx(e['position'], radius=50) 171 | print('.') 172 | 173 | 174 | 175 | -------------------------------------------------------------------------------- /datasets/mulran/mulran_train.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | # Dataset wrapper for Mulran lidar scans dataset 3 | 4 | import os 5 | import random 6 | import numpy as np 7 | import torch 8 | 9 | from datasets.base_datasets import TrainingDataset 10 | from datasets.quantization import Quantizer 11 | from misc.poses import apply_transform 12 | from datasets.base_datasets import TrainingDataset 13 | 14 | DEBUG = False 15 | 16 | 17 | class MulranTraining6DOFDataset(TrainingDataset): 18 | """ 19 | Dataset wrapper for Mulran dataset for 6dof estimation. 20 | """ 21 | def __init__(self, dataset_path: str, query_filename: str, quantizer: Quantizer, 22 | rot_max: float = 0., trans_max: float = 0., **vargs): 23 | dataset_type = 'mulran' 24 | super().__init__(dataset_path, dataset_type, query_filename, **vargs) 25 | self.quantizer = quantizer 26 | self.rot_max = rot_max 27 | self.trans_max = trans_max 28 | 29 | def __getitem__(self, ndx): 30 | # pose is a global coordinate system pose 3x4 R|T matrix 31 | query_pc, _ = super().__getitem__(ndx) 32 | 33 | # get random positive 34 | positives = self.get_positives(ndx) 35 | positive_idx = np.random.choice(positives, 1)[0] 36 | positive_pc, _ = super().__getitem__(positive_idx) 37 | 38 | # get relative pose taking two global poses 39 | transform = self.queries[ndx].positives_poses[positive_idx] 40 | 41 | # Apply random transform to the positive point cloud 42 | rotation_angle = np.random.uniform(low=-self.rot_max, high=self.rot_max) 43 | cosval = np.cos(rotation_angle) 44 | sinval = np.sin(rotation_angle) 45 | m = torch.eye(4, dtype=torch.float) 46 | #m[:3, :3] = np.array([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]]) 47 | m[:3, :3] = torch.tensor([[cosval, sinval, 0.], [-sinval, cosval, 0.], [0., 0., 1.]], dtype=m.dtype) 48 | m[:2, 3] = torch.rand((1, 2)) * 2. * self.trans_max - self.trans_max 49 | positive_pc = apply_transform(positive_pc, m) 50 | transform = m @ transform 51 | 52 | # Find indices of unique quantized coordinates and filter out points to leave max 1 point per voxel 53 | coords1, idx1 = self.quantizer(query_pc) 54 | coords2, idx2 = self.quantizer(positive_pc) 55 | pc1_cop = query_pc[idx1, :] 56 | pc2_trans_cop = positive_pc[idx2, :] 57 | 58 | return pc1_cop, pc2_trans_cop, transform 59 | -------------------------------------------------------------------------------- /datasets/mulran/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from scipy.spatial import distance_matrix 4 | 5 | # Faulty point clouds (with 0 points) 6 | FAULTY_POINTCLOUDS = [1566279795718079314] 7 | 8 | # Coordinates of test region centres (in Sejong sequence) 9 | TEST_REGION_CENTRES = np.array([[345090.0743, 4037591.323], [345090.483, 4044700.04], 10 | [350552.0308, 4041000.71], [349252.0308, 4044800.71]]) 11 | 12 | # Radius of the test region 13 | TEST_REGION_RADIUS = 500 14 | 15 | # Boundary between training and test region - to ensure there's no overlap between training and test clouds 16 | TEST_TRAIN_BOUNDARY = 50 17 | 18 | 19 | def in_train_split(pos): 20 | # returns true if pos is in train split 21 | assert pos.ndim == 2 22 | assert pos.shape[1] == 2 23 | dist = distance_matrix(pos, TEST_REGION_CENTRES) 24 | mask = (dist > TEST_REGION_RADIUS + TEST_TRAIN_BOUNDARY).all(axis=1) 25 | return mask 26 | 27 | 28 | def in_test_split(pos): 29 | # returns true if position is in evaluation split 30 | assert pos.ndim == 2 31 | assert pos.shape[1] == 2 32 | dist = distance_matrix(pos, TEST_REGION_CENTRES) 33 | mask = (dist < TEST_REGION_RADIUS).any(axis=1) 34 | return mask 35 | 36 | 37 | def find_nearest_ndx(ts, timestamps): 38 | ndx = np.searchsorted(timestamps, ts) 39 | if ndx == 0: 40 | return ndx 41 | elif ndx == len(timestamps): 42 | return ndx - 1 43 | else: 44 | assert timestamps[ndx-1] <= ts <= timestamps[ndx] 45 | if ts - timestamps[ndx-1] < timestamps[ndx] - ts: 46 | return ndx - 1 47 | else: 48 | return ndx 49 | 50 | 51 | def read_lidar_poses(poses_filepath: str, lidar_filepath: str, pose_time_tolerance: float = 1.): 52 | # Read global poses from .csv file and link each lidar_scan with the nearest pose 53 | # threshold: threshold in seconds 54 | # Returns a dictionary with (4, 4) pose matrix indexed by a timestamp (as integer) 55 | 56 | with open(poses_filepath, "r") as h: 57 | txt_poses = h.readlines() 58 | 59 | n = len(txt_poses) 60 | system_timestamps = np.zeros((n,), dtype=np.int64) 61 | poses = np.zeros((n, 4, 4), dtype=np.float64) # 4x4 pose matrix 62 | 63 | for ndx, pose in enumerate(txt_poses): 64 | # Split by comma and remove whitespaces 65 | temp = [e.strip() for e in pose.split(',')] 66 | assert len(temp) == 13, f'Invalid line in global poses file: {temp}' 67 | system_timestamps[ndx] = int(temp[0]) 68 | poses[ndx] = np.array([[float(temp[1]), float(temp[2]), float(temp[3]), float(temp[4])], 69 | [float(temp[5]), float(temp[6]), float(temp[7]), float(temp[8])], 70 | [float(temp[9]), float(temp[10]), float(temp[11]), float(temp[12])], 71 | [0., 0., 0., 1.]]) 72 | 73 | # Ensure timestamps and poses are sorted in ascending order 74 | sorted_ndx = np.argsort(system_timestamps, axis=0) 75 | system_timestamps = system_timestamps[sorted_ndx] 76 | poses = poses[sorted_ndx] 77 | 78 | # List LiDAR scan timestamps 79 | all_lidar_timestamps = [int(os.path.splitext(f)[0]) for f in os.listdir(lidar_filepath) if 80 | os.path.splitext(f)[1] == '.bin'] 81 | all_lidar_timestamps.sort() 82 | 83 | lidar_timestamps = [] 84 | lidar_poses = [] 85 | count_rejected = 0 86 | 87 | for ndx, lidar_ts in enumerate(all_lidar_timestamps): 88 | # Skip faulty point clouds 89 | if lidar_ts in FAULTY_POINTCLOUDS: 90 | continue 91 | 92 | # Find index of the closest timestamp 93 | closest_ts_ndx = find_nearest_ndx(lidar_ts, system_timestamps) 94 | delta = abs(system_timestamps[closest_ts_ndx] - lidar_ts) 95 | # Timestamp is in nanoseconds = 1e-9 second 96 | if delta > pose_time_tolerance * 1000000000: 97 | # Reject point cloud without corresponding pose 98 | count_rejected += 1 99 | continue 100 | 101 | lidar_timestamps.append(lidar_ts) 102 | lidar_poses.append(poses[closest_ts_ndx]) 103 | 104 | lidar_timestamps = np.array(lidar_timestamps, dtype=np.int64) 105 | lidar_poses = np.array(lidar_poses, dtype=np.float64) # (northing, easting) position 106 | 107 | print(f'{len(lidar_timestamps)} scans with valid pose, {count_rejected} rejected due to unknown pose') 108 | return lidar_timestamps, lidar_poses 109 | 110 | 111 | def relative_pose(m1, m2): 112 | # SE(3) pose is 4x 4 matrix, such that 113 | # Pw = [R | T] @ [P] 114 | # [0 | 1] [1] 115 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame 116 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame 117 | # m2: coords in camera/lidar2 coords -> world coordinate frame 118 | # returns: relative pose of the first camera with respect to the second camera 119 | # transformation matrix to convert coords in camera/lidar1 reference frame to coords in 120 | # camera/lidar2 reference frame 121 | # 122 | m = np.linalg.inv(m2) @ m1 123 | # !!!!!!!!!! Fix for relative pose !!!!!!!!!!!!! 124 | m[:3, 3] = -m[:3, 3] 125 | return m 126 | -------------------------------------------------------------------------------- /datasets/quantization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | from abc import ABC, abstractmethod 4 | import torch 5 | import MinkowskiEngine as ME 6 | 7 | 8 | class Quantizer(ABC): 9 | @abstractmethod 10 | def __call__(self, pc): 11 | pass 12 | 13 | @abstractmethod 14 | def dequantize(self, coords): 15 | pass 16 | 17 | @abstractmethod 18 | def keypoint_position(self, supervoxel_centers, stride, kp_offset): 19 | pass 20 | 21 | 22 | class PolarQuantizer(Quantizer): 23 | def __init__(self, quant_step: List[float]): 24 | assert len(quant_step) == 3, '3 quantization steps expected: for sector (in degrees), ring and z-coordinate (in meters)' 25 | self.quant_step = torch.tensor(quant_step, dtype=torch.float) 26 | self.theta_range = int(360. // self.quant_step[0]) 27 | self.quant_step = torch.tensor(quant_step, dtype=torch.float) 28 | 29 | def __call__(self, pc): 30 | # Convert to polar coordinates and quantize with different step size for each coordinate 31 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z) 32 | assert pc.shape[1] == 3 33 | 34 | # theta is an angle in degrees in 0..360 range 35 | theta = 180. + torch.atan2(pc[:, 1], pc[:, 0]) * 180./np.pi 36 | # dist is a distance from a coordinate origin 37 | dist = torch.sqrt(pc[:, 0]**2 + pc[:, 1]**2) 38 | z = pc[:, 2] 39 | polar_pc = torch.stack([theta, dist, z], dim=1) 40 | # Scale each coordinate so after quantization with step 1. we got the required quantization step in each dim 41 | polar_pc = polar_pc / self.quant_step 42 | quantized_polar_pc, ndx = ME.utils.sparse_quantize(polar_pc, quantization_size=1., return_index=True) 43 | # Return quantized coordinates and index of selected elements 44 | return quantized_polar_pc, ndx 45 | 46 | def to_cartesian(self, pc): 47 | # Convert to radian in -180..180 range 48 | theta = np.pi * (pc[:, 0] - 180.) / 180. 49 | x = torch.cos(theta) * pc[:, 1] 50 | y = torch.sin(theta) * pc[:, 1] 51 | z = pc[:, 2] 52 | cartesian_pc = torch.stack([x, y, z], dim=1) 53 | return cartesian_pc 54 | 55 | def dequantize(self, coords): 56 | # Dequantize coords and convert to cartesian as (N, 3) tensor of floats 57 | pc = (0.5 + coords) * self.quant_step.to(coords.device) 58 | return self.to_cartesian(pc) 59 | 60 | def keypoint_position(self, supervoxel_centres, stride, kp_offset): 61 | # Add voxel center position: 0.5 * self.voxel_size 62 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values): 63 | # self.voxel_size + features * super_voxel_size / 2 64 | device = supervoxel_centres.device 65 | supervoxel_centres = (supervoxel_centres + 0.5) * self.quant_step.to(device) 66 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * \ 67 | self.quant_step.to(device) 68 | #kp_pos = supervoxel_centres 69 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2. 70 | 71 | kp_pos = self.to_cartesian(kp_pos) 72 | return kp_pos 73 | 74 | 75 | class CartesianQuantizer(Quantizer): 76 | def __init__(self, quant_step: float): 77 | self.quant_step = quant_step 78 | 79 | def __call__(self, pc): 80 | # Converts to polar coordinates and quantizes with different step size for each coordinate 81 | # pc: (N, 3) point cloud with Cartesian coordinates (X, Y, Z) 82 | assert pc.shape[1] == 3 83 | quantized_pc, ndx = ME.utils.sparse_quantize(pc, quantization_size=self.quant_step, return_index=True) 84 | # Return quantized coordinates and index of selected elements 85 | return quantized_pc, ndx 86 | 87 | def dequantize(self, coords): 88 | # Dequantize coords and return as (N, 3) tensor of floats 89 | # Use coords of the voxel center 90 | pc = (0.5 + coords) * self.quant_step 91 | return pc 92 | 93 | def keypoint_position(self, supervoxel_centers, stride, kp_offset): 94 | # Add voxel center position: 0.5 * self.voxel_size 95 | # to offset from the supervoxel centre value (in -1..1 range converted to absolute values): 96 | # self.voxel_size + features * super_voxel_size / 2 97 | supervoxel_centres = (supervoxel_centers + 0.5) * self.quant_step 98 | supervoxel_size = torch.tensor(stride, dtype=torch.float, device=supervoxel_centres.device) * self.quant_step 99 | if kp_offset is not None: 100 | kp_pos = supervoxel_centres + kp_offset * supervoxel_size / 2. 101 | else: 102 | kp_pos = supervoxel_centres 103 | return kp_pos 104 | 105 | 106 | if __name__ == "__main__": 107 | n = 1000 108 | cart = torch.rand((n, 3), dtype=torch.float) 109 | cart[:, 0] = cart[:, 0] * 200. - 100. 110 | cart[:, 1] = cart[:, 1] * 200. - 100. 111 | cart[:, 2] = cart[:, 2] * 30. - 10. 112 | 113 | quantizer = PolarQuantizer([0.5, 0.3, 0.2]) 114 | polar_quant, ndx = quantizer(cart) 115 | back2cart = quantizer.dequantize(polar_quant) 116 | cart_filtered = cart[ndx] 117 | dist = torch.norm(back2cart - cart_filtered, dim=1) 118 | print(f'Residual error - min: {torch.min(dist):0.5f} max: {torch.max(dist):0.5f} mean: {torch.mean(dist):0.5f}') 119 | 120 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import random 4 | import copy 5 | from torch.utils.data import Sampler 6 | 7 | from datasets.base_datasets import TrainingDataset 8 | 9 | VERBOSE = False 10 | 11 | 12 | class ListDict(object): 13 | def __init__(self, items=None): 14 | if items is not None: 15 | self.items = copy.deepcopy(items) 16 | self.item_to_position = {item: ndx for ndx, item in enumerate(items)} 17 | else: 18 | self.items = [] 19 | self.item_to_position = {} 20 | 21 | def add(self, item): 22 | if item in self.item_to_position: 23 | return 24 | self.items.append(item) 25 | self.item_to_position[item] = len(self.items)-1 26 | 27 | def remove(self, item): 28 | position = self.item_to_position.pop(item) 29 | last_item = self.items.pop() 30 | if position != len(self.items): 31 | self.items[position] = last_item 32 | self.item_to_position[last_item] = position 33 | 34 | def choose_random(self): 35 | return random.choice(self.items) 36 | 37 | def __contains__(self, item): 38 | return item in self.item_to_position 39 | 40 | def __iter__(self): 41 | return iter(self.items) 42 | 43 | def __len__(self): 44 | return len(self.items) 45 | 46 | 47 | class BatchSampler(Sampler): 48 | # Sampler returning list of indices to form a mini-batch 49 | # Samples elements in groups consisting of k=2 similar elements (positives) 50 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 51 | def __init__(self, dataset: TrainingDataset, batch_size: int, batch_size_limit: int = None, 52 | batch_expansion_rate: float = None, max_batches: int = None): 53 | if batch_expansion_rate is not None: 54 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 55 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size' 56 | 57 | self.batch_size = batch_size 58 | self.batch_size_limit = batch_size_limit 59 | self.batch_expansion_rate = batch_expansion_rate 60 | self.max_batches = max_batches 61 | self.dataset = dataset 62 | self.k = 2 # Number of positive examples per group must be 2 63 | if self.batch_size < 2 * self.k: 64 | self.batch_size = 2 * self.k 65 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size)) 66 | 67 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch) 68 | self.elems_ndx = list(self.dataset.queries) # List of point cloud indexes 69 | 70 | def __iter__(self): 71 | # Re-generate batches every epoch 72 | self.generate_batches() 73 | for batch in self.batch_idx: 74 | yield batch 75 | 76 | def __len(self): 77 | return len(self.batch_idx) 78 | 79 | def expand_batch(self): 80 | if self.batch_expansion_rate is None: 81 | print('WARNING: batch_expansion_rate is None') 82 | return 83 | 84 | if self.batch_size >= self.batch_size_limit: 85 | return 86 | 87 | old_batch_size = self.batch_size 88 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 89 | self.batch_size = min(self.batch_size, self.batch_size_limit) 90 | print('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size)) 91 | 92 | def generate_batches(self): 93 | # Generate training/evaluation batches. 94 | # batch_idx holds indexes of elements in each batch as a list of lists 95 | self.batch_idx = [] 96 | 97 | unused_elements_ndx = ListDict(self.elems_ndx) 98 | current_batch = [] 99 | 100 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class' 101 | 102 | while True: 103 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0: 104 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 105 | # elements to process 106 | if len(current_batch) >= 2*self.k: 107 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 108 | # to find negative examples in the batch 109 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch)) 110 | self.batch_idx.append(current_batch) 111 | current_batch = [] 112 | if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches): 113 | break 114 | if len(unused_elements_ndx) == 0: 115 | break 116 | 117 | # Add k=2 similar elements to the batch 118 | selected_element = unused_elements_ndx.choose_random() 119 | unused_elements_ndx.remove(selected_element) 120 | positives = self.dataset.get_positives(selected_element) 121 | if len(positives) == 0: 122 | # Broken dataset element without any positives 123 | continue 124 | 125 | unused_positives = [e for e in positives if e in unused_elements_ndx] 126 | # If there're unused elements similar to selected_element, sample from them 127 | # otherwise sample from all similar elements 128 | if len(unused_positives) > 0: 129 | second_positive = random.choice(unused_positives) 130 | unused_elements_ndx.remove(second_positive) 131 | else: 132 | second_positive = random.choice(list(positives)) 133 | 134 | current_batch += [selected_element, second_positive] 135 | 136 | for batch in self.batch_idx: 137 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch)) 138 | 139 | 140 | if __name__ == '__main__': 141 | pass 142 | 143 | -------------------------------------------------------------------------------- /datasets/southbay/generate_evaluation_sets.py: -------------------------------------------------------------------------------- 1 | # Generate evaluation sets 2 | # - Map point clouds are taken from MapData folder 3 | # - Query point clouds are taken from TestData 4 | # For each area (BaylandsToSeafood, ColumbiaPark, HighWay237, MathildaAVE, SanJoseDowntown, SunnyvaleBigloop) a 5 | # separate evaluation set is crated. We do not match clouds from different areas. 6 | 7 | import argparse 8 | import numpy as np 9 | from typing import List 10 | import os 11 | 12 | from datasets.southbay.southbay_raw import SouthBayDataset 13 | from datasets.base_datasets import EvaluationTuple, EvaluationSet 14 | from datasets.dataset_utils import filter_query_elements 15 | 16 | 17 | def get_scans(ds: SouthBayDataset, split: str, area: str, min_displacement: float = 0.1) -> List[EvaluationTuple]: 18 | elems = [] 19 | for ndx in ds.location_ndx[split][area]: 20 | pose = ds.global_ndx[ndx].pose 21 | position = pose[0:2, 3] # (x, y) position in global coordinate frame 22 | rel_scan_filepath = ds.global_ndx[ndx].rel_scan_filepath 23 | timestamp = ds.global_ndx[ndx].timestamp 24 | 25 | item = EvaluationTuple(timestamp, rel_scan_filepath, position=position, pose=pose) 26 | elems.append(item) 27 | 28 | print(f"{len(elems)} total elements in {split} split") 29 | 30 | # Filter-out elements leaving only 1 per grid cell with min_displacement size 31 | pos = np.zeros((len(elems), 2), dtype=np.float32) 32 | for ndx, e in enumerate(elems): 33 | pos[ndx] = e.position 34 | 35 | # Quantize x-y coordinates. Quantized coords start from 0 36 | pos = np.floor(pos / min_displacement) 37 | pos = pos.astype(int) 38 | _, unique_ndx = np.unique(pos, axis=0, return_index=True) 39 | 40 | # Leave only unique elements 41 | elems = [elems[i] for i in unique_ndx] 42 | print(f"{len(elems)} filtered elements in {split} split with grid cell size = {min_displacement}") 43 | 44 | return elems 45 | 46 | 47 | def generate_evaluation_set(ds: SouthBayDataset, area: str, min_displacement: float = 0.1, dist_threshold=5) -> \ 48 | EvaluationSet: 49 | map_set = get_scans(ds, 'MapData', area, min_displacement) 50 | query_set = get_scans(ds, 'TestData', area, min_displacement) 51 | query_set = filter_query_elements(query_set, map_set, dist_threshold) 52 | print(f'Area: {area} - {len(map_set)} database elements, {len(query_set)} query elements\n') 53 | return EvaluationSet(query_set, map_set) 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser(description='Generate evaluation sets for Apollo SouthBay dataset') 58 | parser.add_argument('--dataset_root', type=str, required=True) 59 | parser.add_argument('--min_displacement', type=float, default=1.0) 60 | # Ignore query elements that do not have a corresponding map element within the given threshold (in meters) 61 | parser.add_argument('--dist_threshold', type=float, default=5) 62 | 63 | args = parser.parse_args() 64 | print(f'Dataset root: {args.dataset_root}') 65 | print(f'Minimum displacement between scans in each set (map/query): {args.min_displacement}') 66 | print(f'Ignore query elements without a corresponding map element within a threshold [m]: {args.dist_threshold}') 67 | 68 | ds = SouthBayDataset(args.dataset_root) 69 | ds.print_info() 70 | 71 | min_displacement = args.min_displacement 72 | 73 | area = 'SunnyvaleBigloop' # Evaluation area 74 | assert area in ds.location_ndx['TestData'] 75 | eval_set = generate_evaluation_set(ds, area, min_displacement=min_displacement, 76 | dist_threshold=args.dist_threshold) 77 | pickle_name = f'test_{area}_{args.min_displacement}_{args.dist_threshold}.pickle' 78 | file_path_name = os.path.join(args.dataset_root, pickle_name) 79 | eval_set.save(file_path_name) 80 | -------------------------------------------------------------------------------- /datasets/southbay/generate_training_tuples.py: -------------------------------------------------------------------------------- 1 | # Generate training triplets 2 | 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | import tqdm 7 | import os 8 | 9 | from datasets.southbay.southbay_raw import SouthBayDataset 10 | from datasets.base_datasets import TrainingTuple 11 | 12 | 13 | class Triplet: 14 | def __init__(self, anchor: int, positives: np.ndarray, non_negatives: np.ndarray): 15 | self.anchor = anchor 16 | self.positives = positives 17 | self.non_negatives = non_negatives 18 | 19 | 20 | def generate_triplets(ds: SouthBayDataset, map_split: str, query_split: str, 21 | positives_th: int = 2, negatives_th: int = 10, min_displacement: float = 0.1): 22 | # All elements (anchors, positives and negatives) are taken from both map_split and query_split 23 | assert positives_th < negatives_th 24 | 25 | # Create a master table with positions of all point clouds from the query split and in map_split 26 | pc_ids, pc_poses = ds.get_poses2([query_split, map_split]) 27 | pc_coords = pc_poses[:, :3, 3] 28 | 29 | # Quantize x-y coordinates 30 | pos = np.floor(pc_coords / min_displacement) 31 | pos = pos.astype(int) 32 | _, unique_ndx = np.unique(pos, axis=0, return_index=True) 33 | 34 | # Leave only unique elements 35 | pc_ids = pc_ids[unique_ndx] 36 | pc_coords = pc_coords[unique_ndx] 37 | print(f'{len(pc_ids)} point clouds left from {len(pc_poses)} after filtering with min_displacement={min_displacement}') 38 | 39 | triplets = [] 40 | count_zero_positives = 0 41 | for anchor_id in tqdm.tqdm(pc_ids): 42 | anchor_coords = ds.global_ndx[anchor_id].pose[:3, 3] 43 | dist = np.linalg.norm(pc_coords - anchor_coords, axis=1) 44 | positives_mask = dist <= positives_th 45 | non_negatives_mask = dist <= negatives_th 46 | 47 | positives_ndx = pc_ids[positives_mask] 48 | # remove anchor_id from positives 49 | positives_ndx = np.array([e for e in positives_ndx if e != anchor_id]) 50 | non_negatives_ndx = pc_ids[non_negatives_mask] 51 | 52 | if len(positives_ndx) == 0: 53 | # Skip examples without positives 54 | count_zero_positives += 1 55 | continue 56 | 57 | t = Triplet(anchor_id, positives_ndx, non_negatives_ndx) 58 | triplets.append(t) 59 | 60 | print(f'{count_zero_positives} filtered out due to no positives') 61 | print(f'{len(triplets)} training tuples generated') 62 | 63 | # Remove ids from positives and negatives that are not anchors 64 | anchors_set = set([e.anchor for e in triplets]) 65 | triplets = [Triplet(e.anchor, [p for p in e.positives if p in anchors_set], 66 | [nn for nn in e.non_negatives if nn in anchors_set]) for e in triplets] 67 | 68 | # All used global ids 69 | used_ids = set() 70 | for triplet in triplets: 71 | used_ids.add(triplet.anchor) 72 | used_ids.update(list(triplet.positives)) 73 | used_ids.update(list(triplet.non_negatives)) 74 | 75 | # New ids, consecutive and starting from 0 76 | new_ids = {old_ndx: ndx for ndx, old_ndx in enumerate(used_ids)} 77 | 78 | tuples = {} 79 | for triplet in triplets: 80 | new_anchor_ndx = new_ids[triplet.anchor] 81 | pc = ds.global_ndx[triplet.anchor] 82 | positives = np.array([new_ids[e] for e in triplet.positives], dtype=np.int32) 83 | non_negatives = np.array([new_ids[e] for e in triplet.non_negatives], dtype=np.int32) 84 | 85 | # Sort ascending order 86 | positives = np.sort(positives) 87 | non_negatives = np.sort(non_negatives) 88 | 89 | tuple = TrainingTuple(id=new_anchor_ndx, timestamp=pc.timestamp, 90 | rel_scan_filepath=pc.rel_scan_filepath, 91 | positives=positives, non_negatives=non_negatives, 92 | pose=pc.pose, positives_poses=None) 93 | tuples[new_anchor_ndx] = tuple 94 | 95 | return tuples 96 | 97 | 98 | if __name__ == '__main__': 99 | parser = argparse.ArgumentParser(description='Generate training triplets for Apollo SouthBay dataset') 100 | parser.add_argument('--dataset_root', type=str, required=True, help='Path to Apollo SouthBay root folder') 101 | parser.add_argument('--pos_th', type=float, default=2, help='Positives threshold') 102 | parser.add_argument('--neg_th', type=float, default=10, help='Negatives threshold') 103 | parser.add_argument('--min_displacement', type=float, default=1.0) 104 | 105 | query_split = 'TrainData' 106 | 107 | args = parser.parse_args() 108 | print(f'Dataset root folder: {args.dataset_root}') 109 | print(f'Split for positives/negatives: {query_split}') 110 | print(f'Positives threshold: {args.pos_th}') 111 | print(f'Negatives threshold: {args.neg_th}') 112 | print(f'Minimum displacement between consecutive scans: {args.min_displacement}') 113 | 114 | ds = SouthBayDataset(args.dataset_root) 115 | ds.print_info() 116 | 117 | triplets = generate_triplets(ds, 'MapData', query_split, positives_th=args.pos_th, negatives_th=args.neg_th, 118 | min_displacement=args.min_displacement) 119 | print(f'{len(triplets)} anchors generated') 120 | 121 | pickle_name = f'train_southbay_{args.pos_th}_{args.neg_th}.pickle' 122 | pickle_filepath = os.path.join(args.dataset_root, pickle_name) 123 | pickle.dump(triplets, open(pickle_filepath, 'wb')) 124 | -------------------------------------------------------------------------------- /datasets/southbay/southbay_raw.py: -------------------------------------------------------------------------------- 1 | # Classes and functions to read Apollo SouthBay dataset 2 | 3 | import os 4 | import numpy as np 5 | import csv 6 | from typing import List 7 | 8 | import third_party.pypcd as pypcd 9 | import misc.poses as poses 10 | from misc.point_clouds import PointCloudLoader 11 | 12 | 13 | class GroundTruthPoses: 14 | def __init__(self, pose_filepath): 15 | assert os.path.isfile(pose_filepath), f'Cannot access pose file: {pose_filepath}' 16 | self.pose_filepath = pose_filepath 17 | self.pose_ndx = {} 18 | self.read_poses() 19 | 20 | def read_poses(self): 21 | with open(self.pose_filepath) as h: 22 | csv_reader = csv.reader(h, delimiter=' ') 23 | for ndx, row in enumerate(csv_reader): 24 | assert len(row) == 9, f'Incorrect format of row {ndx}: {row}' 25 | ndx = int(row[0]) 26 | ts = float(row[1]) 27 | x = float(row[2]) 28 | y = float(row[3]) 29 | z = float(row[4]) 30 | qx = float(row[5]) 31 | qy = float(row[6]) 32 | qz = float(row[7]) 33 | qr = float(row[8]) 34 | se3 = np.eye(4, dtype=np.float64) 35 | # Expects quaternions in w, x, y, z format 36 | se3[0:3, 0:3] = poses.q2r((qr, qx, qy, qz)) 37 | se3[0:3, 3] = np.array([x, y, z]) 38 | self.pose_ndx[ndx] = (se3, ts) # (pose, timestamp) 39 | 40 | 41 | class PointCloud: 42 | id: int = 0 # global PointCloud id (unique for each cloud) 43 | 44 | def __init__(self, rel_scan_filepath: str, pose: np.ndarray, timestamp: float): 45 | self.rel_scan_filepath = rel_scan_filepath 46 | self.pose = pose 47 | self.timestamp = timestamp 48 | filename = os.path.split(rel_scan_filepath)[1] 49 | # Relative point cloud ids start from 1 in each subfolder/traversal 50 | self.rel_id = int(os.path.splitext(filename)[0]) 51 | self.id = PointCloud.id 52 | PointCloud.id += 1 53 | 54 | 55 | class SouthBayDataset: 56 | def __init__(self, dataset_root): 57 | assert os.path.isdir(dataset_root), f'Cannot access directory: {dataset_root}' 58 | self.dataset_root = dataset_root 59 | 60 | self.splits = ['MapData', 'TestData', 'TrainData'] 61 | self.pcd_extension = '.pcd' # Point cloud extension 62 | 63 | # location_ndx[split][location] = [... list of global ids in this location ...] 64 | self.location_ndx = {} 65 | 66 | # pc_ndc[global_id] = PointCloud 67 | self.global_ndx = {} 68 | 69 | for split in self.splits: 70 | self.location_ndx[split] = {} 71 | self.index_split(split) 72 | 73 | def index_split(self, split): 74 | path = os.path.join(self.dataset_root, split) 75 | assert os.path.isdir(path), f"Missing split: {split}" 76 | 77 | # Index locations 78 | locations = os.listdir(path) 79 | locations = [f for f in locations if os.path.isdir(os.path.join(path, f))] 80 | locations.sort() 81 | for loc in locations: 82 | # Locations may contain multiple subfolders hierachy 83 | # All point clouds in all subfolders are stored as one list 84 | rel_working_path = os.path.join(split, loc) 85 | self.location_ndx[split][loc] = [] 86 | self.index_location(split, loc, rel_working_path) 87 | 88 | def index_location(self, split, loc, rel_working_path): 89 | working_path = os.path.join(self.dataset_root, rel_working_path) 90 | subfolders = os.listdir(working_path) 91 | if 'pcds' in subfolders and 'poses' in subfolders: 92 | # Process point clouds and poses 93 | rel_pcds_path = os.path.join(rel_working_path, 'pcds') 94 | poses_path = os.path.join(working_path, 'poses') 95 | poses_filepath = os.path.join(poses_path, 'gt_poses.txt') 96 | assert os.path.isfile(poses_filepath), f'Missing poses file: {poses_filepath}' 97 | tp = GroundTruthPoses(poses_filepath) 98 | for e in tp.pose_ndx: 99 | se3, ts = tp.pose_ndx[e] 100 | rel_pcd_filepath = os.path.join(rel_pcds_path, str(e) + self.pcd_extension) 101 | pcd_filepath = os.path.join(self.dataset_root, rel_pcd_filepath) 102 | if not os.path.exists(pcd_filepath): 103 | print(f'Missing pcd file: {pcd_filepath}') 104 | pc = PointCloud(rel_pcd_filepath, se3, ts) 105 | self.global_ndx[pc.id] = pc 106 | self.location_ndx[split][loc].append(pc.id) 107 | elif 'pcds' in subfolders or 'poses' in subfolders: 108 | assert False, 'Something wrong. Either pcds or poses folder is missing' 109 | 110 | # Recursively process other subfolders - check if they contain point data 111 | rel_subfolders = [os.path.join(rel_working_path, p) for p in subfolders] 112 | rel_subfolders = [p for p in rel_subfolders if os.path.isdir(os.path.join(self.dataset_root, p))] 113 | 114 | for sub in rel_subfolders: 115 | self.index_location(split, loc, sub) 116 | 117 | def print_info(self): 118 | print(f'Dataset root: {self.dataset_root}') 119 | print(f"Splits: {self.splits}") 120 | for split in self.location_ndx: 121 | locations = self.location_ndx[split].keys() 122 | print(f"Locations in {split}: {locations}") 123 | for loc in locations: 124 | pc_list = self.location_ndx[split][loc] 125 | print(f"{len(pc_list)} point clouds in location {split} - {loc}") 126 | print("") 127 | print(f'Last point cloud id: {PointCloud.id - 1}') 128 | 129 | def get_poses(self, split, location=None): 130 | # Get ids and poses of all point clouds from the given split and optionally within a location 131 | if location is None: 132 | locations = list(self.location_ndx[split]) 133 | else: 134 | locations = [location] 135 | 136 | # Count point clouds 137 | count_pc = 0 138 | for loc in locations: 139 | count_pc += len(self.location_ndx[split][loc]) 140 | 141 | # Point cloud global ids 142 | pc_ids = np.zeros(count_pc, dtype=np.int64) 143 | # Poses 144 | pc_poses = np.zeros((count_pc, 4, 4), dtype=np.float64) 145 | 146 | # Fill ids and pose tables 147 | n = 0 148 | for loc in locations: 149 | for pc_id in self.location_ndx[split][loc]: 150 | pc = self.global_ndx[pc_id] 151 | pc_ids[n] = pc_id 152 | pc_poses[n] = pc.pose 153 | n += 1 154 | 155 | return pc_ids, pc_poses 156 | 157 | def get_poses2(self, splits: List[str]): 158 | # Get ids and poses of all point clouds from the given splits 159 | 160 | locations = list(self.location_ndx[splits[0]]) 161 | print(f"Locations: {locations}") 162 | 163 | # Count point clouds 164 | count_pc = 0 165 | for split in splits: 166 | for loc in locations: 167 | count_pc += len(self.location_ndx[split][loc]) 168 | 169 | # Point cloud global ids 170 | pc_ids = np.zeros(count_pc, dtype=np.int32) 171 | # Poses 172 | pc_poses = np.zeros((count_pc, 4, 4), dtype=np.float64) 173 | 174 | # Fill ids and pose tables 175 | n = 0 176 | for split in splits: 177 | for loc in locations: 178 | for pc_id in self.location_ndx[split][loc]: 179 | pc = self.global_ndx[pc_id] 180 | pc_ids[n] = pc_id 181 | pc_poses[n] = pc.pose 182 | n += 1 183 | 184 | return pc_ids, pc_poses 185 | 186 | 187 | class SouthbayPointCloudLoader(PointCloudLoader): 188 | def set_properties(self): 189 | # Set point cloud propertiers, such as ground_plane_level. Must be defined in inherited classes. 190 | self.ground_plane_level = -1.6 191 | 192 | def read_pc(self, file_pathname): 193 | pc = pypcd.PointCloud.from_path(file_pathname) 194 | # pc.pc_data has the data as a structured array 195 | # pc.fields, pc.count, etc have the metadata 196 | pc = np.stack([pc.pc_data['x'], pc.pc_data['y'], pc.pc_data['z']], axis=1) 197 | # Replace naans with all-zero coords 198 | nan_mask = np.isnan(pc).any(axis=1) 199 | pc[nan_mask] = np.array([0., 0., 0.], dtype=np.float) 200 | return pc 201 | -------------------------------------------------------------------------------- /eval/evaluate_with_rotations.py: -------------------------------------------------------------------------------- 1 | # Evaluate rotational invariance of a global descriptor using a joint global-local model 2 | 3 | import argparse 4 | 5 | import numpy as np 6 | import tqdm 7 | import os 8 | import random 9 | import pickle 10 | from typing import List 11 | import torch 12 | import MinkowskiEngine as ME 13 | 14 | from models.model_factory import model_factory 15 | from misc.utils import ModelParams, get_datetime 16 | from datasets.base_datasets import EvaluationTuple, EvaluationSet, get_pointcloud_loader 17 | from datasets.quantization import Quantizer 18 | from eval.evaluate import Evaluator 19 | from datasets.augmentation import RandomRotation, Rotation 20 | from models.minkgl import MinkTrunk 21 | 22 | DEBUG = False 23 | 24 | 25 | class MyEvaluator(Evaluator): 26 | # Evaluation of MinkLoc-based with MinkDD methods on Mulan or Apollo SouthBay dataset 27 | def __init__(self, dataset_root: str, dataset_type: str, eval_set_pickle: str, device: str, 28 | radius: List[float], k: int = 20, n_samples=None, quantizer: Quantizer = None): 29 | super().__init__(dataset_root, dataset_type, eval_set_pickle, device, radius, k, n_samples) 30 | if DEBUG: 31 | self.eval_set.map_set = self.eval_set.map_set[:10] 32 | self.eval_set.query_set = self.eval_set.query_set[:10] 33 | 34 | assert quantizer is not None 35 | self.quantizer = quantizer 36 | 37 | def model2eval(self, models): 38 | # This method may be overloaded when model is a tuple consisting of a few models (as in Disco) 39 | [model.eval() for model in models] 40 | 41 | def evaluate(self, model, *args, **kwargs): 42 | map_embeddings = self.compute_embeddings(self.eval_set.map_set, model) 43 | global_metrics = {} 44 | rotations = np.arange(0., 181., 10.) 45 | 46 | for rotation in rotations: 47 | print(f'Rotation: {rotation}') 48 | query_embeddings = self.compute_embeddings(self.eval_set.query_set, model, rotation=rotation) 49 | 50 | map_positions = self.eval_set.get_map_positions() 51 | query_positions = self.eval_set.get_query_positions() 52 | 53 | if self.n_samples is None or len(query_embeddings) <= self.n_samples: 54 | query_indexes = list(range(len(query_embeddings))) 55 | self.n_samples = len(query_embeddings) 56 | else: 57 | query_indexes = random.sample(range(len(query_embeddings)), self.n_samples) 58 | 59 | # Dictionary to store the number of true positives (for global desc. metrics) for different radius and NN number 60 | global_metrics[rotation] = {'tp': {r: [0] * self.k for r in self.radius}} 61 | 62 | for query_ndx in tqdm.tqdm(query_indexes): 63 | # Check if the query element has a true match within each radius 64 | query_pos = query_positions[query_ndx] 65 | 66 | # Nearest neighbour search in the embedding space 67 | query_embedding = query_embeddings[query_ndx] 68 | embed_dist = np.linalg.norm(map_embeddings - query_embedding, axis=1) 69 | nn_ndx = np.argsort(embed_dist)[:self.k] 70 | 71 | # GLOBAL DESCRIPTOR EVALUATION 72 | # Euclidean distance between the query and nn 73 | # Here we use non-icp refined poses, but for the global descriptor it's fine 74 | delta = query_pos - map_positions[nn_ndx] # (k, 2) array 75 | euclid_dist = np.linalg.norm(delta, axis=1) # (k,) array 76 | # Count true positives for different radius and NN number 77 | global_metrics[rotation]['tp'] = {r: [global_metrics[rotation]['tp'][r][nn] + (1 if (euclid_dist[:nn + 1] <= r).any() else 0) for nn in range(self.k)] for r in self.radius} 78 | 79 | # Calculate mean metrics 80 | global_metrics[rotation]["recall"] = {r: [global_metrics[rotation]['tp'][r][nn] / self.n_samples for nn in range(self.k)] for r in self.radius} 81 | print(f"Recall@1 (radius: {self.radius[0]} m.): {global_metrics[rotation]['recall'][self.radius[0]]}") 82 | 83 | return global_metrics 84 | 85 | def compute_embeddings(self, eval_subset: List[EvaluationTuple], model, *args, **kwargs): 86 | self.model2eval((model,)) 87 | if 'rotation' in kwargs: 88 | # Rotation around z-axis 89 | rotation = RandomRotation(max_theta=kwargs['rotation'], max_theta2=None, axis=np.array([0, 0, 1])) 90 | else: 91 | rotation = None 92 | 93 | global_embeddings = None 94 | for ndx, e in tqdm.tqdm(enumerate(eval_subset)): 95 | scan_filepath = os.path.join(self.dataset_root, e.rel_scan_filepath) 96 | assert os.path.exists(scan_filepath) 97 | pc = self.pc_loader(scan_filepath) 98 | if rotation is not None: 99 | # Rotate around z-axis 100 | pc = rotation(pc) 101 | 102 | pc = torch.tensor(pc, dtype=torch.float) 103 | global_embedding = self.compute_embedding(pc, model) 104 | if global_embeddings is None: 105 | global_embeddings = np.zeros((len(eval_subset), global_embedding.shape[1]), dtype=global_embedding.dtype) 106 | global_embeddings[ndx] = global_embedding 107 | 108 | return global_embeddings 109 | 110 | def compute_embedding(self, pc, model, *args, **kwargs): 111 | """ 112 | Returns global embedding (np.array) as well as keypoints and corresponding descriptors (torch.tensors) 113 | """ 114 | coords, _ = self.quantizer(pc) 115 | with torch.no_grad(): 116 | bcoords = ME.utils.batched_coordinates([coords]) 117 | feats = torch.ones((bcoords.shape[0], 1), dtype=torch.float32) 118 | batch = {'coords': bcoords.to(self.device), 'features': feats.to(self.device)} 119 | 120 | # Compute global descriptor 121 | y = model(batch) 122 | global_embedding = y['global'].detach().cpu().numpy() 123 | 124 | return global_embedding 125 | 126 | def print_results(self, metrics): 127 | # Global descriptor results are saved with the last n_k entry 128 | for rotation in metrics: 129 | print(f'Rotation: {rotation}') 130 | recall = metrics[rotation]['recall'] 131 | for r in recall: 132 | print(f"Radius: {r} [m] : ", end='') 133 | for x in recall[r]: 134 | print("{:0.3f}, ".format(x), end='') 135 | print("") 136 | print("") 137 | 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser(description='Evaluate MinkLoc model') 141 | parser.add_argument('--dataset_root', type=str, required=True, help='Path to the dataset root') 142 | parser.add_argument('--eval_set', type=str, required=True, 143 | help='File name of the evaluation pickle (must be located in dataset_root') 144 | parser.add_argument('--radius', type=list, default=[5, 20], help='True Positive thresholds in meters') 145 | parser.add_argument('--n_samples', type=int, default=None, help='Number of elements sampled from the query sequence') 146 | parser.add_argument('--model_config', type=str, required=True, help='Path to the global model configuration file') 147 | parser.add_argument('--weights', type=str, default=None, help='Trained global model weights') 148 | 149 | args = parser.parse_args() 150 | print(f'Dataset root: {args.dataset_root}') 151 | print(f'Evaluation set: {args.eval_set}') 152 | print(f'Radius: {args.radius} [m]') 153 | print(f'Model config path: {args.model_config}') 154 | if args.weights is None: 155 | w = 'RANDOM WEIGHTS' 156 | else: 157 | w = args.weights 158 | print(f'Weights: {w}') 159 | print('') 160 | 161 | model_params = ModelParams(args.model_config) 162 | model_params.print() 163 | 164 | if torch.cuda.is_available(): 165 | device = "cuda" 166 | else: 167 | device = "cpu" 168 | print('Device: {}'.format(device)) 169 | 170 | model = model_factory(model_params) 171 | 172 | if args.weights is not None: 173 | assert os.path.exists(args.weights), 'Cannot open network weights: {}'.format(args.weights) 174 | print('Loading weights: {}'.format(args.weights)) 175 | model.load_state_dict(torch.load(args.weights, map_location=device)) 176 | model.to(device) 177 | 178 | evaluator = MyEvaluator(args.dataset_root, 'mulran', args.eval_set, device, radius=args.radius, 179 | n_samples=args.n_samples, quantizer=model_params.quantizer) 180 | metrics = evaluator.evaluate(model) 181 | with open('evaluate_with_rotations_results_' + get_datetime() + '.pickle', 'wb') as f: 182 | pickle.dump(metrics, f) 183 | evaluator.print_results(metrics) 184 | -------------------------------------------------------------------------------- /images/key1_.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jac99/Egonn/84d5cc2197c81792f3254b43d3273cdccf18294c/images/key1_.png -------------------------------------------------------------------------------- /images/keypoints_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jac99/Egonn/84d5cc2197c81792f3254b43d3273cdccf18294c/images/keypoints_vis.png -------------------------------------------------------------------------------- /images/pair2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jac99/Egonn/84d5cc2197c81792f3254b43d3273cdccf18294c/images/pair2.png -------------------------------------------------------------------------------- /images/registered_pairs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jac99/Egonn/84d5cc2197c81792f3254b43d3273cdccf18294c/images/registered_pairs.png -------------------------------------------------------------------------------- /layers/eca_block.py: -------------------------------------------------------------------------------- 1 | # Implementation of Efficient Channel Attention ECA block 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | import MinkowskiEngine as ME 7 | 8 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 9 | 10 | 11 | class ECALayer(nn.Module): 12 | def __init__(self, channels, gamma=2, b=1): 13 | super().__init__() 14 | t = int(abs((np.log2(channels) + b) / gamma)) 15 | k_size = t if t % 2 else t + 1 16 | self.avg_pool = ME.MinkowskiGlobalPooling() 17 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 18 | self.sigmoid = nn.Sigmoid() 19 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() 20 | 21 | def forward(self, x): 22 | # feature descriptor on the global spatial information 23 | y_sparse = self.avg_pool(x) 24 | 25 | # Apply 1D convolution along the channel dimension 26 | y = self.conv(y_sparse.F.unsqueeze(-1).transpose(-1, -2)).transpose(-1, -2).squeeze(-1) 27 | # y is (batch_size, channels) tensor 28 | 29 | # Multi-scale information fusion 30 | y = self.sigmoid(y) 31 | # y is (batch_size, channels) tensor 32 | 33 | y_sparse = ME.SparseTensor(y, coordinate_manager=y_sparse.coordinate_manager, 34 | coordinate_map_key=y_sparse.coordinate_map_key) 35 | # y must be features reduced to the origin 36 | return self.broadcast_mul(x, y_sparse) 37 | 38 | 39 | class ECABasicBlock(BasicBlock): 40 | def __init__(self, 41 | inplanes, 42 | planes, 43 | stride=1, 44 | dilation=1, 45 | downsample=None, 46 | dimension=3): 47 | super(ECABasicBlock, self).__init__( 48 | inplanes, 49 | planes, 50 | stride=stride, 51 | dilation=dilation, 52 | downsample=downsample, 53 | dimension=dimension) 54 | self.eca = ECALayer(planes, gamma=2, b=1) 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.norm1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.norm2(out) 65 | out = self.eca(out) 66 | 67 | if self.downsample is not None: 68 | residual = self.downsample(x) 69 | 70 | out += residual 71 | out = self.relu(out) 72 | 73 | return out 74 | -------------------------------------------------------------------------------- /layers/netvlad.py: -------------------------------------------------------------------------------- 1 | """ 2 | PointNet code taken from PointNetVLAD Pytorch implementation. 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data 8 | import torch.nn.functional as F 9 | import math 10 | 11 | """ 12 | NOTE: The toolbox can only pool lists of features of the same length. It was specifically optimized to efficiently 13 | do so. One way to handle multiple lists of features of variable length is to create, via a data augmentation 14 | technique, a tensor of shape: 'batch_size'x'max_samples'x'feature_size'. Where 'max_samples' would be the maximum 15 | number of feature per list. Then for each list, you would fill the tensor with 0 values. 16 | """ 17 | 18 | class NetVLADLoupe(nn.Module): 19 | def __init__(self, feature_size, cluster_size, output_dim, gating=True, add_batch_norm=True): 20 | super().__init__() 21 | self.feature_size = feature_size 22 | self.output_dim = output_dim 23 | self.gating = gating 24 | self.add_batch_norm = add_batch_norm 25 | self.cluster_size = cluster_size 26 | self.softmax = nn.Softmax(dim=-1) 27 | self.cluster_weights = nn.Parameter(torch.randn(feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 28 | self.cluster_weights2 = nn.Parameter(torch.randn(1, feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 29 | self.hidden1_weights = nn.Parameter( 30 | torch.randn(cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size)) 31 | 32 | if add_batch_norm: 33 | self.cluster_biases = None 34 | self.bn1 = nn.BatchNorm1d(cluster_size) 35 | else: 36 | self.cluster_biases = nn.Parameter(torch.randn(cluster_size) * 1 / math.sqrt(feature_size)) 37 | self.bn1 = None 38 | 39 | self.bn2 = nn.BatchNorm1d(output_dim) 40 | 41 | if gating: 42 | self.context_gating = GatingContext(output_dim, add_batch_norm=add_batch_norm) 43 | 44 | def forward(self, x): 45 | # Expects (batch_size, num_points, channels) tensor 46 | assert x.dim() == 3 47 | num_points = x.shape[1] 48 | activation = torch.matmul(x, self.cluster_weights) 49 | if self.add_batch_norm: 50 | # activation = activation.transpose(1,2).contiguous() 51 | activation = activation.view(-1, self.cluster_size) 52 | activation = self.bn1(activation) 53 | activation = activation.view(-1, num_points, self.cluster_size) 54 | # activation = activation.transpose(1,2).contiguous() 55 | else: 56 | activation = activation + self.cluster_biases 57 | activation = self.softmax(activation) 58 | activation = activation.view((-1, num_points, self.cluster_size)) 59 | 60 | a_sum = activation.sum(-2, keepdim=True) 61 | a = a_sum * self.cluster_weights2 62 | 63 | activation = torch.transpose(activation, 2, 1) 64 | x = x.view((-1, num_points, self.feature_size)) 65 | vlad = torch.matmul(activation, x) 66 | vlad = torch.transpose(vlad, 2, 1) 67 | vlad = vlad - a 68 | 69 | vlad = F.normalize(vlad, dim=1, p=2) 70 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size)) 71 | vlad = F.normalize(vlad, dim=1, p=2) 72 | 73 | vlad = torch.matmul(vlad, self.hidden1_weights) 74 | 75 | vlad = self.bn2(vlad) 76 | 77 | if self.gating: 78 | vlad = self.context_gating(vlad) 79 | 80 | return vlad 81 | 82 | 83 | class GatingContext(nn.Module): 84 | def __init__(self, dim, add_batch_norm=True): 85 | super(GatingContext, self).__init__() 86 | self.dim = dim 87 | self.add_batch_norm = add_batch_norm 88 | self.gating_weights = nn.Parameter( 89 | torch.randn(dim, dim) * 1 / math.sqrt(dim)) 90 | self.sigmoid = nn.Sigmoid() 91 | 92 | if add_batch_norm: 93 | self.gating_biases = None 94 | self.bn1 = nn.BatchNorm1d(dim) 95 | else: 96 | self.gating_biases = nn.Parameter( 97 | torch.randn(dim) * 1 / math.sqrt(dim)) 98 | self.bn1 = None 99 | 100 | def forward(self, x): 101 | gates = torch.matmul(x, self.gating_weights) 102 | 103 | if self.add_batch_norm: 104 | gates = self.bn1(gates) 105 | else: 106 | gates = gates + self.gating_biases 107 | 108 | gates = self.sigmoid(gates) 109 | 110 | activation = x * gates 111 | 112 | return activation 113 | -------------------------------------------------------------------------------- /layers/pooling.py: -------------------------------------------------------------------------------- 1 | # Pooling methods code based on: https://github.com/filipradenovic/cnnimageretrieval-pytorch 2 | # Global covariance pooling methods implementation taken from: 3 | # https://github.com/jiangtaoxie/fast-MPN-COV 4 | # and ported to MinkowskiEngine by Jacek Komorowski 5 | 6 | import torch 7 | import torch.nn as nn 8 | import MinkowskiEngine as ME 9 | 10 | from layers.netvlad import NetVLADLoupe 11 | 12 | 13 | class PoolingWrapper(nn.Module): 14 | def __init__(self, pool_method, in_dim, output_dim): 15 | super().__init__() 16 | 17 | self.pool_method = pool_method 18 | self.in_dim = in_dim 19 | self.output_dim = output_dim 20 | 21 | if pool_method == 'MAC': 22 | # Global max pooling 23 | assert in_dim == output_dim 24 | self.pooling = MAC(input_dim=in_dim) 25 | elif pool_method == 'SPoC': 26 | # Global average pooling 27 | assert in_dim == output_dim 28 | self.pooling = SPoC(input_dim=in_dim) 29 | elif pool_method == 'GeM': 30 | # Generalized mean pooling 31 | assert in_dim == output_dim 32 | self.pooling = GeM(input_dim=in_dim) 33 | elif self.pool_method == 'netvlad': 34 | # NetVLAD 35 | self.pooling = NetVLADWrapper(feature_size=in_dim, output_dim=output_dim, gating=False) 36 | elif self.pool_method == 'netvladgc': 37 | # NetVLAD with Gating Context 38 | self.pooling = NetVLADWrapper(feature_size=in_dim, output_dim=output_dim, gating=True) 39 | else: 40 | raise NotImplementedError('Unknown pooling method: {}'.format(pool_method)) 41 | 42 | def forward(self, x: ME.SparseTensor): 43 | return self.pooling(x) 44 | 45 | 46 | class MAC(nn.Module): 47 | def __init__(self, input_dim): 48 | super().__init__() 49 | self.input_dim = input_dim 50 | # Same output number of channels as input number of channels 51 | self.output_dim = self.input_dim 52 | self.f = ME.MinkowskiGlobalMaxPooling() 53 | 54 | def forward(self, x: ME.SparseTensor): 55 | x = self.f(x) 56 | return x.F # Return (batch_size, n_features) tensor 57 | 58 | 59 | class SPoC(nn.Module): 60 | def __init__(self, input_dim): 61 | super().__init__() 62 | self.input_dim = input_dim 63 | # Same output number of channels as input number of channels 64 | self.output_dim = self.input_dim 65 | self.f = ME.MinkowskiGlobalAvgPooling() 66 | 67 | def forward(self, x: ME.SparseTensor): 68 | x = self.f(x) 69 | return x.F # Return (batch_size, n_features) tensor 70 | 71 | 72 | class GeM(nn.Module): 73 | def __init__(self, input_dim, p=3, eps=1e-6): 74 | super(GeM, self).__init__() 75 | self.input_dim = input_dim 76 | # Same output number of channels as input number of channels 77 | self.output_dim = self.input_dim 78 | self.p = nn.Parameter(torch.ones(1) * p) 79 | self.eps = eps 80 | self.f = ME.MinkowskiGlobalAvgPooling() 81 | 82 | def forward(self, x: ME.SparseTensor): 83 | # This implicitly applies ReLU on x (clamps negative values) 84 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 85 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling 86 | return temp.F.pow(1./self.p) # Return (batch_size, n_features) tensor 87 | 88 | 89 | class NetVLADWrapper(nn.Module): 90 | def __init__(self, feature_size, output_dim, gating=True): 91 | super().__init__() 92 | self.feature_size = feature_size 93 | self.output_dim = output_dim 94 | self.net_vlad = NetVLADLoupe(feature_size=feature_size, cluster_size=64, output_dim=output_dim, gating=gating, 95 | add_batch_norm=True) 96 | 97 | def forward(self, x: ME.SparseTensor): 98 | # x is (batch_size, C, H, W) 99 | assert x.F.shape[1] == self.feature_size 100 | features = x.decomposed_features 101 | # features is a list of (n_points, feature_size) tensors with variable number of points 102 | batch_size = len(features) 103 | features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True) 104 | # features is (batch_size, n_points, feature_size) tensor padded with zeros 105 | 106 | x = self.net_vlad(features) 107 | assert x.shape[0] == batch_size 108 | assert x.shape[1] == self.output_dim 109 | return x # Return (batch_size, output_dim) tensor 110 | -------------------------------------------------------------------------------- /layers/senet_block.py: -------------------------------------------------------------------------------- 1 | # Changed by JK 2 | # Parameter D in SELayer, SEBasicBlock and SEBottleneck changed to dimensions to be the same as for Resnet blocks 3 | 4 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 5 | # 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 7 | # this software and associated documentation files (the "Software"), to deal in 8 | # the Software without restriction, including without limitation the rights to 9 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 10 | # of the Software, and to permit persons to whom the Software is furnished to do 11 | # so, subject to the following conditions: 12 | # 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | # 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | # 24 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 25 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 26 | # of the code. 27 | import torch.nn as nn 28 | 29 | import MinkowskiEngine as ME 30 | 31 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 32 | 33 | 34 | class SELayer(nn.Module): 35 | 36 | def __init__(self, channel, reduction=16, dimension=-1): 37 | # Global coords does not require coords_key 38 | super(SELayer, self).__init__() 39 | self.fc = nn.Sequential( 40 | ME.MinkowskiLinear(channel, channel // reduction), 41 | ME.MinkowskiReLU(inplace=True), 42 | ME.MinkowskiLinear(channel // reduction, channel), 43 | ME.MinkowskiSigmoid()) 44 | self.pooling = ME.MinkowskiGlobalPooling() 45 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() 46 | 47 | def forward(self, x): 48 | y = self.pooling(x) 49 | y = self.fc(y) 50 | return self.broadcast_mul(x, y) 51 | 52 | 53 | class SEBasicBlock(BasicBlock): 54 | 55 | def __init__(self, 56 | inplanes, 57 | planes, 58 | stride=1, 59 | dilation=1, 60 | downsample=None, 61 | reduction=16, 62 | dimension=3): 63 | super(SEBasicBlock, self).__init__( 64 | inplanes, 65 | planes, 66 | stride=stride, 67 | dilation=dilation, 68 | downsample=downsample, 69 | dimension=dimension) 70 | self.se = SELayer(planes, reduction=reduction, dimension=dimension) 71 | 72 | def forward(self, x): 73 | residual = x 74 | 75 | out = self.conv1(x) 76 | out = self.norm1(out) 77 | out = self.relu(out) 78 | 79 | out = self.conv2(out) 80 | out = self.norm2(out) 81 | out = self.se(out) 82 | 83 | if self.downsample is not None: 84 | residual = self.downsample(x) 85 | 86 | out += residual 87 | out = self.relu(out) 88 | 89 | return out 90 | 91 | 92 | class SEBottleneck(Bottleneck): 93 | 94 | def __init__(self, 95 | inplanes, 96 | planes, 97 | stride=1, 98 | dilation=1, 99 | downsample=None, 100 | dimension=3, 101 | reduction=16): 102 | super(SEBottleneck, self).__init__( 103 | inplanes, 104 | planes, 105 | stride=stride, 106 | dilation=dilation, 107 | downsample=downsample, 108 | dimension=dimension) 109 | self.se = SELayer(planes * self.expansion, reduction=reduction, dimension=dimension) 110 | 111 | def forward(self, x): 112 | residual = x 113 | 114 | out = self.conv1(x) 115 | out = self.norm1(out) 116 | out = self.relu(out) 117 | 118 | out = self.conv2(out) 119 | out = self.norm2(out) 120 | out = self.relu(out) 121 | 122 | out = self.conv3(out) 123 | out = self.norm3(out) 124 | out = self.se(out) 125 | 126 | if self.downsample is not None: 127 | residual = self.downsample(x) 128 | 129 | out += residual 130 | out = self.relu(out) 131 | 132 | return out 133 | -------------------------------------------------------------------------------- /misc/point_clouds.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | 8 | def draw_registration_result(source, target, transformation): 9 | source_temp = copy.deepcopy(source) 10 | target_temp = copy.deepcopy(target) 11 | source_temp.paint_uniform_color([1, 0.706, 0]) 12 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 13 | source_temp.transform(transformation) 14 | o3d.visualization.draw_geometries([source_temp, target_temp], 15 | zoom=0.4459, 16 | front=[0.9288, -0.2951, -0.2242], 17 | lookat=[1.6784, 2.0612, 1.4451], 18 | up=[-0.3402, -0.9189, -0.1996]) 19 | 20 | 21 | def draw_pc(pc): 22 | pc = copy.deepcopy(pc) 23 | pc.paint_uniform_color([1, 0.706, 0]) 24 | o3d.visualization.draw_geometries([pc], 25 | zoom=0.4459, 26 | front=[0.9288, -0.2951, -0.2242], 27 | lookat=[1.6784, 2.0612, 1.4451], 28 | up=[-0.3402, -0.9189, -0.1996]) 29 | 30 | 31 | def icp(anchor_pc, positive_pc, transform: np.ndarray = None, point2plane: bool = False, 32 | inlier_dist_threshold: float = 1.2, max_iteration: int = 200): 33 | # transform: initial alignment transform 34 | if transform is not None: 35 | transform = transform.astype(float) 36 | 37 | voxel_size = 0.1 38 | pcd1 = o3d.geometry.PointCloud() 39 | pcd1.points = o3d.utility.Vector3dVector(anchor_pc) 40 | pcd1 = pcd1.voxel_down_sample(voxel_size=voxel_size) 41 | 42 | pcd2 = o3d.geometry.PointCloud() 43 | pcd2.points = o3d.utility.Vector3dVector(positive_pc) 44 | pcd2 = pcd2.voxel_down_sample(voxel_size=voxel_size) 45 | 46 | if point2plane: 47 | pcd1.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 48 | pcd2.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20)) 49 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPlane() 50 | else: 51 | transform_estimation = o3d.pipelines.registration.TransformationEstimationPointToPoint() 52 | 53 | if transform is not None: 54 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, transform, 55 | estimation_method=transform_estimation, 56 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 57 | else: 58 | reg_p2p = o3d.pipelines.registration.registration_icp(pcd1, pcd2, inlier_dist_threshold, 59 | estimation_method=transform_estimation, 60 | criteria=o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) 61 | 62 | return reg_p2p.transformation, reg_p2p.fitness, reg_p2p.inlier_rmse 63 | 64 | 65 | def make_open3d_feature(data, dim, npts): 66 | feature = o3d.pipelines.registration.Feature() 67 | feature.resize(dim, npts) 68 | feature.data = data.cpu().numpy().astype('d').transpose() 69 | return feature 70 | 71 | 72 | def make_open3d_point_cloud(xyz, color=None): 73 | pcd = o3d.geometry.PointCloud() 74 | pcd.points = o3d.utility.Vector3dVector(xyz) 75 | if color is not None: 76 | pcd.colors = o3d.utility.Vector3dVector(color) 77 | return pcd 78 | 79 | 80 | class PointCloudLoader: 81 | # Generic point cloud loader class 82 | def __init__(self): 83 | # remove_zero_points: remove points with all zero coordinates 84 | # remove_ground_plane: remove points on ground plane level and below 85 | # ground_plane_level: ground plane level 86 | self.remove_zero_points = True 87 | self.remove_ground_plane = True 88 | self.ground_plane_level = None 89 | self.set_properties() 90 | 91 | def set_properties(self): 92 | # Set point cloud properties, such as ground_plane_level. Must be defined in inherited classes. 93 | raise NotImplementedError('set_properties must be defined in inherited classes') 94 | 95 | def __call__(self, file_pathname): 96 | # Reads the point cloud from a disk and preprocess (optional removal of zero points and points on the ground 97 | # plane and below 98 | # file_pathname: relative file path 99 | assert os.path.exists(file_pathname), f"Cannot open point cloud: {file_pathname}" 100 | pc = self.read_pc(file_pathname) 101 | assert pc.shape[1] == 3 102 | 103 | if self.remove_zero_points: 104 | mask = np.all(np.isclose(pc, 0), axis=1) 105 | pc = pc[~mask] 106 | 107 | if self.remove_ground_plane: 108 | mask = pc[:, 2] > self.ground_plane_level 109 | pc = pc[mask] 110 | 111 | return pc 112 | 113 | def read_pc(self, file_pathname): 114 | # Reads the point cloud without pre-processing 115 | raise NotImplementedError("read_pc must be overloaded in an inheriting class") 116 | -------------------------------------------------------------------------------- /misc/poses.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def q2r(q): 6 | # Rotation matrix from Hamiltonian quaternion 7 | # Source: https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles 8 | w, x, y, z = tuple(q) 9 | 10 | n = 1.0/np.sqrt(x*x+y*y+z*z+w*w) 11 | x *= n 12 | y *= n 13 | z *= n 14 | w *= n 15 | r = np.array([[1.0 - 2.0*y*y - 2.0*z*z, 2.0*x*y - 2.0*z*w, 2.0*x*z + 2.0*y*w], 16 | [2.0*x*y + 2.0*z*w, 1.0 - 2.0*x*x - 2.0*z*z, 2.0*y*z - 2.0*x*w], 17 | [2.0*x*z - 2.0*y*w, 2.0*y*z + 2.0*x*w, 1.0 - 2.0*x*x - 2.0*y*y]]) 18 | return r 19 | 20 | 21 | def m2ypr(m): 22 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 23 | # Based on formulas in Section 2.5.1 in: 24 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 25 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 26 | assert m.shape == (4, 4) 27 | pitch = np.arctan2(-m[2][0], np.sqrt(m[0][0]**2 + m[1][0]**2)) 28 | # We do not handle degenerate case, when pitch is 90 degrees a.k.a. gimball lock 29 | assert not np.isclose(np.abs(pitch), np.pi/2) 30 | yaw = np.arctan2(m[1][0], m[0][0]) 31 | roll = np.arctan2(m[2][1], m[2][2]) 32 | return yaw, pitch, roll 33 | 34 | 35 | def m2xyz_ypr(m): 36 | # Get yaw, pitch, roll angles from 4x4 transformation matrix 37 | # Based on formulas in Section 2.5.1 in: 38 | # A tutorial on SE(3) transformation parameterizations and on-manifold optimization 39 | # https://ingmec.ual.es/~jlblanco/papers/jlblanco2010geometry3D_techrep.pdf 40 | assert m.shape == (4, 4) 41 | yaw, pitch, roll = m2ypr(m) 42 | return m[0, 3], m[1, 3], m[2, 3], yaw, pitch, roll 43 | 44 | 45 | def ypr2m(yaw, pitch, roll): 46 | # Construct 4x4 transformation matrix with rotation part set to given yaw, pitch, roll. Translation is set to 0. 47 | # Based on formulas in Section 2.2.1 48 | m = np.array([[np.cos(yaw) * np.cos(pitch), np.cos(yaw) * np.sin(pitch) * np.sin(roll) - np.sin(yaw) * np.cos(roll), 49 | np.cos(yaw) * np.sin(pitch) * np.cos(roll) + np.sin(yaw) * np.sin(roll), 0.], 50 | [np.sin(yaw) * np.cos(pitch), np.sin(roll) * np.sin(pitch) * np.sin(roll) + np.cos(yaw) * np.cos(roll), 51 | np.sin(yaw) * np.sin(pitch) * np.cos(roll) - np.cos(yaw) * np.sin(roll), 0.], 52 | [-np.sin(pitch), np.cos(pitch) * np.sin(roll), np.cos(pitch) * np.cos(roll), 0.], 53 | [0., 0., 0., 1.]], dtype=np.float32) 54 | 55 | return m 56 | 57 | 58 | def xyz_ypr2m(x, y, z, yaw, pitch, roll): 59 | # Construct 4x4 transformation matrix with given translation and rotation part set to given yaw, pitch, roll. 60 | # Based on formulas in Section 2.2.1 61 | m = ypr2m(yaw, pitch, roll) 62 | m[0, 3] = x 63 | m[1, 3] = y 64 | m[2, 3] = z 65 | return m 66 | 67 | 68 | def apply_transform(pc: torch.Tensor, m: torch.Tensor): 69 | # Apply 4x4 SE(3) transformation matrix on (N, 3) point cloud or 3x3 transformation on (N, 2) point cloud 70 | assert pc.ndim == 2 71 | n_dim = pc.shape[1] 72 | assert n_dim == 2 or n_dim == 3 73 | assert m.shape == (n_dim + 1, n_dim + 1) 74 | # (m @ pc.t).t = pc @ m.t 75 | pc = pc @ m[:n_dim, :n_dim].transpose(1, 0) + m[:n_dim, -1] 76 | return pc 77 | 78 | 79 | def relative_pose(m1, m2): 80 | # !!! DO NOT USE THIS FUNCTION FOR MULRAN POSES !!! 81 | # SE(3) pose is 4x 4 matrix, such that 82 | # Pw = [R | T] @ [P] 83 | # [0 | 1] [1] 84 | # where Pw are coordinates in the world reference frame and P are coordinates in the camera frame 85 | # m1: coords in camera/lidar1 reference frame -> world coordinate frame 86 | # m2: coords in camera/lidar2 coords -> world coordinate frame 87 | # returns: coords in camera/lidar1 reference frame -> coords in camera/lidar2 reference frame 88 | # relative pose of the first camera with respect to the second camera 89 | return np.linalg.inv(m2) @ m1 90 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import os 4 | import configparser 5 | import time 6 | import numpy as np 7 | 8 | from datasets.quantization import PolarQuantizer, CartesianQuantizer 9 | 10 | 11 | class ModelParams: 12 | def __init__(self, model_params_path): 13 | config = configparser.ConfigParser() 14 | config.read(model_params_path) 15 | params = config['MODEL'] 16 | 17 | self.model_params_path = model_params_path 18 | self.model = params.get('model') 19 | self.output_dim = params.getint('output_dim', 256) # Size of the final descriptor 20 | 21 | ####################################################################### 22 | # Model dependent 23 | ####################################################################### 24 | 25 | self.coordinates = params.get('coordinates', 'polar') 26 | assert self.coordinates in ['polar', 'cartesian'], f'Unsupported coordinates: {self.coordinates}' 27 | 28 | if 'polar' in self.coordinates: 29 | # 3 quantization steps for polar coordinates: for sectors (in degrees), rings (in meters) and z coordinate (in meters) 30 | self.quantization_step = [float(e) for e in params['quantization_step'].split(',')] 31 | assert len(self.quantization_step) == 3, f'Expected 3 quantization steps: for sectors (degrees), rings (meters) and z coordinate (meters)' 32 | self.quantizer = PolarQuantizer(quant_step=self.quantization_step) 33 | elif 'cartesian' in self.coordinates: 34 | # Single quantization step for cartesian coordinates 35 | self.quantization_step = params.getfloat('quantization_step') 36 | self.quantizer = CartesianQuantizer(quant_step=self.quantization_step) 37 | else: 38 | raise NotImplementedError(f"Unsupported coordinates: {self.coordinates}") 39 | 40 | #self.remove_ground_plane = params.getboolean('remove_ground_plane', False) 41 | 42 | if 'MinkLoc' in self.model: 43 | # Size of the local features from backbone network (only for MinkNet based models) 44 | self.feature_size = params.getint('feature_size', 256) 45 | if 'planes' in params: 46 | self.planes = [int(e) for e in params['planes'].split(',')] 47 | else: 48 | self.planes = [32, 64, 64] 49 | 50 | if 'layers' in params: 51 | self.layers = [int(e) for e in params['layers'].split(',')] 52 | else: 53 | self.layers = [1, 1, 1] 54 | 55 | self.num_top_down = params.getint('num_top_down', 1) 56 | self.conv0_kernel_size = params.getint('conv0_kernel_size', 5) 57 | self.block = params.get('block', 'BasicBlock') 58 | self.pooling = params.get('pooling', 'GeM') 59 | 60 | def print(self): 61 | print('Model parameters:') 62 | param_dict = vars(self) 63 | for e in param_dict: 64 | if e == 'quantization_step': 65 | s = param_dict[e] 66 | if self.coordinates == 'polar': 67 | print(f'quantization_step - sector: {s[0]} [deg] / ring: {s[1]} [m] / z: {s[2]} [m]') 68 | else: 69 | print(f'quantization_step: {s} [m]') 70 | else: 71 | print('{}: {}'.format(e, param_dict[e])) 72 | 73 | print('') 74 | 75 | 76 | def get_datetime(): 77 | return time.strftime("%Y%m%d_%H%M") 78 | 79 | 80 | class TrainingParams: 81 | """ 82 | Parameters for model training 83 | """ 84 | def __init__(self, params_path, model_params_path): 85 | """ 86 | Configuration files 87 | :param path: Training configuration file 88 | :param model_params: Model-specific configuration file 89 | """ 90 | 91 | assert os.path.exists(params_path), 'Cannot find configuration file: {}'.format(params_path) 92 | assert os.path.exists(model_params_path), 'Cannot find model-specific configuration file: {}'.format(model_params_path) 93 | self.params_path = params_path 94 | self.model_params_path = model_params_path 95 | 96 | config = configparser.ConfigParser() 97 | 98 | config.read(self.params_path) 99 | params = config['DEFAULT'] 100 | self.dataset = params.get('dataset', 'mulran').lower() 101 | self.dataset_folder = params.get('dataset_folder') 102 | # Seconday dataset for global descriptor training 103 | self.secondary_dataset = params.get('secondary_dataset', None) 104 | if self.secondary_dataset is not None: 105 | self.secondary_dataset = self.secondary_dataset.lower() 106 | self.secondary_dataset_folder = params.get('secondary_dataset_folder', None) 107 | 108 | # Maximum random rotation and translation applied when generating pairs for local descriptor 109 | self.rot_max = params.getfloat('rot_max', np.pi) 110 | self.trans_max = params.getfloat('rot_max', 5.) 111 | 112 | params = config['TRAIN'] 113 | 114 | self.save_freq = params.getint('save_freq', 20) # Model saving frequency (in epochs) 115 | self.num_workers = params.getint('num_workers', 4) 116 | # Initial batch size for global descriptors (for both main and secondary dataset) 117 | self.batch_size = params.getint('batch_size', 64) 118 | # Batch size for local descriptors 119 | self.local_batch_size = params.getint('local_batch_size', 2) 120 | 121 | # Set batch_expansion_th to turn on dynamic batch sizing 122 | # When number of non-zero triplets falls below batch_expansion_th, expand batch size 123 | self.batch_expansion_th = params.getfloat('batch_expansion_th', None) 124 | if self.batch_expansion_th is not None: 125 | assert 0. < self.batch_expansion_th < 1., 'batch_expansion_th must be between 0 and 1' 126 | self.batch_size_limit = params.getint('batch_size_limit', 256) 127 | # Batch size expansion rate 128 | self.batch_expansion_rate = params.getfloat('batch_expansion_rate', 1.5) 129 | assert self.batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 130 | else: 131 | self.batch_size_limit = self.batch_size 132 | self.batch_expansion_rate = None 133 | 134 | if 'secondary_batch_size_limit' in params: 135 | self.secondary_batch_size_limit = params.getint('secondary_batch_size_limit') 136 | else: 137 | self.secondary_batch_size_limit = self.batch_size_limit 138 | 139 | self.loss_gammas = params.get('l_gammas', None) 140 | if self.loss_gammas is not None: 141 | self.loss_gammas = [float(e) for e in self.loss_gammas.split(',')] 142 | self.lr = params.getfloat('lr', 1e-3) 143 | 144 | self.scheduler = params.get('scheduler', 'MultiStepLR') 145 | if self.scheduler is not None: 146 | if self.scheduler == 'CosineAnnealingLR': 147 | self.min_lr = params.getfloat('min_lr') 148 | elif self.scheduler == 'MultiStepLR': 149 | scheduler_milestones = params.get('scheduler_milestones') 150 | self.scheduler_milestones = [int(e) for e in scheduler_milestones.split(',')] 151 | else: 152 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(self.scheduler)) 153 | 154 | self.epochs = params.getint('epochs', 20) 155 | self.weight_decay = params.getfloat('weight_decay', None) 156 | self.loss = params.get('loss') 157 | 158 | if 'Contrastive' in self.loss: 159 | self.pos_margin = params.getfloat('pos_margin', 0.2) 160 | self.neg_margin = params.getfloat('neg_margin', 0.65) 161 | elif 'Triplet' in self.loss: 162 | self.margin = params.getfloat('margin', 0.4) # Margin used in loss function 163 | else: 164 | raise 'Unsupported loss function: {}'.format(self.loss) 165 | 166 | self.aug_mode = params.getint('aug_mode', 1) # Augmentation mode (1 is default) 167 | 168 | self.train_file = params.get('train_file') 169 | self.val_file = params.get('val_file', None) 170 | self.secondary_train_file = params.get('secondary_train_file', None) 171 | self.test_file = params.get('test_file', None) 172 | 173 | # Read model parameters 174 | self.model_params = ModelParams(self.model_params_path) 175 | self._check_params() 176 | 177 | def _check_params(self): 178 | assert os.path.exists(self.dataset_folder), 'Cannot access dataset: {}'.format(self.dataset_folder) 179 | 180 | def print(self): 181 | print('Parameters:') 182 | param_dict = vars(self) 183 | for e in param_dict: 184 | if e != 'model_params': 185 | print('{}: {}'.format(e, param_dict[e])) 186 | 187 | self.model_params.print() 188 | print('') 189 | 190 | -------------------------------------------------------------------------------- /models/egonn.txt: -------------------------------------------------------------------------------- 1 | [MODEL] 2 | model = egonn 3 | coordinates = polar 4 | # Quantization steps for sectors (in degrees), rings (in meters) and z coordinate (in meters) 5 | quantization_step = 1., 0.3, 0.2 6 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch_metric_learning import losses, miners, reducers 6 | from pytorch_metric_learning.distances import LpDistance 7 | from misc.utils import TrainingParams 8 | from misc.poses import apply_transform 9 | from models.loss_utils import * 10 | 11 | 12 | def make_losses(params: TrainingParams): 13 | if params.loss == 'BatchHardTripletMarginLoss': 14 | # BatchHard mining with triplet margin loss 15 | # Expects input: embeddings, positives_mask, negatives_mask 16 | gl_loss_fn = BatchHardTripletLossWithMasks(params.margin) 17 | elif params.loss == 'BatchHardContrastiveLoss': 18 | gl_loss_fn = BatchHardContrastiveLossWithMasks(params.pos_margin, params.neg_margin) 19 | else: 20 | print('Unknown loss: {}'.format(params.loss)) 21 | raise NotImplementedError 22 | 23 | if params.loss_gammas is not None: 24 | gamma_chamfer, gamma_p2p, gamma_c, beta = params.loss_gammas 25 | else: 26 | gamma_chamfer, gamma_p2p, gamma_c, beta = [1., 1., 1., 2.] 27 | loc_loss_fn = KeypointCorrLoss(gamma_c=gamma_c, gamma_chamfer=gamma_chamfer, gamma_p2p=gamma_p2p, beta=beta) 28 | 29 | return gl_loss_fn, loc_loss_fn 30 | 31 | 32 | class KeypointCorrLoss: 33 | # Loss combining keypoint loss and correspondence matrix loss 34 | def __init__(self, gamma_c=1., gamma_k=1., gamma_chamfer=1., gamma_p2p=1., beta=1.,dist_th=0.5): 35 | # alpha and beta are parameters in Eq. (5) in RPM-Net paper 36 | self.keypoint_loss = KeypointLoss(gamma_chamfer=gamma_chamfer, gamma_p2p=gamma_p2p, prob_chamfer_loss=True, 37 | p2p_loss=True, repeatability_dist_th=dist_th) 38 | self.correspondence_loss = CorrespondenceLoss(beta=beta, dist_th=dist_th) 39 | 40 | self.gamma_k = gamma_k 41 | self.gamma_c = gamma_c 42 | 43 | def __call__(self, clouds1: torch.Tensor, keypoints1: torch.Tensor, sigma1: torch.Tensor, descriptors1: torch.Tensor, 44 | clouds2: torch.Tensor, keypoints2: torch.Tensor, sigma2: torch.Tensor, descriptors2: torch.Tensor, 45 | M_gt, len_batch): 46 | assert clouds1.dim() == 2 47 | assert clouds2.dim() == 2 48 | assert len(keypoints1) == len(sigma1) 49 | assert len(keypoints2) == len(sigma2) 50 | assert len(keypoints1) == len(descriptors1) 51 | assert len(keypoints2) == len(descriptors2) 52 | 53 | len_clouds1 = [e[0] for e in len_batch] # Len on source point clouds 54 | len_clouds2 = [e[1] for e in len_batch] # Len of target point clouds 55 | len_clouds1 = [0] + len_clouds1 56 | len_clouds2 = [0] + len_clouds2 57 | cum_len_clouds1 = np.cumsum(len_clouds1) # Cummulative length 58 | cum_len_clouds2 = np.cumsum(len_clouds2) 59 | 60 | assert cum_len_clouds1[-1] == len(clouds1) 61 | assert cum_len_clouds2[-1] == len(clouds2) 62 | 63 | batch_metrics = [] 64 | batch_loss = [] 65 | 66 | for i, (kp1_pos, s1, desc1, kp2_pos, s2, desc2, Mi_gt) in enumerate(zip(keypoints1, sigma1, descriptors1, 67 | keypoints2, sigma2, descriptors2, 68 | M_gt)): 69 | pc1 = clouds1[cum_len_clouds1[i]:cum_len_clouds1[i + 1]] 70 | pc2 = clouds2[cum_len_clouds2[i]:cum_len_clouds2[i + 1]] 71 | 72 | kp1_pos_trans = apply_transform(kp1_pos, Mi_gt) 73 | dist_kp1_trans_kp2 = torch.cdist(kp1_pos_trans, kp2_pos) # Euclidean distance 74 | 75 | metrics = {} 76 | metrics['kp_per_cloud'] = 0.5 * (len(kp1_pos) + len(kp2_pos)) 77 | loss_keypoints, m = self.keypoint_loss(pc1, kp1_pos, s1, pc2, kp2_pos, s2, dist_kp1_trans_kp2) 78 | metrics.update(m) 79 | 80 | loss_correspondence, m = self.correspondence_loss(desc1, desc2, dist_kp1_trans_kp2) 81 | 82 | metrics.update(m) 83 | loss = self.gamma_k * loss_keypoints + self.gamma_c * loss_correspondence 84 | metrics['loss'] = loss.item() 85 | batch_metrics.append(metrics) 86 | batch_loss.append(loss) 87 | 88 | # Mean values per batch 89 | batch_loss = torch.stack(batch_loss).mean() 90 | batch_metrics = metrics_mean(batch_metrics) 91 | 92 | return batch_loss, batch_metrics 93 | 94 | 95 | class HardTripletMinerWithMasks: 96 | # Hard triplet miner 97 | def __init__(self, distance): 98 | self.distance = distance 99 | # Stats 100 | self.max_pos_pair_dist = None 101 | self.max_neg_pair_dist = None 102 | self.mean_pos_pair_dist = None 103 | self.mean_neg_pair_dist = None 104 | self.min_pos_pair_dist = None 105 | self.min_neg_pair_dist = None 106 | 107 | def __call__(self, embeddings, positives_mask, negatives_mask): 108 | assert embeddings.dim() == 2 109 | d_embeddings = embeddings.detach() 110 | with torch.no_grad(): 111 | hard_triplets = self.mine(d_embeddings, positives_mask, negatives_mask) 112 | return hard_triplets 113 | 114 | def mine(self, embeddings, positives_mask, negatives_mask): 115 | # Based on pytorch-metric-learning implementation 116 | dist_mat = self.distance(embeddings) 117 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask) 118 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 119 | a_keep_idx = torch.where(a1p_keep & a2n_keep) 120 | a = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx] 121 | p = hardest_positive_indices[a_keep_idx] 122 | n = hardest_negative_indices[a_keep_idx] 123 | self.max_pos_pair_dist = torch.max(hardest_positive_dist).item() 124 | self.max_neg_pair_dist = torch.max(hardest_negative_dist).item() 125 | self.mean_pos_pair_dist = torch.mean(hardest_positive_dist).item() 126 | self.mean_neg_pair_dist = torch.mean(hardest_negative_dist).item() 127 | self.min_pos_pair_dist = torch.min(hardest_positive_dist).item() 128 | self.min_neg_pair_dist = torch.min(hardest_negative_dist).item() 129 | return a, p, n 130 | 131 | 132 | def get_max_per_row(mat, mask): 133 | non_zero_rows = torch.any(mask, dim=1) 134 | mat_masked = mat.clone() 135 | mat_masked[~mask] = 0 136 | return torch.max(mat_masked, dim=1), non_zero_rows 137 | 138 | 139 | def get_min_per_row(mat, mask): 140 | non_inf_rows = torch.any(mask, dim=1) 141 | mat_masked = mat.clone() 142 | mat_masked[~mask] = float('inf') 143 | return torch.min(mat_masked, dim=1), non_inf_rows 144 | 145 | 146 | class BatchHardTripletLossWithMasks: 147 | def __init__(self, margin): 148 | self.margin = margin 149 | self.distance = LpDistance(normalize_embeddings=False, collect_stats=True) 150 | # We use triplet loss with Euclidean distance 151 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 152 | reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) 153 | self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance, 154 | reducer=reducer_fn, collect_stats=True) 155 | 156 | def __call__(self, embeddings, positives_mask, negatives_mask): 157 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 158 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 159 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 160 | 161 | stats = {'loss': loss.item(), 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 162 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter, 163 | 'num_triplets': len(hard_triplets[0]), 164 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 165 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 166 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 167 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 168 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 169 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist 170 | } 171 | 172 | return loss, stats, hard_triplets 173 | 174 | 175 | class BatchHardContrastiveLossWithMasks: 176 | def __init__(self, pos_margin, neg_margin): 177 | self.pos_margin = pos_margin 178 | self.neg_margin = neg_margin 179 | self.distance = LpDistance(normalize_embeddings=False, collect_stats=True) 180 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 181 | # We use contrastive loss with squared Euclidean distance 182 | reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) 183 | self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin, neg_margin=self.neg_margin, 184 | distance=self.distance, reducer=reducer_fn, collect_stats=True) 185 | 186 | def __call__(self, embeddings, positives_mask, negatives_mask): 187 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 188 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 189 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 190 | stats = {'loss': loss.item(), 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 191 | 'pos_pairs_above_threshold': self.loss_fn.reducer.reducers['pos_loss'].pos_pairs_above_threshold, 192 | 'neg_pairs_above_threshold': self.loss_fn.reducer.reducers['neg_loss'].neg_pairs_above_threshold, 193 | 'pos_loss': self.loss_fn.reducer.reducers['pos_loss'].pos_loss.item(), 194 | 'neg_loss': self.loss_fn.reducer.reducers['neg_loss'].neg_loss.item(), 195 | 'num_pairs': 2*len(hard_triplets[0]), 196 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 197 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 198 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 199 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 200 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 201 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist 202 | } 203 | 204 | return loss, stats, hard_triplets 205 | -------------------------------------------------------------------------------- /models/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | # Functions and classes used by different loss functions 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import math 8 | EPS = 1e-5 9 | 10 | 11 | class KeypointLoss: 12 | # Computes loss for regressed keypoints (based on UNET paper) 13 | def __init__(self, gamma_chamfer=1., gamma_p2p=1., prob_chamfer_loss=True, p2p_loss=True, 14 | repeatability_dist_th=0.5): 15 | # prob_chamfer_loss: if True probabilistic chamfer loss is calculated using predicted saliency uncertainty 16 | # p2p_loss: if True, point-2-point loss is calculated and added to chamfer loss 17 | self.gamma_chamfer = gamma_chamfer 18 | self.gamma_p2p = gamma_p2p 19 | self.prob_chamfer_loss = prob_chamfer_loss 20 | self.p2p_loss = p2p_loss 21 | self.repeatability_dist_th = repeatability_dist_th 22 | 23 | def __call__(self, pc1: torch.Tensor, kp1: torch.Tensor, sigma1: torch.Tensor, 24 | pc2: torch.Tensor, kp2: torch.Tensor, sigma2: torch.Tensor, 25 | dist_kp1_trans_kp2: torch.Tensor): 26 | # dist_kp1_trans_kp2: Euclidean distance matrix between transformed points in the first cloud 27 | # and points in the second cloud 28 | assert pc1.shape[1] == 3 29 | assert pc2.shape[1] == 3 30 | assert kp1.shape[1] == 3 31 | assert kp2.shape[1] == 3 32 | assert sigma1.shape[1] == 1 33 | assert sigma2.shape[1] == 1 34 | assert kp1.shape[0] == sigma1.shape[0] 35 | assert kp2.shape[0] == sigma2.shape[0] 36 | assert dist_kp1_trans_kp2.shape[0] == kp1.shape[0] 37 | assert dist_kp1_trans_kp2.shape[1] == kp2.shape[0] 38 | 39 | # Convert (n, 1) tensors to (n,) tensor 40 | sigma1 = sigma1.squeeze(1) 41 | sigma2 = sigma2.squeeze(1) 42 | 43 | min_dist1, min_dist_ndx1 = torch.min(dist_kp1_trans_kp2, dim=1) 44 | min_dist2, min_dist_ndx2 = torch.min(dist_kp1_trans_kp2, dim=0) 45 | 46 | # COMPUTE CHAMFER LOSS BETWEEN KEYPOINTS REGRESSED IN THE SOURCE AND TARGET KEYPOINT 47 | # In USIP code they use non-squared chamfer distance, in USIP paper squared distance is used 48 | 49 | # Match points from the first cloud to the closest points in the second cloud 50 | if self.prob_chamfer_loss: 51 | selected_sigma1 = sigma2[min_dist_ndx1] # (n_keypoints1) tensor 52 | sigma12 = (sigma1 + selected_sigma1) / 2 53 | loss1 = (torch.log(sigma12) + min_dist1 / sigma12).mean() 54 | else: 55 | loss1 = min_dist1.mean() 56 | 57 | # Match points from the second cloud to the closest points in the first cloud 58 | if self.prob_chamfer_loss: 59 | selected_sigma2 = sigma1[min_dist_ndx2] # (n_keypoints2) tensor 60 | sigma21 = (sigma2 + selected_sigma2) / 2 61 | loss2 = (torch.log(sigma21) + min_dist2 / sigma21).mean() 62 | else: 63 | loss2 = min_dist2.mean() 64 | 65 | # Metrics not included in optimization, for reporting only 66 | metrics = {} 67 | metrics['repeatability'] = torch.mean((min_dist1 <= self.repeatability_dist_th).float()).item() 68 | metrics['chamfer_pure'] = 0.5 * (min_dist1.detach().mean() + min_dist2.detach().mean()).item() 69 | if self.prob_chamfer_loss: 70 | weight_src_dst = (1.0 / sigma12.detach()) / (1.0 / sigma12.detach()).mean() 71 | weight_dst_src = (1.0 / sigma21.detach()) / (1.0 / sigma21.detach()).mean() 72 | chamfer_weighted = 0.5 * (weight_src_dst * min_dist1.detach()).mean() + \ 73 | 0.5 * (weight_dst_src * min_dist2.detach()).mean() 74 | metrics['chamfer_weighted'] = chamfer_weighted.item() 75 | 76 | metrics['mean_sigma'] = 0.5 * (sigma12.detach().mean() + sigma21.detach().mean()).item() 77 | loss = self.gamma_chamfer * 0.5 * (loss1 + loss2) 78 | metrics['loss_chamfer'] = loss.item() 79 | 80 | if self.p2p_loss: 81 | # keypoints1 distance to the first point cloud 82 | dist1 = torch.cdist(kp1, pc1) 83 | min_dist1, _ = torch.min(dist1, dim=1) 84 | 85 | # keypoints2 distance to the first second point cloud 86 | dist2 = torch.cdist(kp2, pc2) 87 | min_dist2, _ = torch.min(dist2, dim=1) 88 | 89 | loss_p2p = 0.5 * (min_dist1.mean() + min_dist2.mean()) 90 | metrics['loss_p2p'] = loss_p2p.item() 91 | loss = loss + self.gamma_p2p * loss_p2p 92 | 93 | metrics['keypoint_loss'] = loss.item() 94 | 95 | return loss, metrics 96 | 97 | 98 | class CorrespondenceLoss(nn.Module): 99 | # Loss inspired by "Correspondence Matrices are Underrated" paper 100 | # Modified to handle partially overlapping clouds 101 | def __init__(self, beta, dist_th=0.5): 102 | # beta is not used currently 103 | super().__init__() 104 | self.beta = beta # Inverse of the temperature 105 | self.dist_th = dist_th 106 | self.loss = torch.nn.CrossEntropyLoss() 107 | 108 | def forward(self, desc1: torch.Tensor, desc2: torch.Tensor, dist_kp1_trans_kp2: torch.Tensor): 109 | # dist_kp1_trans_kp2: Euclidean distance matrix between transformed points in the first cloud 110 | # and points in the second cloud 111 | assert dist_kp1_trans_kp2.shape[0] == desc1.shape[0] 112 | assert dist_kp1_trans_kp2.shape[1] == desc2.shape[0] 113 | 114 | # For each keypoint from the first cloud, find it's true class, that is a corresponding keypoint from 115 | # the second cloud. Filter out keypoints without correspondence. 116 | min_dist1, min_dist_ndx1 = torch.min(dist_kp1_trans_kp2, dim=1) 117 | mask = min_dist1 <= self.dist_th 118 | 119 | similarity_matrix = torch.mm(desc1[mask], desc2.t()) * math.exp(self.beta) 120 | loss = self.loss(similarity_matrix, min_dist_ndx1[mask]) 121 | # Keypoints having a corresponding keypoint in the other cloud (below dist_th meters) 122 | matching_keypoints = torch.sum(mask).float().item() 123 | # Keypoints that are correctly matched by searching for the closest descriptor 124 | if matching_keypoints > 0: 125 | tp_mask = torch.max(similarity_matrix, 1)[1] == min_dist_ndx1[mask] 126 | matching_descriptors = torch.sum(tp_mask).float().item() 127 | pos_similarity = torch.mean(torch.max(similarity_matrix, 1)[1].float()).item() 128 | neg_mat = similarity_matrix.clone().detach() 129 | neg_mat[:, min_dist_ndx1[mask]] = 0. 130 | neg_similarity = torch.mean(torch.max(neg_mat, 1)[0]).float().item() 131 | else: 132 | matching_descriptors = 0. 133 | pos_similarity = 0. 134 | neg_similarity = 0. 135 | 136 | metrics = {'correspondence_loss': loss.item(), 'matching_keypoints': matching_keypoints, 137 | 'matching_descriptors': matching_descriptors, 'pos_similarity': pos_similarity, 138 | 'neg_similarity': neg_similarity} 139 | return loss, metrics 140 | 141 | 142 | def metrics_mean(l): 143 | # Compute the mean and return as Python number 144 | metrics = {} 145 | for e in l: 146 | for metric_name in e: 147 | if metric_name not in metrics: 148 | metrics[metric_name] = [] 149 | metrics[metric_name].append(e[metric_name]) 150 | 151 | for metric_name in metrics: 152 | metrics[metric_name] = np.mean(np.array(metrics[metric_name])) 153 | 154 | return metrics 155 | 156 | 157 | def squared_euclidean_distance(x, y): 158 | ''' 159 | Compute squared Euclidean distance 160 | Input: x is Nxd matrix 161 | y is Mxd matirx 162 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 163 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 164 | Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3 165 | ''' 166 | x_norm = (x ** 2).sum(1).view(-1, 1) 167 | y_t = torch.transpose(y, 0, 1) 168 | y_norm = (y ** 2).sum(1).view(1, -1) 169 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 170 | return torch.clamp(dist, 0.0, np.inf) 171 | -------------------------------------------------------------------------------- /models/minkfpn.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import torch.nn as nn 4 | import MinkowskiEngine as ME 5 | from MinkowskiEngine.modules.resnet_block import BasicBlock 6 | from models.resnet import ResNetBase 7 | 8 | 9 | class MinkFPN(ResNetBase): 10 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 11 | def __init__(self, in_channels, out_channels, num_top_down=1, conv0_kernel_size=5, block=BasicBlock, 12 | layers=(1, 1, 1), planes=(32, 64, 64)): 13 | assert len(layers) == len(planes) 14 | assert 1 <= len(layers) 15 | assert 0 <= num_top_down <= len(layers) 16 | self.num_bottom_up = len(layers) 17 | self.num_top_down = num_top_down 18 | self.conv0_kernel_size = conv0_kernel_size 19 | self.block = block 20 | self.layers = layers 21 | self.planes = planes 22 | self.lateral_dim = out_channels 23 | self.init_dim = planes[0] 24 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 25 | 26 | def network_initialization(self, in_channels, out_channels, D): 27 | assert len(self.layers) == len(self.planes) 28 | assert len(self.planes) == self.num_bottom_up 29 | 30 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 31 | self.bn = nn.ModuleList() # Bottom-up BatchNorms 32 | self.blocks = nn.ModuleList() # Bottom-up blocks 33 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions 34 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections 35 | 36 | # The first convolution is special case, with kernel size = 5 37 | self.inplanes = self.planes[0] 38 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, 39 | dimension=D) 40 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 41 | 42 | for plane, layer in zip(self.planes, self.layers): 43 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 44 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes)) 45 | self.blocks.append(self._make_layer(self.block, plane, layer)) 46 | 47 | # Lateral connections 48 | for i in range(self.num_top_down): 49 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1, 50 | stride=1, dimension=D)) 51 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2, 52 | stride=2, dimension=D)) 53 | # There's one more lateral connection than top-down TConv blocks 54 | if self.num_top_down < self.num_bottom_up: 55 | # Lateral connection from Conv block 1 or above 56 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1, 57 | stride=1, dimension=D)) 58 | else: 59 | # Lateral connection from Con0 block 60 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[0], self.lateral_dim, kernel_size=1, 61 | stride=1, dimension=D)) 62 | 63 | self.relu = ME.MinkowskiReLU(inplace=True) 64 | 65 | def forward(self, x): 66 | # *** BOTTOM-UP PASS *** 67 | # First bottom-up convolution is special (with bigger stride) 68 | feature_maps = [] 69 | x = self.conv0(x) 70 | x = self.bn0(x) 71 | x = self.relu(x) 72 | if self.num_top_down == self.num_bottom_up: 73 | feature_maps.append(x) 74 | 75 | # BOTTOM-UP PASS 76 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)): 77 | x = conv(x) # Decreases spatial resolution (conv stride=2) 78 | x = bn(x) 79 | x = self.relu(x) 80 | x = block(x) 81 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1: 82 | feature_maps.append(x) 83 | 84 | assert len(feature_maps) == self.num_top_down 85 | 86 | x = self.conv1x1[0](x) 87 | 88 | # TOP-DOWN PASS 89 | for ndx, tconv in enumerate(self.tconvs): 90 | x = tconv(x) # Upsample using transposed convolution 91 | x = x + self.conv1x1[ndx+1](feature_maps[-ndx - 1]) 92 | 93 | return x 94 | -------------------------------------------------------------------------------- /models/minkloc.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import torch 4 | import MinkowskiEngine as ME 5 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 6 | 7 | from models.minkfpn import MinkFPN 8 | import layers.pooling as pooling 9 | from layers.senet_block import SEBasicBlock, SEBottleneck 10 | from layers.eca_block import ECABasicBlock 11 | 12 | 13 | class MinkLoc(torch.nn.Module): 14 | def __init__(self, in_channels, feature_size, output_dim, planes, layers, num_top_down, conv0_kernel_size, 15 | block='BasicBlock', pooling_method='GeM'): 16 | # block: Type of the network building block: BasicBlock or SEBasicBlock 17 | # add_linear_layers: Add linear layers at the end 18 | # dropout_p: dropout probability (None = no dropout) 19 | 20 | super().__init__() 21 | self.in_channels = in_channels 22 | self.feature_size = feature_size # Size of local features produced by local feature extraction block 23 | self.output_dim = output_dim # Dimensionality of the global descriptor produced by pooling layer 24 | self.block = block 25 | 26 | if block == 'BasicBlock': 27 | block_module = BasicBlock 28 | elif block == 'Bottleneck': 29 | block_module = Bottleneck 30 | elif block == 'SEBasicBlock': 31 | block_module = SEBasicBlock 32 | elif block == 'ECABasicBlock': 33 | block_module = ECABasicBlock 34 | else: 35 | raise NotImplementedError('Unsupported network block: {}'.format(block)) 36 | 37 | self.pooling_method = pooling_method 38 | self.backbone = MinkFPN(in_channels=in_channels, out_channels=self.feature_size, num_top_down=num_top_down, 39 | conv0_kernel_size=conv0_kernel_size, block=block_module, layers=layers, planes=planes) 40 | self.pooling = pooling.PoolingWrapper(pool_method=pooling_method, in_dim=self.feature_size, 41 | output_dim=output_dim) 42 | self.pooled_feature_size = self.pooling.output_dim # Number of channels returned by pooling layer 43 | 44 | def forward(self, batch): 45 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 46 | x = ME.SparseTensor(batch['features'], coordinates=batch['coords']) 47 | x = self.backbone(x) 48 | 49 | # x is (num_points, n_features) tensor 50 | assert x.shape[1] == self.feature_size, 'Backbone output tensor has: {} channels. Expected: {}'.format(x.shape[1], self.feature_size) 51 | 52 | x = self.pooling(x) 53 | if x.dim() == 3 and x.shape[2] == 1: 54 | # Reshape (batch_size, 55 | x = x.flatten(1) 56 | 57 | assert x.dim() == 2, 'Expected 2-dimensional tensor (batch_size,output_dim). Got {} dimensions.'.format(x.dim()) 58 | assert x.shape[1] == self.pooled_feature_size, 'Backbone output tensor has: {} channels. Expected: {}'.format(x.shape[1], self.pooled_feature_size) 59 | assert x.shape[1] == self.output_dim, 'Output tensor has: {} channels. Expected: {}'.format(x.shape[1], self.output_dim) 60 | # x is (batch_size, output_dim) tensor 61 | return {'global': x} 62 | 63 | def print_info(self): 64 | print('Model class: MinkLoc') 65 | n_params = sum([param.nelement() for param in self.parameters()]) 66 | print('Total parameters: {}'.format(n_params)) 67 | n_params = sum([param.nelement() for param in self.backbone.parameters()]) 68 | print('Backbone parameters: {}'.format(n_params)) 69 | print('Backbone building block: {}'.format(self.block)) 70 | print('Pooling method: {}'.format(self.pooling_method)) 71 | n_params = sum([param.nelement() for param in self.pooling.parameters()]) 72 | print('Pooling parameters: {}'.format(n_params)) 73 | print('# channels from feature extraction block: {}'.format(self.feature_size)) 74 | print('# channels from pooling block: {}'.format(self.pooled_feature_size)) 75 | print('# output channels : {}'.format(self.output_dim)) 76 | -------------------------------------------------------------------------------- /models/minkloc3d_mulran.txt: -------------------------------------------------------------------------------- 1 | # MinkLoc3D model 2 | [MODEL] 3 | model = MinkFPN 4 | mink_quantization_size = 0.3 5 | planes = 32,64,64 6 | layers = 1,1,1 7 | num_top_down = 1 8 | conv0_kernel_size = 5 9 | feature_size = 256 10 | block=ECABasicBlock 11 | pooling=GeM 12 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | from layers.eca_block import ECABasicBlock 4 | from models.minkgl import MinkHead, MinkTrunk, MinkGL 5 | 6 | 7 | from models.minkloc import MinkLoc 8 | from third_party.minkloc3d.minkloc import MinkLoc3D 9 | from misc.utils import ModelParams 10 | 11 | 12 | def model_factory(model_params: ModelParams): 13 | in_channels = 1 14 | 15 | if model_params.model == 'MinkLoc': 16 | model = MinkLoc(in_channels=in_channels, feature_size=model_params.feature_size, 17 | output_dim=model_params.output_dim, planes=model_params.planes, 18 | layers=model_params.layers, num_top_down=model_params.num_top_down, 19 | conv0_kernel_size=model_params.conv0_kernel_size, block=model_params.block, 20 | pooling_method=model_params.pooling) 21 | elif model_params.model == 'MinkLoc3D': 22 | model = MinkLoc3D() 23 | elif 'egonn' in model_params.model: 24 | model = create_egonn_model(model_params) 25 | else: 26 | raise NotImplementedError('Model not implemented: {}'.format(model_params.model)) 27 | 28 | return model 29 | 30 | 31 | def create_egonn_model(model_params: ModelParams): 32 | model_name = model_params.model 33 | 34 | global_normalize = False 35 | local_normalize = True 36 | 37 | if model_name == 'egonn': 38 | # THIS IS OUR BEST MODEL 39 | block = ECABasicBlock 40 | planes = [32, 64, 64, 128, 128, 128, 128] 41 | layers = [1, 1, 1, 1, 1, 1, 1] 42 | 43 | global_in_levels = [5, 6, 7] 44 | global_map_channels = 128 45 | global_descriptor_size = 256 46 | 47 | local_in_levels = [3, 4] 48 | local_map_channels = 64 49 | local_descriptor_size = 128 50 | 51 | else: 52 | raise NotImplementedError(f'Unknown model: {model_name}') 53 | 54 | # Planes list number of channels for level 1 and above 55 | global_in_channels = [planes[i-1] for i in global_in_levels] 56 | head_global = MinkHead(global_in_levels, global_in_channels, global_map_channels) 57 | 58 | if len(local_in_levels) > 0: 59 | local_in_channels = [planes[i-1] for i in local_in_levels] 60 | head_local = MinkHead(local_in_levels, local_in_channels, local_map_channels) 61 | else: 62 | head_local = None 63 | 64 | min_out_level = len(planes) 65 | if len(global_in_levels) > 0: 66 | min_out_level = min(min_out_level, min(global_in_levels)) 67 | if len(local_in_levels) > 0: 68 | min_out_level = min(min_out_level, min(local_in_levels)) 69 | 70 | trunk = MinkTrunk(in_channels=1, planes=planes, layers=layers, conv0_kernel_size=5, block=block, 71 | min_out_level=min_out_level) 72 | 73 | net = MinkGL(trunk, local_head=head_local, local_descriptor_size=local_descriptor_size, 74 | local_normalize=local_normalize, global_head=head_global, 75 | global_descriptor_size=global_descriptor_size, global_pool_method='GeM', 76 | global_normalize=global_normalize, quantizer=model_params.quantizer) 77 | 78 | return net -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import torch.nn as nn 26 | 27 | import MinkowskiEngine as ME 28 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 29 | 30 | 31 | class ResNetBase(nn.Module): 32 | block = None 33 | layers = () 34 | init_dim = 64 35 | planes = (64, 128, 256, 512) 36 | 37 | def __init__(self, in_channels, out_channels, D=3): 38 | nn.Module.__init__(self) 39 | self.D = D 40 | assert self.block is not None 41 | 42 | self.network_initialization(in_channels, out_channels, D) 43 | self.weight_initialization() 44 | 45 | def network_initialization(self, in_channels, out_channels, D): 46 | self.inplanes = self.init_dim 47 | self.conv1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 49 | 50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 51 | self.relu = ME.MinkowskiReLU(inplace=True) 52 | 53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 54 | 55 | self.layer1 = self._make_layer( 56 | self.block, self.planes[0], self.layers[0], stride=2) 57 | self.layer2 = self._make_layer( 58 | self.block, self.planes[1], self.layers[1], stride=2) 59 | self.layer3 = self._make_layer( 60 | self.block, self.planes[2], self.layers[2], stride=2) 61 | self.layer4 = self._make_layer( 62 | self.block, self.planes[3], self.layers[3], stride=2) 63 | 64 | self.conv5 = ME.MinkowskiConvolution( 65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 67 | 68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling() 69 | 70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 71 | 72 | def weight_initialization(self): 73 | for m in self.modules(): 74 | if isinstance(m, ME.MinkowskiConvolution): 75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 76 | 77 | if isinstance(m, ME.MinkowskiBatchNorm): 78 | nn.init.constant_(m.bn.weight, 1) 79 | nn.init.constant_(m.bn.bias, 0) 80 | 81 | def _make_layer(self, 82 | block, 83 | planes, 84 | blocks, 85 | stride=1, 86 | dilation=1, 87 | bn_momentum=0.1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | ME.MinkowskiConvolution( 92 | self.inplanes, 93 | planes * block.expansion, 94 | kernel_size=1, 95 | stride=stride, 96 | dimension=self.D), 97 | ME.MinkowskiBatchNorm(planes * block.expansion)) 98 | layers = [] 99 | layers.append( 100 | block( 101 | self.inplanes, 102 | planes, 103 | stride=stride, 104 | dilation=dilation, 105 | downsample=downsample, 106 | dimension=self.D)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append( 110 | block( 111 | self.inplanes, 112 | planes, 113 | stride=1, 114 | dilation=dilation, 115 | dimension=self.D)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.bn1(x) 122 | x = self.relu(x) 123 | x = self.pool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.conv5(x) 131 | x = self.bn5(x) 132 | x = self.relu(x) 133 | 134 | x = self.glob_avg(x) 135 | return self.final(x) 136 | 137 | 138 | class ResNet14(ResNetBase): 139 | BLOCK = BasicBlock 140 | LAYERS = (1, 1, 1, 1) 141 | 142 | 143 | class ResNet18(ResNetBase): 144 | BLOCK = BasicBlock 145 | LAYERS = (2, 2, 2, 2) 146 | 147 | 148 | class ResNet34(ResNetBase): 149 | BLOCK = BasicBlock 150 | LAYERS = (3, 4, 6, 3) 151 | 152 | 153 | class ResNet50(ResNetBase): 154 | BLOCK = Bottleneck 155 | LAYERS = (3, 4, 6, 3) 156 | 157 | 158 | class ResNet101(ResNetBase): 159 | BLOCK = Bottleneck 160 | LAYERS = (3, 4, 23, 3) 161 | 162 | -------------------------------------------------------------------------------- /third_party/minkloc3d/minkloc.py: -------------------------------------------------------------------------------- 1 | # Adapted MinkLoc3D model 2 | 3 | import torch 4 | import torch.nn as nn 5 | from models.minkfpn import MinkFPN 6 | import MinkowskiEngine as ME 7 | 8 | 9 | class MinkLoc3D(torch.nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.feature_size = 256 13 | self.output_dim = 256 14 | self.backbone = MinkFPN(in_channels=1, out_channels=self.feature_size, num_top_down=1, 15 | conv0_kernel_size=5, layers=[1, 1, 1], planes=[32, 64, 64]) 16 | assert self.feature_size == self.output_dim, 'output_dim must be the same as feature_size' 17 | self.pooling = GeM() 18 | 19 | def forward(self, batch, disable_local_head: bool = True): 20 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 21 | assert disable_local_head, f"MinkLoc3D model has only the global head" 22 | x = ME.SparseTensor(batch['features'], coordinates=batch['coords']) 23 | x = self.backbone(x) 24 | 25 | # x is (num_points, n_features) tensor 26 | assert x.shape[1] == self.feature_size, 'Backbone output tensor has: {} channels. Expected: {}'.format(x.shape[1], self.feature_size) 27 | x = self.pooling(x) 28 | assert x.dim() == 2, 'Expected 2-dimensional tensor (batch_size,output_dim). Got {} dimensions.'.format(x.dim()) 29 | assert x.shape[1] == self.output_dim, 'Output tensor has: {} channels. Expected: {}'.format(x.shape[1], self.output_dim) 30 | # x is (batch_size, output_dim) tensor 31 | return {'global': x} 32 | 33 | def print_info(self): 34 | print('Model class: MinkLoc') 35 | n_params = sum([param.nelement() for param in self.parameters()]) 36 | print('Total parameters: {}'.format(n_params)) 37 | n_params = sum([param.nelement() for param in self.backbone.parameters()]) 38 | print('Backbone parameters: {}'.format(n_params)) 39 | n_params = sum([param.nelement() for param in self.pooling.parameters()]) 40 | print('Aggregation parameters: {}'.format(n_params)) 41 | if hasattr(self.backbone, 'print_info'): 42 | self.backbone.print_info() 43 | if hasattr(self.pooling, 'print_info'): 44 | self.pooling.print_info() 45 | 46 | 47 | class GeM(nn.Module): 48 | def __init__(self, p=3, eps=1e-6): 49 | super(GeM, self).__init__() 50 | self.p = nn.Parameter(torch.ones(1) * p) 51 | self.eps = eps 52 | self.f = ME.MinkowskiGlobalAvgPooling() 53 | 54 | def forward(self, x: ME.SparseTensor): 55 | # This implicitly applies ReLU on x (clamps negative values) 56 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 57 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling 58 | return temp.F.pow(1./self.p) # Return (batch_size, n_features) tensor 59 | 60 | -------------------------------------------------------------------------------- /third_party/pypcd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Read and write PCL .pcd files in python. 3 | dimatura@cmu.edu, 2013-2018 4 | - TODO better API for wacky operations. 5 | - TODO add a cli for common operations. 6 | - TODO deal properly with padding 7 | - TODO deal properly with multicount fields 8 | - TODO better support for rgb nonsense 9 | """ 10 | 11 | import re 12 | import struct 13 | import copy 14 | from io import StringIO as sio 15 | import numpy as np 16 | import warnings 17 | import lzf 18 | 19 | 20 | numpy_pcd_type_mappings = [(np.dtype('float32'), ('F', 4)), 21 | (np.dtype('float64'), ('F', 8)), 22 | (np.dtype('uint8'), ('U', 1)), 23 | (np.dtype('uint16'), ('U', 2)), 24 | (np.dtype('uint32'), ('U', 4)), 25 | (np.dtype('uint64'), ('U', 8)), 26 | (np.dtype('int16'), ('I', 2)), 27 | (np.dtype('int32'), ('I', 4)), 28 | (np.dtype('int64'), ('I', 8))] 29 | numpy_type_to_pcd_type = dict(numpy_pcd_type_mappings) 30 | pcd_type_to_numpy_type = dict((q, p) for (p, q) in numpy_pcd_type_mappings) 31 | 32 | 33 | def parse_header(lines): 34 | """ Parse header of PCD files. 35 | """ 36 | metadata = {} 37 | for ln in lines: 38 | if ln.startswith('#') or len(ln) < 2: 39 | continue 40 | match = re.match('(\w+)\s+([\w\s\.]+)', str(ln)) 41 | if not match: 42 | warnings.warn("warning: can't understand line: %s" % ln) 43 | continue 44 | key, value = match.group(1).lower(), match.group(2) 45 | if key == 'version': 46 | metadata[key] = value 47 | elif key in ('fields', 'type'): 48 | metadata[key] = value.split() 49 | elif key in ('size', 'count'): 50 | #print('found size and count k %s v %s '% (key, value)) 51 | metadata[key] = list(map(int, value.split())) 52 | #print(list(map(int,value.split()))) 53 | elif key in ('width', 'height', 'points'): 54 | metadata[key] = int(value) 55 | elif key == 'viewpoint': 56 | metadata[key] = map(float, value.split()) 57 | elif key == 'data': 58 | metadata[key] = value.strip().lower() 59 | # TODO apparently count is not required? 60 | # add some reasonable defaults 61 | if 'count' not in metadata: 62 | metadata['count'] = [1]*len(metadata['fields']) 63 | if 'viewpoint' not in metadata: 64 | metadata['viewpoint'] = [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] 65 | if 'version' not in metadata: 66 | metadata['version'] = '.7' 67 | return metadata 68 | 69 | 70 | def write_header(metadata, rename_padding=False): 71 | """ Given metadata as dictionary, return a string header. 72 | """ 73 | template = """\ 74 | VERSION {version} 75 | FIELDS {fields} 76 | SIZE {size} 77 | TYPE {type} 78 | COUNT {count} 79 | WIDTH {width} 80 | HEIGHT {height} 81 | VIEWPOINT {viewpoint} 82 | POINTS {points} 83 | DATA {data} 84 | """ 85 | str_metadata = metadata.copy() 86 | 87 | if not rename_padding: 88 | str_metadata['fields'] = ' '.join(metadata['fields']) 89 | else: 90 | new_fields = [] 91 | for f in metadata['fields']: 92 | if f == '_': 93 | new_fields.append('padding') 94 | else: 95 | new_fields.append(f) 96 | str_metadata['fields'] = ' '.join(new_fields) 97 | str_metadata['size'] = ' '.join(map(str, metadata['size'])) 98 | str_metadata['type'] = ' '.join(metadata['type']) 99 | str_metadata['count'] = ' '.join(map(str, metadata['count'])) 100 | str_metadata['width'] = str(metadata['width']) 101 | str_metadata['height'] = str(metadata['height']) 102 | str_metadata['viewpoint'] = ' '.join(map(str, metadata['viewpoint'])) 103 | str_metadata['points'] = str(metadata['points']) 104 | tmpl = template.format(**str_metadata) 105 | return tmpl 106 | 107 | 108 | def _metadata_is_consistent(metadata): 109 | """ Sanity check for metadata. Just some basic checks. 110 | """ 111 | checks = [] 112 | required = ('version', 'fields', 'size', 'width', 'height', 'points', 113 | 'viewpoint', 'data') 114 | for f in required: 115 | if f not in metadata: 116 | print('%s required' % f) 117 | checks.append((lambda m: all([k in m for k in required]), 118 | 'missing field')) 119 | checks.append((lambda m: len(m['type']) == len(list(m['count'])) == 120 | len(m['fields']), 121 | 'length of type, count and fields must be equal')) 122 | checks.append((lambda m: m['height'] > 0, 123 | 'height must be greater than 0')) 124 | checks.append((lambda m: m['width'] > 0, 125 | 'width must be greater than 0')) 126 | checks.append((lambda m: m['points'] > 0, 127 | 'points must be greater than 0')) 128 | checks.append((lambda m: m['data'].lower() in ('ascii', 'binary', 129 | 'binary_compressed'), 130 | 'unknown data type:' 131 | 'should be ascii/binary/binary_compressed')) 132 | ok = True 133 | for check, msg in checks: 134 | if not check(metadata): 135 | print('error:', msg) 136 | ok = False 137 | return ok 138 | 139 | # def pcd_type_to_numpy(pcd_type, pcd_sz): 140 | # """ convert from a pcd type string and size to numpy dtype.""" 141 | # typedict = {'F' : { 4:np.float32, 8:np.float64 }, 142 | # 'I' : { 1:np.int8, 2:np.int16, 4:np.int32, 8:np.int64 }, 143 | # 'U' : { 1:np.uint8, 2:np.uint16, 4:np.uint32 , 8:np.uint64 }} 144 | # return typedict[pcd_type][pcd_sz] 145 | 146 | 147 | def _build_dtype(metadata): 148 | """ Build numpy structured array dtype from pcl metadata. 149 | Note that fields with count > 1 are 'flattened' by creating multiple 150 | single-count fields. 151 | *TODO* allow 'proper' multi-count fields. 152 | """ 153 | fieldnames = [] 154 | typenames = [] 155 | for f, c, t, s in zip(metadata['fields'], 156 | metadata['count'], 157 | metadata['type'], 158 | metadata['size']): 159 | np_type = pcd_type_to_numpy_type[(t, s)] 160 | if c == 1: 161 | fieldnames.append(f) 162 | typenames.append(np_type) 163 | else: 164 | fieldnames.extend(['%s_%04d' % (f, i) for i in range(c)]) 165 | typenames.extend([np_type]*c) 166 | dtype = np.dtype(list(zip(fieldnames, typenames))) 167 | return dtype 168 | 169 | 170 | def build_ascii_fmtstr(pc): 171 | """ Make a format string for printing to ascii. 172 | Note %.8f is minimum for rgb. 173 | """ 174 | fmtstr = [] 175 | for t, cnt in zip(pc.type, pc.count): 176 | if t == 'F': 177 | fmtstr.extend(['%.10f']*cnt) 178 | elif t == 'I': 179 | fmtstr.extend(['%d']*cnt) 180 | elif t == 'U': 181 | fmtstr.extend(['%u']*cnt) 182 | else: 183 | raise ValueError("don't know about type %s" % t) 184 | return fmtstr 185 | 186 | 187 | def parse_ascii_pc_data(f, dtype, metadata): 188 | """ Use numpy to parse ascii pointcloud data. 189 | """ 190 | return np.loadtxt(f, dtype=dtype, delimiter=' ') 191 | 192 | 193 | def parse_binary_pc_data(f, dtype, metadata): 194 | rowstep = metadata['points']*dtype.itemsize 195 | # for some reason pcl adds empty space at the end of files 196 | buf = f.read(rowstep) 197 | return np.fromstring(buf, dtype=dtype) 198 | 199 | 200 | def parse_binary_compressed_pc_data(f, dtype, metadata): 201 | """ Parse lzf-compressed data. 202 | Format is undocumented but seems to be: 203 | - compressed size of data (uint32) 204 | - uncompressed size of data (uint32) 205 | - compressed data 206 | - junk 207 | """ 208 | fmt = 'II' 209 | compressed_size, uncompressed_size =\ 210 | struct.unpack(fmt, f.read(struct.calcsize(fmt))) 211 | compressed_data = f.read(compressed_size) 212 | # TODO what to use as second argument? if buf is None 213 | # (compressed > uncompressed) 214 | # should we read buf as raw binary? 215 | buf = lzf.decompress(compressed_data, uncompressed_size) 216 | if len(buf) != uncompressed_size: 217 | raise IOError('Error decompressing data') 218 | # the data is stored field-by-field 219 | # JK: metadata['width'] replaced with metadata['width'] * metadata['height'] 220 | pc_data = np.zeros(metadata['width'] * metadata['height'], dtype=dtype) 221 | ix = 0 222 | for dti in range(len(dtype)): 223 | dt = dtype[dti] 224 | bytes = dt.itemsize * metadata['width'] * metadata['height'] 225 | column = np.fromstring(buf[ix:(ix+bytes)], dt) 226 | pc_data[dtype.names[dti]] = column 227 | ix += bytes 228 | return pc_data 229 | 230 | 231 | def point_cloud_from_fileobj(f): 232 | """ Parse pointcloud coming from file object f 233 | """ 234 | header = [] 235 | while True: 236 | ln = f.readline().strip().decode('ascii') 237 | header.append(ln) 238 | if ln.startswith('DATA'): 239 | metadata = parse_header(header) 240 | dtype = _build_dtype(metadata) 241 | break 242 | if metadata['data'] == 'ascii': 243 | pc_data = parse_ascii_pc_data(f, dtype, metadata) 244 | elif metadata['data'] == 'binary': 245 | pc_data = parse_binary_pc_data(f, dtype, metadata) 246 | elif metadata['data'] == 'binary_compressed': 247 | pc_data = parse_binary_compressed_pc_data(f, dtype, metadata) 248 | else: 249 | print('DATA field is neither "ascii" or "binary" or\ 250 | "binary_compressed"') 251 | return PointCloud(metadata, pc_data) 252 | 253 | 254 | def point_cloud_from_path(fname): 255 | """ load point cloud in binary format 256 | """ 257 | with open(fname, 'rb') as f: 258 | pc = point_cloud_from_fileobj(f) 259 | return pc 260 | 261 | 262 | def point_cloud_from_buffer(buf): 263 | fileobj = sio.StringIO(buf) 264 | pc = point_cloud_from_fileobj(fileobj) 265 | fileobj.close() # necessary? 266 | return pc 267 | 268 | 269 | def update_field(pc, field, pc_data): 270 | """ Updates field in-place. 271 | """ 272 | pc.pc_data[field] = pc_data 273 | return pc 274 | 275 | 276 | def encode_rgb_for_pcl(rgb): 277 | """ Encode bit-packed RGB for use with PCL. 278 | :param rgb: Nx3 uint8 array with RGB values. 279 | :rtype: Nx1 float32 array with bit-packed RGB, for PCL. 280 | """ 281 | assert(rgb.dtype == np.uint8) 282 | assert(rgb.ndim == 2) 283 | assert(rgb.shape[1] == 3) 284 | rgb = rgb.astype(np.uint32) 285 | rgb = np.array((rgb[:, 0] << 16) | (rgb[:, 1] << 8) | (rgb[:, 2] << 0), 286 | dtype=np.uint32) 287 | rgb.dtype = np.float32 288 | return rgb 289 | 290 | 291 | def decode_rgb_from_pcl(rgb): 292 | """ Decode the bit-packed RGBs used by PCL. 293 | :param rgb: An Nx1 array. 294 | :rtype: Nx3 uint8 array with one column per color. 295 | """ 296 | 297 | rgb = rgb.copy() 298 | rgb.dtype = np.uint32 299 | r = np.asarray((rgb >> 16) & 255, dtype=np.uint8) 300 | g = np.asarray((rgb >> 8) & 255, dtype=np.uint8) 301 | b = np.asarray(rgb & 255, dtype=np.uint8) 302 | rgb_arr = np.zeros((len(rgb), 3), dtype=np.uint8) 303 | rgb_arr[:, 0] = r 304 | rgb_arr[:, 1] = g 305 | rgb_arr[:, 2] = b 306 | return rgb_arr 307 | 308 | 309 | class PointCloud(object): 310 | """ Wrapper for point cloud data. 311 | The variable members of this class parallel the ones used by 312 | the PCD metadata (and similar to PCL and ROS PointCloud2 messages), 313 | ``pc_data`` holds the actual data as a structured numpy array. 314 | The other relevant metadata variables are: 315 | - ``version``: Version, usually .7 316 | - ``fields``: Field names, e.g. ``['x', 'y' 'z']``. 317 | - ``size.`: Field sizes in bytes, e.g. ``[4, 4, 4]``. 318 | - ``count``: Counts per field e.g. ``[1, 1, 1]``. NB: Multi-count field 319 | support is sketchy. 320 | - ``width``: Number of points, for unstructured point clouds (assumed by 321 | most operations). 322 | - ``height``: 1 for unstructured point clouds (again, what we assume most 323 | of the time. 324 | - ``viewpoint``: A pose for the viewpoint of the cloud, as 325 | x y z qw qx qy qz, e.g. ``[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]``. 326 | - ``points``: Number of points. 327 | - ``type``: Data type of each field, e.g. ``[F, F, F]``. 328 | - ``data``: Data storage format. One of ``ascii``, ``binary`` or ``binary_compressed``. 329 | See `PCL docs `__ 330 | for more information. 331 | """ 332 | 333 | def __init__(self, metadata, pc_data): 334 | self.metadata_keys = metadata.keys() 335 | self.__dict__.update(metadata) 336 | self.pc_data = pc_data 337 | self.check_sanity() 338 | 339 | def get_metadata(self): 340 | """ returns copy of metadata """ 341 | metadata = {} 342 | for k in self.metadata_keys: 343 | metadata[k] = copy.copy(getattr(self, k)) 344 | return metadata 345 | 346 | def check_sanity(self): 347 | # pdb.set_trace() 348 | md = self.get_metadata() 349 | assert(_metadata_is_consistent(md)) 350 | assert(len(self.pc_data) == self.points) 351 | assert(self.width*self.height == self.points) 352 | assert(len(self.fields) == len(self.count)) 353 | assert(len(self.fields) == len(self.type)) 354 | 355 | def copy(self): 356 | new_pc_data = np.copy(self.pc_data) 357 | new_metadata = self.get_metadata() 358 | return PointCloud(new_metadata, new_pc_data) 359 | 360 | @staticmethod 361 | def from_path(fname): 362 | return point_cloud_from_path(fname) 363 | 364 | @staticmethod 365 | def from_fileobj(fileobj): 366 | return point_cloud_from_fileobj(fileobj) 367 | 368 | -------------------------------------------------------------------------------- /third_party/scan_context/evaluate_scan_context.py: -------------------------------------------------------------------------------- 1 | # Evaluation of ScanContext 2 | 3 | import argparse 4 | import numpy as np 5 | import tqdm 6 | import os 7 | import random 8 | from typing import List 9 | 10 | from datasets.base_datasets import EvaluationSet, get_pointcloud_loader 11 | from third_party.scan_context.scan_context import ScanContextManager 12 | from datasets.dataset_utils import preprocess_pointcloud 13 | 14 | 15 | def print_results(stats): 16 | recall1 = stats['recall1'] 17 | for r in recall1: 18 | print(f"Radius: {r} [m] : ", end='') 19 | for x in recall1[r]: 20 | print("{:0.3f}, ".format(x), end='') 21 | print("\n") 22 | 23 | 24 | def evaluate(dataset_root: str, dataset_type: str, eval_set_pickle: str, radius: List[float], k: int = 50, 25 | n_samples=None, reranking=True, debug=False): 26 | # radius: list of thresholds (in meters) to consider an element from the map sequence a true positive 27 | # k: maximum number of nearest neighbours to consider 28 | # n_samples: number of samples taken from a query sequence (None=all query elements) 29 | # cache_prefix: if given computed embeddings will be cached in pickles with a given prefix 30 | 31 | load_pc_fn = get_pointcloud_loader(dataset_type) 32 | 33 | # Use default parameters 34 | # load_pc methods for all 3 datasets (Kitti, Mulran, Southbay) remove ground plane, 35 | # so default lidar_height parameter (=2 meters) is OK for all datasets, as ground points are already removed 36 | sc = ScanContextManager() 37 | 38 | eval_set_filepath = os.path.join(dataset_root, eval_set_pickle) 39 | assert os.path.exists(eval_set_filepath), f'Cannot access evaluation set pickle: {eval_set_filepath}' 40 | eval_set = EvaluationSet() 41 | eval_set.load(eval_set_filepath) 42 | 43 | # Compute database point clouds descriptors 44 | for ndx, e in tqdm.tqdm(enumerate(eval_set.map_set)): 45 | scan_filepath = os.path.join(dataset_root, e.rel_scan_filepath) 46 | pc = load_pc_fn(scan_filepath) 47 | sc.add_node(pc) 48 | 49 | map_positions = eval_set.get_map_positions() 50 | query_positions = eval_set.get_query_positions() 51 | 52 | # Dictionary to store the number of true positives for different radius and NN number 53 | tp = {r: [0] * k for r in radius} 54 | 55 | if n_samples is None or len(eval_set.query_set) < n_samples: 56 | query_indexes = list(range(len(eval_set.query_set))) 57 | n_samples = len(eval_set.query_set) 58 | else: 59 | query_indexes = random.sample(range(len(eval_set.query_set)), n_samples) 60 | 61 | # Randomly sample n_samples clouds from the query sequence and NN search in the target sequence 62 | for query_ndx in tqdm.tqdm(query_indexes): 63 | # Check if the query element has a true match within each radius 64 | query_pos = query_positions[query_ndx] 65 | 66 | # Get the nearest neighbours 67 | query_scan_filepath = os.path.join(dataset_root, eval_set.query_set[query_ndx].rel_scan_filepath) 68 | query_pc = load_pc_fn(query_scan_filepath) 69 | 70 | nn_ndx, _, _ = sc.query(query_pc, k, reranking=reranking) 71 | 72 | assert sc.curr_node_idx == len(map_positions), f"{sc.curr_node_idx} {len(map_positions)}" 73 | assert np.max(nn_ndx) < len(map_positions) 74 | 75 | # Euclidean distance between the query and nn 76 | delta = query_pos - map_positions[nn_ndx] # (k, 2) array 77 | euclid_dist = np.linalg.norm(delta, axis=1) # (k,) array 78 | # Count true positives for different radius and NN number 79 | tp = {r: [tp[r][nn] + (1 if (euclid_dist[:nn+1] <= r).any() else 0) for nn in range(k)] for r in radius} 80 | if debug: 81 | break 82 | 83 | recall1 = {r: [tp[r][nn]/n_samples for nn in range(k)] for r in radius} 84 | return {'recall1': recall1} 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser(description='Evaluate ScanContext model') 89 | parser.add_argument('--dataset_root', type=str, required=True, help='Path to the dataset root') 90 | parser.add_argument('--dataset_type', type=str, required=True, choices=['mulran', 'southbay', 'kitti']) 91 | parser.add_argument('--eval_set', type=str, required=True) 92 | parser.add_argument('--radius', type=list, default=[5, 20], help='True Positive thresholds in meters') 93 | parser.add_argument('--nn', type=int, default=50, help='Maximum number of nearest neighbours to consider') 94 | parser.add_argument('--n_samples', type=int, default=None, help='Number of elements sampled from the query sequence') 95 | 96 | args = parser.parse_args() 97 | print(f'Dataset root: {args.dataset_root}') 98 | print(f'Dataset type: {args.dataset_type}') 99 | print(f'Evaluation set: {args.eval_set}') 100 | print(f'Radius: {args.radius} [m]') 101 | print(f'Maximum number of nearest neighbours for re-ranking: {args.nn}') 102 | print(f'Number of samples from the query set: {args.n_samples}') 103 | print('') 104 | 105 | # Cache computed embeddings, prefixing with a network weights file name 106 | print('Reranking: FALSE') 107 | stats = evaluate(args.dataset_root, args.dataset_type, args.eval_set, radius=args.radius, n_samples=args.n_samples, 108 | reranking=False, k=args.nn) 109 | print_results(stats) 110 | 111 | print('Reranking: TRUE') 112 | stats = evaluate(args.dataset_root, args.dataset_type, args.eval_set, radius=args.radius, n_samples=args.n_samples, 113 | reranking=True, k=args.nn) 114 | print_results(stats) 115 | -------------------------------------------------------------------------------- /third_party/scan_context/scan_context.py: -------------------------------------------------------------------------------- 1 | # Code based on ScanContext implementation: https://github.com/irapkaist/scancontext/blob/master/python/make_sc_example.py 2 | # Partially vectorized implementation by Jacek Komorowski 3 | 4 | import numpy as np 5 | import numpy_indexed as npi 6 | from sklearn.neighbors import KDTree 7 | 8 | 9 | def pt2rs(points, gap_ring, gap_sector): 10 | # np.arctan2 produces values in -pi..pi range 11 | theta = np.arctan2(points[:, 1], points[:, 0]) + np.pi 12 | eps = 1e-6 13 | 14 | theta = np.clip(theta, a_min=0., a_max=2*np.pi-eps) 15 | faraway = np.linalg.norm(points[:, 0:2], axis=1) 16 | 17 | idx_ring = (faraway // gap_ring).astype(int) 18 | idx_sector = (theta // gap_sector).astype(int) 19 | 20 | return idx_ring, idx_sector 21 | 22 | 23 | class ScanContext: 24 | def __init__(self, num_sector=60, num_ring=20, max_length=80, lidar_height=2.0): 25 | # lidar_height: app. lidar height above the ground, for Kitti is set to 2.0 26 | self.lidar_height = lidar_height 27 | self.num_sector = num_sector 28 | self.num_ring = num_ring 29 | self.max_length = max_length 30 | self.gap_ring = self.max_length / self.num_ring 31 | self.gap_sector = 2. * np.pi / self.num_sector 32 | 33 | def __call__(self, x): 34 | idx_ring, idx_sector = pt2rs(x, self.gap_ring, self.gap_sector) 35 | height = x[:, 2] + self.lidar_height 36 | 37 | # Filter out points that are self.max_length or further away 38 | mask = idx_ring < self.num_ring 39 | idx_ring = idx_ring[mask] 40 | idx_sector = idx_sector[mask] 41 | height = height[mask] 42 | 43 | assert idx_ring.shape == idx_sector.shape 44 | assert idx_ring.shape == height.shape 45 | 46 | # Convert idx_ring and idx_sector to a linear index 47 | idx_linear = idx_ring * self.num_sector + idx_sector 48 | idx, max_height = npi.group_by(idx_linear).max(height) 49 | 50 | sc = np.zeros([self.num_ring, self.num_sector]) 51 | # As per original ScanContext implementation, the minimum height is always non-negative 52 | # (they take max from 0 and height values) 53 | sc[idx // self.num_sector, idx % self.num_sector] = np.clip(max_height, a_min=0., a_max=None) 54 | 55 | return sc 56 | 57 | 58 | def distance_sc(sc1, sc2): 59 | # Distance between 2 scan context descriptors 60 | num_sectors = sc1.shape[1] 61 | 62 | # Repeate to move 1 columns 63 | _one_step = 1 # const 64 | sim_for_each_cols = np.zeros(num_sectors) 65 | for i in range(num_sectors): 66 | # Shift 67 | sc1 = np.roll(sc1, _one_step, axis=1) # columne shift 68 | 69 | # Compare 70 | sc1_norm = np.linalg.norm(sc1, axis=0) 71 | sc2_norm = np.linalg.norm(sc2, axis=0) 72 | mask = ~np.logical_or(np.isclose(sc1_norm, 0.), np.isclose(sc2_norm, 0.)) 73 | 74 | # Compute cosine similarity between columns of sc1 and sc2 75 | cossim = np.sum(np.multiply(sc1[:, mask], sc2[:, mask]), axis=0) / (sc1_norm[mask] * sc2_norm[mask]) 76 | 77 | sim_for_each_cols[i] = np.sum(cossim) / np.sum(mask) 78 | 79 | yaw_diff = (np.argmax(sim_for_each_cols) + 1) % sc1.shape[1] # because python starts with 0 80 | sim = np.max(sim_for_each_cols) 81 | dist = 1. - sim 82 | 83 | return dist, yaw_diff 84 | 85 | 86 | def sc2rk(sc): 87 | # Scan context to ring key 88 | return np.mean(sc, axis=1) 89 | 90 | 91 | class ScanContextManager: 92 | def __init__(self, num_sector=60, num_ring=20, max_length=80, lidar_height=2.0, max_capacity=100000): 93 | # max_capacity: maximum number of nodes 94 | self.num_sector = num_sector 95 | self.num_ring = num_ring 96 | self.max_length = max_length 97 | self.lidar_height = lidar_height 98 | self.max_capacity = max_capacity 99 | 100 | self.sc = ScanContext(self.num_sector, self.num_ring, self.max_length, self.lidar_height) 101 | self.scancontexts = np.zeros((self.max_capacity, self.num_ring, self.num_sector)) 102 | self.ringkeys = np.zeros((self.max_capacity, self.num_ring)) 103 | self.curr_node_idx = 0 104 | self.knn_tree_idx = -1 # Number of nodes for which KDTree tree was computed (recalculate KDTree only after new nodes are added) 105 | self.ringkey_tree = None 106 | 107 | def add_node(self, pc): 108 | # Compute and store the point cloud descriptor 109 | # pc: (N, 3) point cloud 110 | assert pc.ndim == 2 111 | assert pc.shape[1] == 3 112 | 113 | # Compute 114 | sc = self.sc(pc) # (num_ring, num_sector) array 115 | rk = sc2rk(sc) # (num_ring,) array 116 | 117 | self.scancontexts[self.curr_node_idx] = sc 118 | self.ringkeys[self.curr_node_idx] = rk 119 | self.curr_node_idx += 1 120 | assert self.curr_node_idx < self.max_capacity, f'Maximum ScanContextManager capacity exceeded: {self.max_capacity}' 121 | 122 | def query(self, query_pc, k=1, reranking=True): 123 | assert self.curr_node_idx > 0, 'Empty database' 124 | 125 | if self.curr_node_idx != self.knn_tree_idx: 126 | # recalculate KDTree 127 | self.ringkey_tree = KDTree(self.ringkeys[:self.curr_node_idx-1]) 128 | self.knn_tree_idx = self.curr_node_idx 129 | 130 | # find candidates 131 | query_sc = self.sc(query_pc) # (num_ring, num_sector) array 132 | query_rk = sc2rk(query_sc) # (num_ring,) array 133 | # KDTree query expects (n_samples, sample_dimension) array 134 | _, nn_ndx = self.ringkey_tree.query(query_rk.reshape(1, -1), k=k) 135 | # nncandidates_idx is (n_samples=1, sample_dimension) array 136 | nn_ndx = nn_ndx[0] 137 | 138 | # step 2 139 | sc_dist = np.zeros((k,)) 140 | sc_yaw_diff = np.zeros((k,)) 141 | 142 | if not reranking: 143 | return nn_ndx, None, None 144 | 145 | # Reranking using the full ScanContext descriptor 146 | for i in range(k): 147 | candidate_ndx = nn_ndx[i] 148 | candidate_sc = self.scancontexts[candidate_ndx] 149 | sc_dist[i], sc_yaw_diff[i] = distance_sc(candidate_sc, query_sc) 150 | 151 | reranking_order = np.argsort(sc_dist) 152 | nn_ndx = nn_ndx[reranking_order] 153 | sc_yaw_diff = sc_yaw_diff[reranking_order] 154 | sc_dist = sc_dist[reranking_order] 155 | 156 | return nn_ndx, sc_dist, sc_yaw_diff 157 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import argparse 4 | import torch 5 | 6 | from training.trainer import do_train 7 | from misc.utils import TrainingParams 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description='Train Minkowski Net embeddings using BatchHard negative mining') 12 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file') 13 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file') 14 | parser.add_argument('--debug', dest='debug', action='store_true') 15 | parser.set_defaults(debug=False) 16 | parser.add_argument('--visualize', dest='visualize', action='store_true') 17 | parser.set_defaults(visualize=False) 18 | 19 | args = parser.parse_args() 20 | print('Training config path: {}'.format(args.config)) 21 | print('Model config path: {}'.format(args.model_config)) 22 | print('Debug mode: {}'.format(args.debug)) 23 | print('Visualize: {}'.format(args.visualize)) 24 | 25 | params = TrainingParams(args.config, args.model_config) 26 | params.print() 27 | 28 | if args.debug: 29 | torch.autograd.set_detect_anomaly(True) 30 | 31 | if torch.cuda.is_available(): 32 | device = "cuda" 33 | else: 34 | device = "cpu" 35 | 36 | do_train(params, debug=args.debug, visualize=args.visualize, device=device) 37 | -------------------------------------------------------------------------------- /weights/model_egonn_20210916_1104.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jac99/Egonn/84d5cc2197c81792f3254b43d3273cdccf18294c/weights/model_egonn_20210916_1104.pth --------------------------------------------------------------------------------