├── LICENSE ├── README.md ├── config └── config_base.yaml ├── datasets ├── augmentation.py ├── dataset_utils.py ├── etna.py ├── oxford.py └── samplers.py ├── eval ├── evaluate.py └── utils.py ├── figures └── UMF_architecture.png ├── misc ├── lr_scheduler.py └── utils.py ├── models ├── UMF │ ├── UMFnet.py │ ├── UMFnet.yml │ ├── UMFnet_ransac.yml │ ├── UMFnet_superfeat.yml │ ├── __init__.py │ ├── resnet_fpn.py │ ├── spconv_backbone.py │ ├── utils │ │ ├── linear_attention.py │ │ ├── lit.py │ │ ├── lit3d.py │ │ ├── local_transformer.py │ │ ├── local_transformer3D.py │ │ ├── multimodal_fusion.py │ │ ├── patch_matcher.py │ │ ├── pooling.py │ │ ├── spconv_utils.py │ │ ├── swin_transformer.py │ │ └── transformer.py │ └── voxel_encoder.py ├── loss.py ├── losses │ ├── loss_utils.py │ ├── super_feature_losses.py │ └── truncated_smoothap.py └── model_factory.py └── training ├── train.py └── trainer.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UMF: Unifying Local and Global Multimodal Features 2 | ![alt text](figures/UMF_architecture.png "UMF architecture") 3 | 4 | 5 | ### Abstract 6 | 7 | Perceptual aliasing and weak textures pose significant challenges to the task of place recognition, hindering the performance of Simultaneous Localization and Mapping (SLAM) systems. This paper presents a novel model, called UMF (standing for Unifying Local and Global Multimodal Features) that 1) leverages multi-modality by cross-attention blocks between vision and LiDAR features, and 2)includes a re-ranking stage that re-orders based on local feature matching the top-k candidates retrieved using a global representation. Our experiments, particularly on sequences captured on a planetary-analogous environment, show that UMF outperforms significantly previous baselines in those challenging aliased environments. Since our work aims to enhance the reliability of SLAM in all situations, we also explore its performance on the widely used RobotCar dataset, for broader applicability. 8 | 9 | ### Citation 10 | If you find our work useful, please cite us: 11 | ``` 12 | @INPROCEEDINGS{10611563, 13 | author={García-Hernández, Alberto and Giubilato, Riccardo and Strobl, Klaus H. and Civera, Javier and Triebel, Rudolph}, 14 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}, 15 | title={Unifying Local and Global Multimodal Features for Place Recognition in Aliased and Low-Texture Environments}, 16 | year={2024}, 17 | volume={}, 18 | number={}, 19 | pages={3991-3998}, 20 | keywords={Visualization;Simultaneous localization and mapping;Laser radar;Codes;Fuses;Transformers;Data models}, 21 | doi={10.1109/ICRA57147.2024.10611563}} 22 | ``` 23 | 24 | 25 | ### Environment and Dependencies 26 | 27 | Code was tested using Python 3.8 with PyTorch 1.9.1 on Ubuntu 20.04 with CUDA 10.6. 28 | 29 | The following Python packages are required: 30 | * PyTorch 31 | * pytorch_metric_learning (version 1.0 or above) 32 | * spconv 33 | * einops 34 | * opencv-python 35 | 36 | 37 | 38 | * install [HOW](https://github.com/albertogarci/how) 39 | ``` 40 | git clone https://github.com/albertogarci/how 41 | export PYTHONPATH=${PYTHONPATH}:$(realpath how) 42 | ``` 43 | 44 | * install [cirtorch](https://github.com/filipradenovic/cnnimageretrieval-pytorch/) 45 | ``` 46 | wget "https://github.com/filipradenovic/cnnimageretrieval-pytorch/archive/v1.2.zip" 47 | unzip v1.2.zip 48 | rm v1.2.zip 49 | export PYTHONPATH=${PYTHONPATH}:$(realpath cnnimageretrieval) 50 | ``` 51 | 52 | Modify the `PYTHONPATH` environment variable to include absolute path to the project root folder: 53 | ```export PYTHONPATH 54 | export PYTHONPATH=$PYTHONPATH:/home/.../UMF 55 | export PYTHONPATH=$PYTHONPATH:/home_local/$USER/UMF 56 | export PYTHONPATH=$PYTHONPATH:/home/.../UMF 57 | export PYTHONPATH=$PYTHONPATH:/home_local/$USER/UMF 58 | export PYTHONPATH=${PYTHONPATH}:$(realpath how) 59 | export PYTHONPATH=${PYTHONPATH}:$(realpath cnnimageretrieval) 60 | ``` 61 | 62 | 63 | 64 | ### Training 65 | 66 | To train UMF following the procedure in the paper first download the Robotcar dataset. Otherwise, adapt the dataloader accordingly. 67 | 68 | Edit the configuration files: 69 | - `config_base.yaml` # select the etna or robotcar version 70 | - `./models/UMF/UMFnet.yml` # fusion model only 71 | - `./models/UMF/UMFnet_ransac.yml` # multimodal with ransac reranking 72 | - `./models/UMF/UMFnet_superfeat.yml` # multimodal with superfeatures reranking 73 | 74 | Modify `batch_size` parameter depending on the available GPU memory. 75 | 76 | 77 | Set `dataset_folder` parameter to the dataset root folder, where 3D point clouds are located. 78 | Set `image_path ` parameter to the path with RGB images corresponding to 3D point clouds, extracted from 79 | 80 | 81 | To train, run: 82 | 83 | ```train 84 | # Fusion only 85 | python train.py --config ../config/config_base.yaml --model_config ../models/UMFnet.yml 86 | # RANSAC 87 | python train.py --config ../config/config_base.yaml --model_config ../models/UMFnet_ransac.yml 88 | ``` 89 | We provide the pre-trained models for the Robotcar datatset ([link](https://drive.google.com/drive/folders/1MXOhMC6wxjU0FjsDM1GzUIzJJ0e-5mjQ?usp=sharing)). 90 | 91 | 92 | ### Evaluation 93 | 94 | To evaluate pretrained models run the following commands: 95 | 96 | ``` 97 | cd eval 98 | 99 | # Evaluate with superfeatures or ransac variant 100 | python evaluate.py --config ../config/config_base.yaml --model_config ../models/UMFnet_ransac.yml --weights 101 | 102 | ``` 103 | 104 | ### S3LI Dataset 105 | In order to use the S3LI dataset to generate train and test samples, check out this repo: https://github.com/DLR-RM/s3li-toolkit 106 | -------------------------------------------------------------------------------- /config/config_base.yaml: -------------------------------------------------------------------------------- 1 | DEFAULT: 2 | num_points: 4096 3 | use_rgb: True 4 | use_cloud: True 5 | dataset: "robotcar" # robotcar or etna 6 | dataset_folder: /home/user/benchmark_datasets 7 | image_path: /home/user/images4lidar_small20 8 | 9 | TRAIN: 10 | num_workers: 26 11 | train_step: 'single_step' # single_step, multistaged 12 | optimizer: "AdamW" 13 | scheduler: "CosineAnnealingLR" # MultiStepLR OneCycleLR CosineAnnealingLR LinearWarmupCosineAnnealingLR 14 | batch_size: 32 # 64 15 | val_batch_size: 88 # 256 16 | batch_size_limit: 88 # 2048 17 | batch_expansion_rate: 1.4 18 | batch_expansion_th: 0.5 19 | lr: 1e-5 20 | image_lr: 1e-5 21 | epochs: 200 22 | scheduler_milestones: [10, 30, 60] 23 | aug_mode: 1 24 | weight_decay: 5e-4 25 | warmup_epochs: 5 26 | loss: MultiBatchHardTripletLossWithMasksAugmented 27 | weights: [0.5, 0.5, 0.1] 28 | normalize_embeddings: False 29 | margin: 0.2 30 | train_file: training_queries_baseline.pickle 31 | val_file: validation_queries_baseline.pickle 32 | -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | from scipy.linalg import expm, norm 4 | import random 5 | import torch 6 | import torchvision.transforms as transforms 7 | 8 | 9 | class TrainTransform: 10 | def __init__(self, aug_mode): 11 | # 1 is default mode, no transform 12 | self.aug_mode = aug_mode 13 | if self.aug_mode == 0: 14 | self.transform = None 15 | return 16 | elif self.aug_mode == 1: 17 | t = [JitterPoints(sigma=0.001, clip=0.002), 18 | RemoveRandomPoints(r=(0.15, 0.3)), 19 | RandomTranslation(max_delta=0.02), 20 | #RandomRotation(angle_range=(-5, 5)), 21 | RemoveRandomBlock(p=0.5), 22 | ] 23 | else: 24 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 25 | self.transform = transforms.Compose(t) 26 | 27 | def __call__(self, e): 28 | if self.transform is not None: 29 | e = self.transform(e) 30 | return e 31 | 32 | 33 | class TrainRGBTransform: 34 | def __init__(self, aug_mode): 35 | # 1 is default mode, no transform 36 | self.aug_mode = aug_mode 37 | if self.aug_mode == 0: 38 | t = [ 39 | transforms.ToTensor(), 40 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] 41 | elif self.aug_mode == 1: 42 | t = [ 43 | 44 | transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), 45 | #transforms.RandomAffine(degrees=5, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=0.1), 46 | #transforms.RandomHorizontalFlip(p=0.5), 47 | transforms.ToTensor(), 48 | transforms.RandomErasing(scale=(0.1, 0.4)), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] 50 | else: 51 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 52 | self.transform = transforms.Compose(t) 53 | 54 | 55 | def __call__(self, e): 56 | if self.transform is not None: 57 | e = self.transform(e) 58 | return e 59 | 60 | 61 | class TrainGreyTransform: 62 | def __init__(self, aug_mode): 63 | # 1 is default mode, no transform 64 | self.aug_mode = aug_mode 65 | if self.aug_mode == 0: 66 | t = [transforms.Resize((224, 224)), 67 | transforms.ToTensor(), 68 | transforms.Normalize([0.5], [0.5])] 69 | elif self.aug_mode == 1: 70 | t = [ 71 | transforms.Resize((224, 224)), 72 | transforms.ColorJitter(brightness=0.25, contrast=0.2, saturation=0.2, hue=0.1), 73 | transforms.ToTensor(), 74 | transforms.RandomErasing(scale=(0.15, 0.4)), 75 | transforms.Normalize([0.5], [0.5])] 76 | else: 77 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 78 | self.transform = transforms.Compose(t) 79 | 80 | 81 | def __call__(self, e): 82 | if self.transform is not None: 83 | e = self.transform(e) 84 | return e 85 | 86 | class ValRGBTransform: 87 | def __init__(self): 88 | # 1 is default mode, no transform 89 | t = [ 90 | transforms.ToTensor(), 91 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] 92 | self.transform = transforms.Compose(t) 93 | 94 | def __call__(self, e): 95 | e = self.transform(e) 96 | return e 97 | 98 | class ValGreyTransform: 99 | def __init__(self): 100 | t = [transforms.Resize((224, 224)), 101 | transforms.ToTensor(), 102 | transforms.Normalize([0.5], [0.5])] 103 | self.transform = transforms.Compose(t) 104 | 105 | def __call__(self, e): 106 | e = self.transform(e) 107 | return e 108 | 109 | class RandomFlip: 110 | def __init__(self, p): 111 | # p = [p_x, p_y, p_z] probability of flipping each axis 112 | assert len(p) == 3 113 | assert 0 < sum(p) <= 1, 'sum(p) must be in (0, 1] range, is: {}'.format(sum(p)) 114 | self.p = p 115 | self.p_cum_sum = np.cumsum(p) 116 | 117 | def __call__(self, coords): 118 | r = random.random() 119 | if r <= self.p_cum_sum[0]: 120 | # Flip the first axis 121 | coords[..., 0] = -coords[..., 0] 122 | elif r <= self.p_cum_sum[1]: 123 | # Flip the second axis 124 | coords[..., 1] = -coords[..., 1] 125 | elif r <= self.p_cum_sum[2]: 126 | # Flip the third axis 127 | coords[..., 2] = -coords[..., 2] 128 | 129 | return coords 130 | 131 | 132 | class RandomRotation: 133 | def __init__(self, axis=None, max_theta=180, max_theta2=15): 134 | self.axis = axis 135 | self.max_theta = max_theta # Rotation around axis 136 | self.max_theta2 = max_theta2 # Smaller rotation in random direction 137 | 138 | def _M(self, axis, theta): 139 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)).astype(np.float32) 140 | 141 | def __call__(self, coords): 142 | if self.axis is not None: 143 | axis = self.axis 144 | else: 145 | axis = np.random.rand(3) - 0.5 146 | R = self._M(axis, (np.pi * self.max_theta / 180) * 2 * (np.random.rand(1) - 0.5)) 147 | if self.max_theta2 is None: 148 | coords = coords @ R 149 | else: 150 | R_n = self._M(np.random.rand(3) - 0.5, (np.pi * self.max_theta2 / 180) * 2 * (np.random.rand(1) - 0.5)) 151 | coords = coords @ R @ R_n 152 | 153 | return coords 154 | 155 | 156 | class RandomTranslation: 157 | def __init__(self, max_delta=0.05): 158 | self.max_delta = max_delta 159 | 160 | def __call__(self, coords): 161 | trans = self.max_delta * np.random.randn(1, 3) 162 | return coords + trans.astype(np.float32) 163 | 164 | 165 | class RandomScale: 166 | def __init__(self, min, max): 167 | self.scale = max - min 168 | self.bias = min 169 | 170 | def __call__(self, coords): 171 | s = self.scale * np.random.rand(1) + self.bias 172 | return coords * s.astype(np.float32) 173 | 174 | 175 | class RandomShear: 176 | def __init__(self, delta=0.1): 177 | self.delta = delta 178 | 179 | def __call__(self, coords): 180 | T = np.eye(3) + self.delta * np.random.randn(3, 3) 181 | return coords @ T.astype(np.float32) 182 | 183 | 184 | class JitterPoints: 185 | def __init__(self, sigma=0.01, clip=None, p=1.): 186 | assert 0 < p <= 1. 187 | assert sigma > 0. 188 | 189 | self.sigma = sigma 190 | self.clip = clip 191 | self.p = p 192 | 193 | def __call__(self, e): 194 | """ Randomly jitter points. jittering is per point. 195 | Input: 196 | BxNx3 array, original batch of point clouds 197 | Return: 198 | BxNx3 array, jittered batch of point clouds 199 | """ 200 | 201 | sample_shape = (e.shape[0],) 202 | if self.p < 1.: 203 | # Create a mask for points to jitter 204 | m = torch.distributions.categorical.Categorical(probs=torch.tensor([1 - self.p, self.p])) 205 | mask = m.sample(sample_shape=sample_shape) 206 | else: 207 | mask = torch.ones(sample_shape, dtype=torch.int64 ) 208 | 209 | mask = mask == 1 210 | jitter = self.sigma * torch.randn_like(e[mask]) 211 | 212 | if self.clip is not None: 213 | jitter = torch.clamp(jitter, min=-self.clip, max=self.clip) 214 | 215 | e[mask] = e[mask] + jitter 216 | return e 217 | 218 | 219 | class RemoveRandomPoints: 220 | def __init__(self, r): 221 | if type(r) is list or type(r) is tuple: 222 | assert len(r) == 2 223 | assert 0 <= r[0] <= 1 224 | assert 0 <= r[1] <= 1 225 | self.r_min = float(r[0]) 226 | self.r_max = float(r[1]) 227 | else: 228 | assert 0 <= r <= 1 229 | self.r_min = None 230 | self.r_max = float(r) 231 | 232 | def __call__(self, e): 233 | n = len(e) 234 | if self.r_min is None: 235 | r = self.r_max 236 | else: 237 | # Randomly select removal ratio 238 | r = random.uniform(self.r_min, self.r_max) 239 | 240 | mask = np.random.choice(range(n), size=int(n*r), replace=False) # select elements to remove 241 | e[mask] = torch.zeros_like(e[mask]) 242 | return e 243 | 244 | 245 | class RandomScaling: 246 | def __init__(self, scale_range=(0.9, 1.1)): 247 | """ 248 | Initializes the RandomScaling transformer with a specified scale range. 249 | 250 | Parameters: 251 | - scale_range: A tuple of two floats, specifying the minimum and maximum 252 | scaling factors. 253 | """ 254 | self.scale_range = scale_range 255 | 256 | def __call__(self, points): 257 | """ 258 | Applies random scaling to the point cloud. 259 | 260 | Parameters: 261 | - points: A numpy array of shape (N, 3), where N is the number of points. 262 | 263 | Returns: 264 | - Scaled points as a numpy array of shape (N, 3). 265 | """ 266 | scale_factor = np.random.uniform(self.scale_range[0], self.scale_range[1]) 267 | 268 | # Scale the points 269 | scaled_points = points * scale_factor 270 | 271 | return scaled_points 272 | 273 | class RemoveRandomBlock: 274 | """ 275 | Randomly remove part of the point cloud. Similar to PyTorch RandomErasing but operating on 3D point clouds. 276 | Erases fronto-parallel cuboid. 277 | Instead of erasing we set coords of removed points to (0, 0, 0) to retain the same number of points 278 | """ 279 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 280 | self.p = p 281 | self.scale = scale 282 | self.ratio = ratio 283 | 284 | def get_params(self, coords): 285 | # Find point cloud 3D bounding box 286 | flattened_coords = coords.view(-1, 3) 287 | min_coords, _ = torch.min(flattened_coords, dim=0) 288 | max_coords, _ = torch.max(flattened_coords, dim=0) 289 | span = max_coords - min_coords 290 | area = span[0] * span[1] 291 | erase_area = random.uniform(self.scale[0], self.scale[1]) * area 292 | aspect_ratio = random.uniform(self.ratio[0], self.ratio[1]) 293 | 294 | h = math.sqrt(erase_area * aspect_ratio) 295 | w = math.sqrt(erase_area / aspect_ratio) 296 | 297 | x = min_coords[0] + random.uniform(0, 1) * (span[0] - w) 298 | y = min_coords[1] + random.uniform(0, 1) * (span[1] - h) 299 | 300 | return x, y, w, h 301 | 302 | def __call__(self, coords): 303 | if random.random() < self.p: 304 | x, y, w, h = self.get_params(coords) # Fronto-parallel cuboid to remove 305 | mask = (x < coords[..., 0]) & (coords[..., 0] < x+w) & (y < coords[..., 1]) & (coords[..., 1] < y+h) 306 | coords[mask] = torch.zeros_like(coords[mask]) 307 | return coords 308 | 309 | 310 | class TrainSetTransform: 311 | def __init__(self, aug_mode): 312 | self.aug_mode = aug_mode 313 | self.transform = None 314 | if aug_mode == 0: 315 | t = None 316 | elif aug_mode == 1: 317 | t = [RandomRotation(max_theta=5, max_theta2=0, axis=np.array([0, 0, 1])), 318 | RandomFlip([0.25, 0.25, 0.])] 319 | else: 320 | raise NotImplementedError('Unknown aug_mode: {}'.format(aug_mode)) 321 | if t is None: 322 | self.transform = None 323 | else: 324 | self.transform = transforms.Compose(t) 325 | 326 | def __call__(self, e): 327 | if self.transform is not None: 328 | e = self.transform(e) 329 | return e 330 | 331 | 332 | def tensor2img(x): 333 | t = transforms.Compose([transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 334 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), 335 | transforms.ToPILImage()]) 336 | return t(x) -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import numpy as np 5 | import torch 6 | import dotsi 7 | import yaml 8 | from torch.utils.data import DataLoader 9 | from datasets.oxford import OxfordDataset 10 | from datasets.etna import EtnaDataset 11 | 12 | from typing import List, Union 13 | 14 | from datasets.augmentation import TrainTransform, TrainSetTransform, TrainRGBTransform, ValRGBTransform, TrainGreyTransform, ValGreyTransform 15 | from datasets.samplers import BatchSampler 16 | from misc.utils import UMFParams 17 | 18 | def load_config(config_file): 19 | with open(config_file, 'r') as ymlfile: 20 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 21 | return cfg 22 | 23 | 24 | def make_datasets(params: UMFParams, debug=False): 25 | # Create training and validation datasets 26 | datasets = {} 27 | train_transform = TrainTransform(params.aug_mode) 28 | train_set_transform = TrainSetTransform(params.aug_mode) 29 | cfg = load_config(params.model_params_path) 30 | cfg = dotsi.Dict(cfg) 31 | 32 | if params.dataset == 'robotcar': 33 | image_train_transform = TrainRGBTransform(params.aug_mode) 34 | image_val_transform = ValRGBTransform() 35 | 36 | datasets['train'] = OxfordDataset(params.dataset_folder, params.train_file, image_path=params.image_path, 37 | lidar2image_ndx_path=params.lidar2image_ndx_path, transform=train_transform, 38 | set_transform=train_set_transform, image_transform=image_train_transform, 39 | use_cloud=params.use_cloud, cfg=cfg) 40 | val_transform = None 41 | if params.val_file is not None: 42 | datasets['val'] = OxfordDataset(params.dataset_folder, params.val_file, image_path=params.image_path, 43 | lidar2image_ndx_path=params.lidar2image_ndx_path, transform=val_transform, 44 | set_transform=None, image_transform=image_val_transform, 45 | use_cloud=params.use_cloud, cfg=cfg) 46 | 47 | elif params.dataset == 'etna': 48 | image_train_transform = TrainGreyTransform(params.aug_mode) 49 | image_val_transform = ValGreyTransform() 50 | 51 | datasets['train'] = EtnaDataset(params.dataset_folder, params.train_file, 52 | transform=train_transform, set_transform=train_set_transform, 53 | image_transform=image_train_transform, 54 | use_cloud=params.use_cloud, use_rgb=params.use_rgb, cfg=cfg) 55 | val_transform = None 56 | if params.val_file is not None: 57 | datasets['val'] = EtnaDataset(params.dataset_folder, params.val_file, 58 | transform=val_transform, set_transform=None, 59 | image_transform=image_val_transform, 60 | use_cloud=params.use_cloud, use_rgb=params.use_rgb, cfg=cfg) 61 | 62 | return datasets 63 | 64 | 65 | 66 | def gather_features_by_pc_voxel_id(seg_res_features: torch.Tensor, pc_voxel_id: torch.Tensor, invalid_value: Union[int, float] = 0): 67 | """This function is used to gather segmentation result to match origin pc. 68 | """ 69 | if seg_res_features.device != pc_voxel_id.device: 70 | pc_voxel_id = pc_voxel_id.to(seg_res_features.device) 71 | res_feature_shape = (pc_voxel_id.shape[0], *seg_res_features.shape[1:]) 72 | if invalid_value == 0: 73 | res = torch.zeros(res_feature_shape, dtype=seg_res_features.dtype, device=seg_res_features.device) 74 | else: 75 | res = torch.full(res_feature_shape, invalid_value, dtype=seg_res_features.dtype, device=seg_res_features.device) 76 | pc_voxel_id_valid = pc_voxel_id != -1 77 | pc_voxel_id_valid_ids = torch.nonzero(pc_voxel_id_valid).view(-1) 78 | seg_res_features_valid = seg_res_features[pc_voxel_id[pc_voxel_id_valid_ids]] 79 | res[pc_voxel_id_valid_ids] = seg_res_features_valid 80 | return res 81 | 82 | def make_collate_fn(dataset: OxfordDataset): 83 | # set_transform: the transform to be applied to all batch elements 84 | def collate_fn(data_list): 85 | # Constructs a batch object 86 | labels = [e['ndx'] for e in data_list] 87 | 88 | # Compute positives and negatives mask 89 | positives_mask = [[in_sorted_array(e, dataset.queries[label].positives) for e in labels] for label in labels] 90 | negatives_mask = [[not in_sorted_array(e, dataset.queries[label].non_negatives) for e in labels] for label in labels] 91 | positives_mask = torch.tensor(positives_mask) 92 | negatives_mask = torch.tensor(negatives_mask) 93 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 94 | 95 | # Returns (batch_size, n_points, 3) tensor and positives_mask and 96 | # negatives_mask which are batch_size x batch_size boolean tensors 97 | result = {'positives_mask': positives_mask, 'negatives_mask': negatives_mask} 98 | if 'cloud' in data_list[0]: 99 | clouds = [e['cloud'] for e in data_list] 100 | clouds = torch.stack(clouds, dim=0) 101 | partial_clouds = clouds 102 | if dataset.set_transform is not None: 103 | # Apply the same transformation on all dataset elements 104 | partial_clouds = dataset.set_transform(clouds) 105 | 106 | images = [e['image'] for e in data_list] 107 | result['images'] = torch.stack(images, dim=0) 108 | 109 | voxel_data = [dataset.gen(partial_clouds[batch_id]) for batch_id in range(len(data_list))] 110 | batch_voxel, batch_voxel_coords, batch_num_pts_in_voxels = zip(*voxel_data) 111 | batch_voxel_coords = list(batch_voxel_coords) 112 | 113 | 114 | for batch_id in range(len(batch_voxel_coords)): 115 | coordinates = batch_voxel_coords[batch_id] 116 | 117 | discrete_coords = torch.cat( 118 | ( 119 | torch.zeros(coordinates.shape[0], 1, dtype=torch.int32), 120 | coordinates, 121 | ), 122 | 1, 123 | ) 124 | 125 | batch_voxel_coords[batch_id] = discrete_coords 126 | # Move batchids to the beginning 127 | batch_voxel_coords[batch_id][:, 0] = batch_id 128 | 129 | batch_voxel_coords = torch.cat(batch_voxel_coords, dim=0) 130 | 131 | 132 | num_points = torch.cat(batch_num_pts_in_voxels, 0) 133 | batch_voxel = torch.cat(batch_voxel, dim=0) 134 | feats_batch = torch.ones((batch_voxel.shape[0], 1), dtype=torch.float32, device=partial_clouds.device) 135 | 136 | result["coordinates"] = batch_voxel_coords 137 | result["voxel_features"] = feats_batch 138 | 139 | return result 140 | 141 | return collate_fn 142 | 143 | 144 | 145 | def mean_vfe(voxel_features, voxel_num_points): 146 | points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) 147 | normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) 148 | points_mean = points_mean / normalizer 149 | voxel_features = points_mean.contiguous() 150 | return voxel_features 151 | 152 | 153 | def make_dataloaders(params: UMFParams, debug=False): 154 | """ 155 | Create training and validation dataloaders that return groups of k=2 similar elements 156 | :param train_params: 157 | :param model_params: 158 | :return: 159 | """ 160 | datasets = make_datasets(params, debug=debug) 161 | 162 | dataloders = {} 163 | train_sampler = BatchSampler(datasets['train'], batch_size=params.batch_size, 164 | batch_size_limit=params.batch_size_limit, 165 | batch_expansion_rate=params.batch_expansion_rate) 166 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 167 | train_collate_fn = make_collate_fn(datasets['train']) 168 | dataloders['train'] = DataLoader(datasets['train'], batch_sampler=train_sampler, collate_fn=train_collate_fn, 169 | num_workers=params.num_workers, pin_memory=True, persistent_workers=True) 170 | 171 | if 'val' in datasets: 172 | val_sampler = BatchSampler(datasets['val'], batch_size=params.val_batch_size) 173 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 174 | # Currently validation dataset has empty set_transform function, but it may change in the future 175 | val_collate_fn = make_collate_fn(datasets['val']) 176 | dataloders['val'] = DataLoader(datasets['val'], batch_sampler=val_sampler, collate_fn=val_collate_fn, 177 | num_workers=params.num_workers, pin_memory=True, persistent_workers=True, 178 | shuffle=False) 179 | 180 | return dataloders 181 | 182 | 183 | def in_sorted_array(e: int, array: np.ndarray) -> bool: 184 | pos = np.searchsorted(array, e) 185 | if pos == len(array) or pos == -1: 186 | return False 187 | else: 188 | return array[pos] == e 189 | -------------------------------------------------------------------------------- /datasets/etna.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | import random 8 | from typing import Dict 9 | import math 10 | from spconv.pytorch.utils import PointToVoxel 11 | 12 | DEBUG = False 13 | 14 | 15 | class EtnaDataset(Dataset): 16 | """ 17 | Dataset wrapper for Oxford laser scans dataset from PointNetVLAD project. 18 | """ 19 | 20 | def __init__(self, dataset_path: str, query_filename: str, 21 | transform=None, set_transform=None, image_transform=None, 22 | use_cloud: bool = True, use_rgb: bool = True, cfg: dict = None): 23 | assert os.path.exists( 24 | dataset_path), 'Cannot access dataset path: {}'.format(dataset_path) 25 | self.dataset_path = dataset_path 26 | self.query_filepath = os.path.join(dataset_path, query_filename) 27 | assert os.path.exists( 28 | self.query_filepath), 'Cannot access query file: {}'.format(self.query_filepath) 29 | 30 | self.transform = transform 31 | self.set_transform = set_transform 32 | 33 | self.queries: Dict[int, TrainingTuple] = pickle.load( 34 | open(self.query_filepath, 'rb')) 35 | self.image_transform = image_transform 36 | self.n_points = 4096 # 4096 # pointclouds in the dataset are downsampled to 4096 points 37 | self.image_ext = '.png' 38 | self.use_cloud = use_cloud 39 | self.use_rgb = use_rgb 40 | 41 | self.point_cloud_range = cfg.model.point_cloud.range 42 | self.cfg = cfg 43 | print('{} queries in the dataset'.format(len(self))) 44 | 45 | self.queries = pickle.load(open(self.query_filepath, "rb")) 46 | self.dataset = pickle.load( 47 | open(os.path.join(dataset_path, "etna_complete_dataset2.pickle"), "rb")) 48 | self.images = use_rgb 49 | 50 | self.gen = PointToVoxel(vsize_xyz=cfg.model.point_cloud.voxel_size, 51 | coors_range_xyz=cfg.model.point_cloud.range, 52 | num_point_features=cfg.model.point_cloud.num_point_features, 53 | max_num_voxels=cfg.model.point_cloud.max_num_voxels, 54 | max_num_points_per_voxel=cfg.model.point_cloud.max_num_points_per_voxel) 55 | 56 | def __len__(self): 57 | return len(self.queries) 58 | 59 | def __getitem__(self, ndx): 60 | # Load point cloud and apply transform 61 | result = {'ndx': ndx} 62 | file = self.queries[ndx].timestamp 63 | value = self.dataset.loc[self.dataset['file'] == file] 64 | 65 | if self.use_cloud: 66 | # Load point cloud and apply transform 67 | query_pc= self.load_pc(value) 68 | 69 | if self.transform is not None: 70 | query_pc = self.transform(query_pc) 71 | result['cloud'] = query_pc 72 | 73 | if self.images: 74 | img_path = value["img_path"].values[0].split("/")[-3:] 75 | 76 | img_path = os.path.join( 77 | self.dataset_path, '/'.join(str(x) for x in img_path)) 78 | img_path = img_path.replace("s3li_crater_inout", "s3li_zcrater_inout") 79 | img = Image.open(img_path).convert("L") 80 | # crop image to remobe black borders, 2% of image size 81 | img = img.crop((int(img.size[0]*0.02), int(img.size[1]*0.02), int(img.size[0]*0.98), int(img.size[1]*0.98))) 82 | if img is None: 83 | print("Image not found: {}".format(img_path)) 84 | if self.image_transform is not None: 85 | img = self.image_transform(img) 86 | result['image'] = img 87 | 88 | return result 89 | 90 | def get_positives(self, ndx): 91 | return self.queries[ndx].positives 92 | 93 | def get_non_negatives(self, ndx): 94 | return self.queries[ndx].non_negatives 95 | 96 | def load_pc(self, file): 97 | # Load point cloud, does not apply any transform 98 | pc = np.array(file["point_cloud"].values[0]) 99 | 100 | voxel_range = [-30, -10, 0, 30, 10, 40] # xmin, ymin, zmin, xmax, ymax, zmax 101 | pc = pc[(pc[:, 0] > voxel_range[0]) & (pc[:, 0] < voxel_range[3]) & 102 | (pc[:, 1] > voxel_range[1]) & (pc[:, 1] < voxel_range[4]) & 103 | (pc[:, 2] > voxel_range[2]) & (pc[:, 2] < voxel_range[5])] 104 | 105 | 106 | 107 | N = pc.shape[0] 108 | if N == 0: 109 | assert False, "Empty point cloud" 110 | subsample_idx = np.random.choice(N, self.n_points) 111 | pc = pc[subsample_idx, :] 112 | pc = pc[:, :3] 113 | 114 | pc = torch.tensor(pc, dtype=torch.float).contiguous() 115 | return pc 116 | 117 | 118 | 119 | class TrainingTuple: 120 | # Tuple describing an element for training/validation 121 | def __init__(self, id: int, timestamp: int, positives: np.ndarray, 122 | non_negatives: np.ndarray, position: np.ndarray): 123 | # id: element id (ids start from 0 and are consecutive numbers) 124 | # ts: timestamp 125 | # positives: sorted ndarray of positive elements id 126 | # negatives: sorted ndarray of elements id 127 | # position: x, y position in meters (northing, easting) 128 | assert position.shape == (2,) 129 | 130 | self.id = id 131 | self.timestamp = timestamp 132 | self.positives = positives 133 | self.non_negatives = non_negatives 134 | self.position = position 135 | -------------------------------------------------------------------------------- /datasets/oxford.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Dataset wrapper for Oxford laser scans dataset from PointNetVLAD project 5 | # For information on dataset see: https://github.com/mikacuy/pointnetvlad 6 | 7 | import os 8 | import pickle 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset 12 | from PIL import Image 13 | import random 14 | from typing import Dict 15 | from spconv.pytorch.utils import PointToVoxel 16 | import concurrent.futures 17 | from tqdm import tqdm 18 | DEBUG = False 19 | log_file = "log.csv" 20 | 21 | 22 | class TrainingTuple: 23 | # Tuple describing an element for training/validation 24 | def __init__(self, id: int, timestamp: int, rel_scan_filepath: str, positives: np.ndarray, 25 | non_negatives: np.ndarray, position: np.ndarray): 26 | # id: element id (ids start from 0 and are consecutive numbers) 27 | # ts: timestamp 28 | # rel_scan_filepath: relative path to the scan 29 | # positives: sorted ndarray of positive elements id 30 | # negatives: sorted ndarray of elements id 31 | # position: x, y position in meters (northing, easting) 32 | assert position.shape == (2,) 33 | 34 | self.id = id 35 | self.timestamp = timestamp 36 | self.rel_scan_filepath = rel_scan_filepath 37 | self.positives = positives 38 | self.non_negatives = non_negatives 39 | self.position = position 40 | 41 | 42 | def mean_vfe(voxel_features, voxel_num_points): 43 | # voxel_features, voxel_num_points = batch_dict['voxels'], batch_dict['voxel_num_points'] 44 | points_mean = voxel_features[:, :, :].sum(dim=1, keepdim=False) 45 | normalizer = torch.clamp_min(voxel_num_points.view(-1, 1), min=1.0).type_as(voxel_features) 46 | points_mean = points_mean / normalizer 47 | voxel_features = points_mean.contiguous() 48 | # optional use all 1s to represent voxel feature 49 | # voxel_features = voxel_features.new_ones((voxel_features.shape[0], voxel_features.shape[1])) 50 | 51 | return voxel_features 52 | 53 | class OxfordDataset(Dataset): 54 | """ 55 | Dataset wrapper for Oxford laser scans dataset from PointNetVLAD project. 56 | """ 57 | def __init__(self, dataset_path: str, query_filename: str, image_path: str = None, 58 | lidar2image_ndx_path: str = None, transform=None, set_transform=None, image_transform=None, 59 | use_cloud: bool = True, cfg: dict = None): 60 | assert os.path.exists(dataset_path), 'Cannot access dataset path: {}'.format(dataset_path) 61 | self.dataset_path = dataset_path 62 | self.query_filepath = os.path.join(dataset_path, query_filename) 63 | assert os.path.exists(self.query_filepath), 'Cannot access query file: {}'.format(self.query_filepath) 64 | self.transform = transform 65 | self.set_transform = set_transform 66 | print('Loading queries from: {}'.format(self.query_filepath)) 67 | self.queries: Dict[int, TrainingTuple] = pickle.load(open(self.query_filepath, 'rb')) 68 | 69 | self.image_path = image_path 70 | self.lidar2image_ndx_path = lidar2image_ndx_path 71 | self.image_transform = image_transform 72 | self.n_points = 4096 # pointclouds in the dataset are downsampled to 4096 points 73 | self.image_ext = '.png' 74 | self.use_cloud = use_cloud 75 | print('{} queries in the dataset'.format(len(self))) 76 | self.point_cloud_range = cfg.model.point_cloud.range 77 | self.cfg = cfg 78 | self.log_file = "log.csv" 79 | self.logger = open(self.log_file, "w") 80 | 81 | 82 | self.gen = PointToVoxel(vsize_xyz=cfg.model.point_cloud.voxel_size, 83 | coors_range_xyz=cfg.model.point_cloud.range, 84 | num_point_features=cfg.model.point_cloud.num_point_features, 85 | max_num_voxels=cfg.model.point_cloud.max_num_voxels, 86 | max_num_points_per_voxel=cfg.model.point_cloud.max_num_points_per_voxel) 87 | 88 | assert os.path.exists(self.lidar2image_ndx_path), f"Cannot access lidar2image_ndx: {self.lidar2image_ndx_path}" 89 | self.lidar2image_ndx = pickle.load(open(self.lidar2image_ndx_path, 'rb')) 90 | 91 | def __len__(self): 92 | return len(self.queries) 93 | 94 | def __getitem__(self, ndx): 95 | # Load point cloud and apply transform 96 | filename = self.queries[ndx].rel_scan_filepath 97 | result = {'ndx': ndx} 98 | if self.use_cloud: 99 | # Load point cloud and apply transform 100 | file_pathname = os.path.join(self.dataset_path, self.queries[ndx].rel_scan_filepath) 101 | query_pc = self.load_pc(file_pathname) 102 | if self.transform is not None: 103 | query_pc = self.transform(query_pc) 104 | result['cloud'] = query_pc 105 | 106 | if self.image_path is not None: 107 | img = image4lidar(filename, self.image_path, self.image_ext, self.lidar2image_ndx, k=None) 108 | if img is None: 109 | # log in log_file the ndx and rel_scan_filepath 110 | print(str(self.queries[ndx].timestamp) + "," + str(self.queries[ndx].rel_scan_filepath)) 111 | img = Image.new('RGB', (224, 224), color='red') 112 | if self.image_transform is not None: 113 | img = self.image_transform(img) 114 | result['image'] = img 115 | 116 | return result 117 | 118 | def point_batch_to_voxel(self, clouds): 119 | device = clouds.device 120 | # set device cpu 121 | clouds = clouds.cpu() 122 | 123 | voxel_data = [self.gen(clouds[batch_id]) for batch_id in range(len(clouds))] 124 | batch_voxel, batch_voxel_coords, batch_num_pts_in_voxels = zip(*voxel_data) 125 | batch_voxel_coords = list(batch_voxel_coords) 126 | 127 | 128 | for batch_id in range(len(batch_voxel_coords)): 129 | coordinates = batch_voxel_coords[batch_id] 130 | 131 | discrete_coords = torch.cat( 132 | ( 133 | torch.zeros(coordinates.shape[0], 1, dtype=torch.int32), 134 | coordinates, 135 | ), 136 | 1, 137 | ) 138 | 139 | batch_voxel_coords[batch_id] = discrete_coords 140 | # Move batchids to the beginning 141 | batch_voxel_coords[batch_id][:, 0] = batch_id 142 | 143 | batch_voxel_coords = torch.cat(batch_voxel_coords, dim=0) 144 | 145 | feats_batch = torch.ones((batch_voxel_coords.shape[0], 1), dtype=torch.float32, device=device) 146 | feats_batch = feats_batch.to(device) 147 | batch_voxel_coords = batch_voxel_coords.to(device) 148 | return batch_voxel_coords, feats_batch 149 | 150 | 151 | def is_loadable(self, query: TrainingTuple) -> str: 152 | """ 153 | Given a lidar filename, return the associated image filepath. 154 | """ 155 | lidar_filename = query.rel_scan_filepath 156 | 157 | img = image4lidar(lidar_filename, self.image_path, self.image_ext, self.lidar2image_ndx, k=None) 158 | if img is None: 159 | # log in log_file the ndx and rel_scan_filepath 160 | self.logger.write(str(query.timestamp) + "," + str(self.queries[query.id].rel_scan_filepath) + "\n") 161 | # Assuming that the `image4lidar` function returns the filepath 162 | return (query.id, image4lidar(lidar_filename, self.image_path, self.image_ext, self.lidar2image_ndx, k=None)) 163 | 164 | 165 | def get_positives(self, ndx): 166 | return self.queries[ndx].positives 167 | 168 | def get_non_negatives(self, ndx): 169 | return self.queries[ndx].non_negatives 170 | 171 | def load_pc(self, filename): 172 | # Load point cloud, does not apply any transform 173 | # Returns Nx3 matrix 174 | file_path = os.path.join(self.dataset_path, filename) 175 | pc = np.fromfile(file_path, dtype=np.float64) 176 | # coords are within -1..1 range in each dimension 177 | assert pc.shape[0] == self.n_points * 3, "Error in point cloud shape: {}".format(file_path) 178 | pc = np.reshape(pc, (pc.shape[0] // 3, 3)) 179 | pc = torch.tensor(pc, dtype=torch.float) 180 | return pc 181 | 182 | 183 | def ts_from_filename(filename): 184 | # Extract timestamp (as integer) from the file path/name 185 | temp = os.path.split(filename)[1] 186 | lidar_ts = os.path.splitext(temp)[0] # LiDAR timestamp 187 | assert lidar_ts.isdigit(), 'Incorrect lidar timestamp: {}'.format(lidar_ts) 188 | 189 | temp = os.path.split(filename)[0] 190 | temp = os.path.split(temp)[0] 191 | traversal = os.path.split(temp)[1] 192 | assert len(traversal) == 19, 'Incorrect traversal name: {}'.format(traversal) 193 | 194 | return int(lidar_ts), traversal 195 | 196 | 197 | def image4lidar(filename, image_path, image_ext, lidar2image_ndx, k=None): 198 | # Return an image corresponding to the given lidar point cloud (given as a path to .bin file) 199 | # k: Number of closest images to randomly select from 200 | lidar_ts, traversal = ts_from_filename(filename) 201 | assert lidar_ts in lidar2image_ndx, 'Unknown lidar timestamp: {}'.format(lidar_ts) 202 | 203 | # Randomly select one of images linked with the point cloud 204 | if k is None or k > len(lidar2image_ndx[lidar_ts]): 205 | k = len(lidar2image_ndx[lidar_ts]) 206 | 207 | image_ts = random.choice(lidar2image_ndx[lidar_ts][:k]) 208 | image_file_path = os.path.join(image_path, traversal, str(image_ts) + image_ext) 209 | #image_file_path = '/media/sf_Datasets/images4lidar/2014-05-19-13-20-57/1400505893134088.png' 210 | try: 211 | img = Image.open(image_file_path) 212 | # check TypeError: Unexpected type 213 | if img is None: 214 | print('Cannot access image file: {}'.format(image_file_path)) 215 | return None 216 | return img 217 | except: 218 | return None 219 | 220 | 221 | import os 222 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | import random 4 | import copy 5 | from torch.utils.data import Sampler 6 | from datasets.oxford import OxfordDataset 7 | 8 | 9 | class ListDict(object): 10 | def __init__(self, items=None): 11 | if items is not None: 12 | self.items = copy.deepcopy(items) 13 | self.item_to_position = {item: ndx for ndx, item in enumerate(items)} 14 | else: 15 | self.items = [] 16 | self.item_to_position = {} 17 | 18 | def add(self, item): 19 | if item in self.item_to_position: 20 | return 21 | self.items.append(item) 22 | self.item_to_position[item] = len(self.items)-1 23 | 24 | def remove(self, item): 25 | position = self.item_to_position.pop(item) 26 | last_item = self.items.pop() 27 | if position != len(self.items): 28 | self.items[position] = last_item 29 | self.item_to_position[last_item] = position 30 | 31 | def choose_random(self): 32 | return random.choice(self.items) 33 | 34 | def __contains__(self, item): 35 | return item in self.item_to_position 36 | 37 | def __iter__(self): 38 | return iter(self.items) 39 | 40 | def __len__(self): 41 | return len(self.items) 42 | 43 | 44 | class BatchSampler(Sampler): 45 | # Sampler returning list of indices to form a mini-batch 46 | # Samples elements in groups consisting of k=2 similar elements (positives) 47 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 48 | def __init__(self, dataset: OxfordDataset, batch_size: int, batch_size_limit: int = None, 49 | batch_expansion_rate: float = None, max_batches: int = None): 50 | if batch_expansion_rate is not None: 51 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 52 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size' 53 | 54 | self.batch_size = batch_size 55 | self.batch_size_limit = batch_size_limit 56 | self.batch_expansion_rate = batch_expansion_rate 57 | self.max_batches = max_batches 58 | self.dataset = dataset 59 | self.k = 2 # Number of positive examples per group must be 2 60 | if self.batch_size < 2 * self.k: 61 | self.batch_size = 2 * self.k 62 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size)) 63 | 64 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch) 65 | self.elems_ndx = list(self.dataset.queries) # List of point cloud indexes 66 | 67 | def __iter__(self): 68 | # Re-generate batches every epoch 69 | self.generate_batches() 70 | for batch in self.batch_idx: 71 | yield batch 72 | 73 | def __len(self): 74 | return len(self.batch_idx) 75 | 76 | def expand_batch(self): 77 | if self.batch_expansion_rate is None: 78 | print('WARNING: batch_expansion_rate is None') 79 | return 80 | 81 | if self.batch_size >= self.batch_size_limit: 82 | return 83 | 84 | old_batch_size = self.batch_size 85 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 86 | self.batch_size = min(self.batch_size, self.batch_size_limit) 87 | print('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size)) 88 | 89 | def generate_batches(self): 90 | # Generate training/evaluation batches. 91 | # batch_idx holds indexes of elements in each batch as a list of lists 92 | self.batch_idx = [] 93 | 94 | unused_elements_ndx = ListDict(self.elems_ndx) 95 | current_batch = [] 96 | 97 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class' 98 | 99 | while True: 100 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0: 101 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 102 | # elements to process 103 | if len(current_batch) >= 2*self.k: 104 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 105 | # to find negative examples in the batch 106 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch)) 107 | self.batch_idx.append(current_batch) 108 | current_batch = [] 109 | if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches): 110 | break 111 | if len(unused_elements_ndx) == 0: 112 | break 113 | 114 | # Add k=2 similar elements to the batch 115 | selected_element = unused_elements_ndx.choose_random() 116 | unused_elements_ndx.remove(selected_element) 117 | positives = self.dataset.get_positives(selected_element) 118 | if len(positives) == 0: 119 | # Broken dataset element without any positives 120 | continue 121 | 122 | unused_positives = [e for e in positives if e in unused_elements_ndx] 123 | # If there're unused elements similar to selected_element, sample from them 124 | # otherwise sample from all similar elements 125 | if len(unused_positives) > 0: 126 | second_positive = random.choice(unused_positives) 127 | unused_elements_ndx.remove(second_positive) 128 | else: 129 | second_positive = random.choice(list(positives)) 130 | 131 | current_batch += [selected_element, second_positive] 132 | 133 | for batch in self.batch_idx: 134 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch)) 135 | 136 | -------------------------------------------------------------------------------- /eval/evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | import os 4 | import sys 5 | import argparse 6 | import torch 7 | import tqdm 8 | from misc.utils import UMFParams 9 | from models.model_factory import model_factory 10 | from datasets.oxford import image4lidar 11 | from datasets.augmentation import ValRGBTransform 12 | import logging 13 | from spconv.pytorch.utils import PointToVoxel 14 | import dotsi 15 | import yaml 16 | from eval.utils import load_config, compute_and_log_stats 17 | 18 | logging.basicConfig(filename='rerank_log.txt', filemode='w', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 19 | LOAD_FEATURES = False 20 | 21 | 22 | def load_config(config_file): 23 | with open(config_file, 'r') as ymlfile: 24 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 25 | return cfg 26 | 27 | 28 | def evaluate(model, device, params, silent=True): 29 | assert len(params.eval_database_files) == len(params.eval_query_files) 30 | 31 | stats = {} 32 | stats_img = {} 33 | stats_pc = {} 34 | stats_combined = {} 35 | 36 | for database_file, query_file in zip(params.eval_database_files, params.eval_query_files): 37 | # Extract location name from query and database files 38 | location_name = database_file.split('_')[0] 39 | temp = query_file.split('_')[0] 40 | assert location_name == temp, 'Database location: {} does not match query location: {}'.format(database_file, 41 | query_file) 42 | 43 | p = os.path.join(params.dataset_folder, database_file) 44 | with open(p, 'rb') as f: 45 | database_sets = pickle.load(f) 46 | 47 | p = os.path.join(params.dataset_folder, query_file) 48 | with open(p, 'rb') as f: 49 | query_sets = pickle.load(f) 50 | 51 | temp, temp_img, temp_pc, temp_combined = evaluate_dataset(model, device, params, database_sets, query_sets, silent=silent) 52 | stats[location_name] = temp 53 | stats_img[location_name] = temp_img 54 | stats_pc[location_name] = temp_pc 55 | stats_combined[location_name] = temp_combined 56 | 57 | return stats, stats_img, stats_pc, stats_combined 58 | 59 | 60 | def evaluate_dataset(model, device, params, database_sets, query_sets, silent=True): 61 | features_folder = 'features' 62 | 63 | if not os.path.exists(features_folder): 64 | os.makedirs(features_folder) 65 | 66 | initialize_model(model, database_sets, features_folder, device, params) 67 | 68 | if not LOAD_FEATURES: 69 | # Process and save database sets 70 | for idx, set in enumerate(tqdm.tqdm(database_sets, disable=silent)): 71 | out = get_latent_vectors(model, set, device, params, dim_reduction=True, normalize=False) 72 | with open(os.path.join(features_folder, f'db_set_{idx}.pkl'), 'wb') as f: 73 | pickle.dump(out, f) 74 | 75 | # Process and save query sets 76 | for idx, set in enumerate(tqdm.tqdm(query_sets, disable=silent)): 77 | out = get_latent_vectors(model, set, device, params, dim_reduction=True, normalize=False) 78 | with open(os.path.join(features_folder, f'query_set_{idx}.pkl'), 'wb') as f: 79 | pickle.dump(out, f) 80 | 81 | stats = [] 82 | mode = params.model_params.params['mode'] 83 | 84 | modalities = ['base'] 85 | if mode != 'fusion': 86 | modalities.extend(['img', 'pc', 'combined']) 87 | 88 | for modality in modalities: 89 | stat = compute_and_log_stats(mode, modality, query_sets, database_sets, features_folder) 90 | stats.append(stat) 91 | 92 | return tuple(stats) 93 | 94 | def load_data_item(file_name, params): 95 | # returns Nx3 matrix 96 | file_path = os.path.join(params.dataset_folder, file_name) 97 | 98 | result = {} 99 | if params.use_cloud: 100 | pc = np.fromfile(file_path, dtype=np.float64) 101 | # coords are within -1..1 range in each dimension 102 | assert pc.shape[0] == params.num_points * 3, "Error in point cloud shape: {}".format(file_path) 103 | pc = np.reshape(pc, (pc.shape[0] // 3, 3)) 104 | pc = torch.tensor(pc, dtype=torch.float) 105 | result['cloud'] = pc 106 | 107 | if params.use_rgb: 108 | # Get the first closest image for each LiDAR scan 109 | assert os.path.exists(params.lidar2image_ndx_path), f"Cannot find lidar2image_ndx pickle: {params.lidar2image_ndx_path}" 110 | lidar2image_ndx = pickle.load(open(params.lidar2image_ndx_path, 'rb')) 111 | img = image4lidar(file_name, params.image_path, '.png', lidar2image_ndx, k=1) 112 | transform = ValRGBTransform() 113 | # Convert to tensor and normalize 114 | result['image'] = transform(img) 115 | return result 116 | 117 | 118 | def process_cloud_data(clouds, gen, device): 119 | voxel_data = [gen(c) for c in clouds] 120 | batch_voxel_features, batch_voxel_coords, _ = zip(*voxel_data) 121 | # Prepare batch data 122 | batch_voxel_coords = [ 123 | torch.cat(( 124 | torch.full((coords.shape[0], 1), batch_id, dtype=torch.int32, device=device), 125 | coords.to(device)), 126 | dim=1) 127 | for batch_id, coords in enumerate(batch_voxel_coords) 128 | ] 129 | return torch.cat(batch_voxel_coords, dim=0), torch.cat(batch_voxel_features, dim=0) 130 | 131 | def process_data_item(item, device, params, gen): 132 | x = load_data_item(item["query"], params) 133 | batch = {} 134 | if params.use_cloud: 135 | clouds = [x['cloud'].to(device)] 136 | batch_voxel_coords, batch_voxel = process_cloud_data(clouds, gen, device) 137 | feats_batch = torch.ones((batch_voxel.shape[0], 1), dtype=torch.float32, device=device) 138 | batch['coordinates'] = batch_voxel_coords 139 | batch['voxel_features'] = feats_batch 140 | if params.use_rgb: 141 | batch['images'] = x['image'].unsqueeze(0).to(device) 142 | return batch 143 | 144 | 145 | def get_latent_vectors(model, dataset, device, params, dim_reduction=True, normalize=False): 146 | cfg = load_config(params.model_params_path) 147 | cfg = dotsi.Dict(cfg) 148 | 149 | gen = PointToVoxel(vsize_xyz=cfg.model.point_cloud.voxel_size, 150 | coors_range_xyz=cfg.model.point_cloud.range, 151 | num_point_features=cfg.model.point_cloud.num_point_features, 152 | max_num_voxels=cfg.model.point_cloud.max_num_voxels, 153 | max_num_points_per_voxel=cfg.model.point_cloud.max_num_points_per_voxel, 154 | device=device) 155 | 156 | model.eval() 157 | embeddings_l = [] 158 | 159 | feature_lists = { 160 | 'img_fine_feat': [], 'img_attns': [], 'pc_fine_feat': [], 'pc_attns': [], 161 | 'img_super_feat': [], 'img_strengths': [], 'pc_super_feat': [], 'pc_strengths': [] 162 | } 163 | 164 | for elem_ndx in dataset: 165 | if not dim_reduction and np.random.rand() > 0.1: 166 | continue 167 | 168 | batch = process_data_item(dataset[elem_ndx], device, params, gen) 169 | with torch.no_grad(): 170 | outputs = model(batch, dim_reduction=dim_reduction, normalize=normalize) 171 | if params.normalize_embeddings: 172 | outputs['embedding'] = torch.nn.functional.normalize(outputs['embedding'], p=2, dim=1) 173 | embedding = outputs['embedding'].detach().cpu().numpy()[0] 174 | embeddings_l.append(embedding) 175 | 176 | for key in feature_lists.keys(): 177 | if key in outputs: 178 | feature_lists[key].append(outputs[key].cpu()) 179 | 180 | out = {'embedding': torch.tensor(np.array(embeddings_l), dtype=torch.float32)} 181 | for key, value in feature_lists.items(): 182 | if value: 183 | out[key] = torch.stack(value, dim=0) if 'feat' in key else torch.cat(value, dim=0) 184 | 185 | return out 186 | 187 | def print_eval_stats(stats): 188 | for database_name in stats: 189 | print('Dataset: {}'.format(database_name)) 190 | t = 'Avg. top 1% recall: {:.2f} Avg. similarity: {:.4f} Avg. recall @N:' 191 | print(t.format(stats[database_name]['ave_one_percent_recall'], 192 | stats[database_name]['average_similarity'])) 193 | print(stats[database_name]['ave_recall']) 194 | 195 | 196 | def initialize_model(model, train_set, features_folder, device, params): 197 | model.eval() 198 | mode = params.model_params.params['mode'] 199 | des_img_train = [] 200 | des_pc_train = [] 201 | train_set = train_set[:] 202 | 203 | if not LOAD_FEATURES: 204 | for set in train_set: 205 | out = get_latent_vectors(model, set, device, params, dim_reduction=False, normalize=False) 206 | 207 | if 'img_fine_feat' in out: 208 | des_img_train.append(out['img_fine_feat'].detach().cpu()) 209 | des_pc_train.append(out['pc_fine_feat'].detach().cpu()) 210 | 211 | elif 'img_super_feat' in out: 212 | des_img_train.append(out['img_super_feat'].detach().cpu()) 213 | des_pc_train.append(out['pc_super_feat'].detach().cpu()) 214 | 215 | des_img_train = torch.cat(des_img_train, dim=0) 216 | des_pc_train = torch.cat(des_pc_train, dim=0) 217 | 218 | if mode == 'ransac': 219 | pc_feat_dim = des_pc_train.shape[2] 220 | img_feat_dim = des_img_train.shape[2] 221 | elif mode == 'superfeatures': 222 | pc_feat_dim = des_pc_train.shape[2] 223 | img_feat_dim = des_img_train.shape[2] 224 | 225 | max_num_samples = 50000 226 | if des_img_train.shape[0] > max_num_samples: 227 | indices_img = np.random.choice(des_img_train.shape[0], max_num_samples, replace=False) 228 | des_img_train_sampled = des_img_train[indices_img].reshape(-1, img_feat_dim) 229 | else: 230 | des_img_train_sampled = des_img_train.reshape(-1, img_feat_dim) 231 | 232 | 233 | if des_pc_train.shape[0] > max_num_samples: 234 | indices_pc = np.random.choice(des_pc_train.shape[0], max_num_samples, replace=False) 235 | des_pc_train_sampled = des_pc_train[indices_pc].reshape(-1, pc_feat_dim) 236 | else: 237 | des_pc_train_sampled = des_pc_train.reshape(-1, pc_feat_dim) 238 | 239 | if mode == 'ransac': 240 | mi, Pi = model.lt.reduction_layer.initialize_pca_whitening(des_img_train_sampled) 241 | mp, Pp = model.lt3d.reduction_layer.initialize_pca_whitening(des_pc_train_sampled) 242 | elif mode == 'superfeatures': 243 | mi, Pi = model.lit.reduction_layer.initialize_pca_whitening(des_img_train_sampled) 244 | mp, Pp = model.lit3d.reduction_layer.initialize_pca_whitening(des_pc_train_sampled) 245 | 246 | del des_img_train, des_pc_train, des_img_train_sampled, des_pc_train_sampled 247 | 248 | with open(f"{features_folder}/initialization.pkl", 'wb') as f: 249 | pickle.dump({'mi': mi, 'Pi': Pi, 'mp': mp, 'Pp': Pp}, f) 250 | 251 | # load initialization values 252 | with open(f"{features_folder}/initialization.pkl", 'rb') as f: 253 | init = pickle.load(f) 254 | mi, Pi, mp, Pp = init['mi'], init['Pi'], init['mp'], init['Pp'] 255 | 256 | if hasattr(model, 'lit'): 257 | print('-> Loading initialization values for lit and lit3d') 258 | model.lit.reduction_layer.load_initialization(mi, Pi) 259 | model.lit3d.reduction_layer.load_initialization(mp, Pp) 260 | elif hasattr(model, 'lt'): 261 | print('-> Loading initialization values for lt and lt3d') 262 | model.lt.reduction_layer.load_initialization(mi, Pi) 263 | model.lt3d.reduction_layer.load_initialization(mp, Pp) 264 | 265 | 266 | 267 | if __name__ == "__main__": 268 | parser = argparse.ArgumentParser(description='Evaluate model on RobotCar dataset') 269 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file') 270 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file') 271 | parser.add_argument('--weights', type=str, required=False, help='Trained model weights') 272 | 273 | args = parser.parse_args() 274 | print('Config path: {}'.format(args.config)) 275 | print('Model config path: {}'.format(args.model_config)) 276 | if args.weights is None: 277 | sys.exit('Please provide a path to the trained model weights') 278 | 279 | params = UMFParams(args.config, args.model_config) 280 | params.print() 281 | 282 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 283 | print('Device: {}'.format(device)) 284 | model = model_factory(params) 285 | if args.weights is not None: 286 | assert os.path.exists(args.weights), 'Cannot open network weights: {}'.format(args.weights) 287 | print('Loading weights: {}'.format(args.weights)) 288 | 289 | pretrained_state_dict = torch.load(args.weights, map_location='cpu') 290 | dim_reduction_prefix = '.reduction_layer.' 291 | filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if dim_reduction_prefix not in k} 292 | model.load_state_dict(filtered_state_dict, strict=False) 293 | 294 | model.to(device) 295 | 296 | 297 | stats, stats_img, stats_pc, stats_combined = evaluate(model, device, params, silent=False) 298 | print("----- Without reranking -----") 299 | 300 | print_eval_stats(stats) 301 | 302 | print("----- With Img reranking -----") 303 | print_eval_stats(stats_img) 304 | 305 | 306 | print("----- With PC reranking -----") 307 | print_eval_stats(stats_pc) 308 | 309 | print("----- With Combined reranking -----") 310 | print_eval_stats(stats_combined) 311 | 312 | 313 | 314 | -------------------------------------------------------------------------------- /eval/utils.py: -------------------------------------------------------------------------------- 1 | from sklearn.neighbors import KDTree 2 | import numpy as np 3 | import pickle 4 | import os 5 | import torch 6 | import tqdm 7 | from how.layers import functional as HF 8 | from models.UMF.utils.patch_matcher import PatchMatcher 9 | from models.losses.super_feature_losses import match_super 10 | import logging 11 | import yaml 12 | 13 | scores_cache_img = {} 14 | scores_cache_pc = {} 15 | NUM_NEIGHBORS = 25 16 | 17 | 18 | def load_config(config_file): 19 | with open(config_file, 'r') as ymlfile: 20 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 21 | return cfg 22 | 23 | def compute_and_log_stats(mode, modality, query_sets, database_sets, features_folder): 24 | count = 0 25 | recall = np.zeros(NUM_NEIGHBORS) 26 | one_percent_recall = [] 27 | similarity = [] 28 | 29 | for i in tqdm.tqdm(range(len(query_sets))): 30 | with open(os.path.join(features_folder, f'query_set_{i}.pkl'), 'rb') as f: 31 | query_outputs = pickle.load(f) 32 | for j in range(len(database_sets)): 33 | if i == j: 34 | continue 35 | with open(os.path.join(features_folder, f'db_set_{j}.pkl'), 'rb') as f: 36 | db_outputs = pickle.load(f) 37 | 38 | if mode == 'superfeatures': 39 | pair_recall, pair_similarity, pair_opr = get_recall_super(i, j, query_outputs, db_outputs, 40 | query_sets[i], mode=modality) 41 | else: 42 | pair_recall, pair_similarity, pair_opr = get_recall(i, j, query_outputs, db_outputs, 43 | query_sets[i], mode=modality) 44 | 45 | recall += np.array(pair_recall) 46 | count += 1 47 | one_percent_recall.append(pair_opr) 48 | similarity.extend(pair_similarity) 49 | 50 | ave_recall = recall / count 51 | average_similarity = np.mean(similarity) 52 | ave_one_percent_recall = np.mean(one_percent_recall) 53 | stats = {'ave_one_percent_recall': ave_one_percent_recall, 'ave_recall': ave_recall, 54 | 'average_similarity': average_similarity} 55 | 56 | logging.info("-------------------------------------------------------------------") 57 | logging.info(f"Mode: {mode}") 58 | logging.info(f"ave_one_percent_recall: {ave_one_percent_recall}") 59 | logging.info(f"ave_recall: {ave_recall}") 60 | logging.info(f"average_similarity: {average_similarity}") 61 | logging.info("-------------------------------------------------------------------") 62 | 63 | return stats 64 | 65 | 66 | def select_superfeatures(superfeatures, strengths, scales=[1], N=20): 67 | selected_superfeatures = torch.zeros(N, superfeatures.shape[-1]) 68 | strengths = strengths.unsqueeze(0).unsqueeze(2) 69 | features, _, _, scales = HF.how_select_local(superfeatures, strengths, scales=scales, features_num=N) 70 | selected_superfeatures = features 71 | return selected_superfeatures 72 | 73 | 74 | def get_recall_super(n, m, query_outputs, db_outputs, query_set, mode='base'): 75 | global scores_cache_img, scores_cache_pc 76 | database_output = db_outputs['embedding'] 77 | queries_output = query_outputs['embedding'] 78 | 79 | # log mode 80 | logging.info(f"-------------------------------------------------------------------") 81 | logging.info(f"Mode: {mode}") 82 | logging.info(f"-------------------------------------------------------------------") 83 | 84 | if mode == 'img' or mode == 'combined': 85 | db_feat_img = db_outputs['img_super_feat'] 86 | q_feat_img = query_outputs['img_super_feat'] 87 | db_strengths_img = db_outputs['img_strengths'] 88 | q_strengths_img = query_outputs['img_strengths'] 89 | N = 64 90 | LoweRatioTh = 0.9 91 | 92 | if mode == 'pc' or mode == 'combined': 93 | db_feat_pc = db_outputs['pc_super_feat'] 94 | q_feat_pc = query_outputs['pc_super_feat'] 95 | db_strengths_pc = db_outputs['pc_strengths'] 96 | q_strengths_pc = query_outputs['pc_strengths'] 97 | N = 128 98 | LoweRatioTh = 0.9 99 | database_nbrs = KDTree(database_output) 100 | recall = [0] * NUM_NEIGHBORS 101 | 102 | top1_similarity_score = [] 103 | one_percent_retrieved = 0 104 | threshold = max(int(round(len(database_output)/100.0)), 1) 105 | 106 | num_evaluated = 0 107 | for i in range(len(queries_output)): 108 | # i is query element ndx 109 | query_details = query_set[i] # {'query': path, 'northing': , 'easting': } 110 | true_neighbors = query_details[m] 111 | if len(true_neighbors) == 0: 112 | continue 113 | num_evaluated += 1 114 | distances, indices = database_nbrs.query(np.array([queries_output[i]]), k=NUM_NEIGHBORS) 115 | 116 | if len(indices) == 0: 117 | print("-----> Error") 118 | continue 119 | 120 | candidates = indices[0] 121 | final_candidates = candidates 122 | final_scores = distances[0] 123 | cache_key = (m, n, i) 124 | if mode != 'base': 125 | 126 | if mode == 'img': 127 | q_img_superfeatures = select_superfeatures(q_feat_img[i], q_strengths_img[i], N=N) 128 | db_img_superfeatures = [select_superfeatures(db_feat_img[j], db_strengths_img[j], N=N) for j in candidates] 129 | scores = np.array([ -len(match_super(q_img_superfeatures, db_img_superfeatures[j], LoweRatioTh=LoweRatioTh)) / N for j in range(len(candidates))]) 130 | scores_cache_img[cache_key] = scores 131 | 132 | elif mode == 'pc': 133 | q_pc_superfeatures = select_superfeatures(q_feat_pc[i], q_strengths_pc[i], N=N) 134 | db_pc_superfeatures = [select_superfeatures(db_feat_pc[j], db_strengths_pc[j], N=N) for j in candidates] 135 | 136 | scores = np.array([ -len(match_super(q_pc_superfeatures, db_pc_superfeatures[j], LoweRatioTh=LoweRatioTh)) / N for j in range(len(candidates))]) 137 | scores_cache_pc[cache_key] = scores 138 | elif mode == 'combined': 139 | scores_img = scores_cache_img[cache_key] 140 | scores_pc = scores_cache_pc[cache_key] 141 | 142 | scores = 1. * scores_img + 1 * scores_pc 143 | 144 | scores = scores + distances[0] 145 | scores = [(score, candidate) for score, candidate in zip(scores, candidates)] 146 | 147 | final_candidates = [x for _, x in sorted(scores, key=lambda el: el[0])] 148 | final_scores = [x[0] for x in sorted(scores)] 149 | 150 | correct = [] 151 | for j in range(len(final_candidates)): 152 | if final_candidates[j] in true_neighbors: 153 | correct.append( (final_candidates[j], final_scores[j]) ) 154 | if j == 0: 155 | similarity = np.dot(queries_output[i], database_output[final_candidates[j]]) 156 | top1_similarity_score.append(similarity) 157 | recall[j] += 1 158 | break 159 | 160 | if len(list(set(final_candidates[0:threshold]).intersection(set(true_neighbors)))) > 0: 161 | one_percent_retrieved += 1 162 | 163 | if mode != 'base': 164 | candidates = torch.tensor(candidates, dtype=torch.int32) 165 | final_candidates = torch.tensor(final_candidates, dtype=torch.int32) 166 | 167 | candidates_list = candidates.tolist() 168 | final_candidates_list = final_candidates.tolist() 169 | 170 | if len(correct) > 0: 171 | log_reranking_stats(candidates_list, final_candidates_list, correct, true_neighbors) 172 | 173 | one_percent_recall = (one_percent_retrieved/float(num_evaluated))*100 174 | recall = (np.cumsum(recall)/float(num_evaluated))*100 175 | return recall, top1_similarity_score, one_percent_recall 176 | 177 | 178 | def get_recall(n, m, query_outputs, db_outputs, query_set, mode='base'): 179 | global scores_cache_img, scores_cache_pc 180 | database_output = db_outputs['embedding'] 181 | queries_output = query_outputs['embedding'] 182 | 183 | logging.info(f"-------------------------------------------------------------------") 184 | logging.info(f"Mode: {mode}") 185 | logging.info(f"-------------------------------------------------------------------") 186 | 187 | if mode == 'img' or mode == 'combined': 188 | db_feat_img = db_outputs['img_fine_feat'] 189 | q_feat_img = query_outputs['img_fine_feat'] 190 | q_attn_img = query_outputs['img_attns'] 191 | db_attn_img = db_outputs['img_attns'] 192 | 193 | if mode == 'pc' or mode == 'combined': 194 | db_feat_pc = db_outputs['pc_fine_feat'] 195 | q_feat_pc = query_outputs['pc_fine_feat'] 196 | q_attn_pc = query_outputs['pc_attns'] 197 | db_attn_pc = db_outputs['pc_attns'] 198 | 199 | 200 | database_nbrs = KDTree(database_output) 201 | 202 | recall = [0] * NUM_NEIGHBORS 203 | 204 | top1_similarity_score = [] 205 | one_percent_retrieved = 0 206 | threshold = max(int(round(len(database_output)/100.0)), 1) 207 | 208 | num_evaluated = 0 209 | for i in range(len(queries_output)): 210 | # i is query element ndx 211 | query_details = query_set[i] # {'query': path, 'northing': , 'easting': } 212 | true_neighbors = query_details[m] 213 | if len(true_neighbors) == 0: 214 | continue 215 | num_evaluated += 1 216 | distances, indices = database_nbrs.query(np.array([queries_output[i]]), k=NUM_NEIGHBORS) 217 | 218 | if len(indices) == 0: 219 | print("-----> Error") 220 | continue 221 | 222 | candidates = indices[0] 223 | final_candidates = candidates 224 | final_scores = distances[0] 225 | cache_key = (m, n, i) 226 | if mode != 'base': 227 | matcher = PatchMatcher( 228 | patch_sizes=[1], # [1, 5, 7] 229 | strides=[1], # [1, 2, 3] 230 | patch_size3D=[1], # [1, 3] 231 | stride3D=[1], # [1, 2] 232 | th_img=0.6, 233 | th_pc=0.5, 234 | ) 235 | 236 | if mode == 'img': 237 | scores = matcher.match(q_feat_img[i], db_feat_img[candidates], q_attn_img[i], db_attn_img[candidates]) 238 | scores_cache_img[cache_key] = scores 239 | elif mode == 'pc': 240 | scores = matcher.match_pc(q_feat_pc[i], db_feat_pc[candidates], q_attn_pc[i], db_attn_pc[candidates]) 241 | scores_cache_pc[cache_key] = scores 242 | 243 | elif mode == 'combined': 244 | if cache_key not in scores_cache_img: 245 | scores_img = matcher.match(q_feat_img[i], db_feat_img[candidates], q_attn_img[i], db_attn_img[candidates]) 246 | scores_cache_img[cache_key] = scores_img 247 | else: 248 | scores_img = scores_cache_img[cache_key] 249 | 250 | if cache_key not in scores_cache_pc: 251 | scores_pc = matcher.match_pc(q_feat_pc[i], db_feat_pc[candidates], q_attn_pc[i], db_attn_pc[candidates]) 252 | scores_cache_pc[cache_key] = scores_pc 253 | else: 254 | scores_pc = scores_cache_pc[cache_key] 255 | 256 | scores = 1. * scores_img + 1 * scores_pc 257 | 258 | scores = 1 * scores + distances[0] 259 | scores = [(score, candidate) for score, candidate in zip(scores, candidates)] 260 | 261 | final_candidates = [x for _, x in sorted(scores, key=lambda el: el[0])] 262 | final_scores = [x[0] for x in sorted(scores)] 263 | 264 | correct = [] 265 | for j in range(len(final_candidates)): 266 | if final_candidates[j] in true_neighbors: 267 | correct.append( (final_candidates[j], final_scores[j]) ) 268 | if j == 0: 269 | similarity = np.dot(queries_output[i], database_output[final_candidates[j]]) 270 | top1_similarity_score.append(similarity) 271 | recall[j] += 1 272 | break 273 | 274 | if len(list(set(final_candidates[0:threshold]).intersection(set(true_neighbors)))) > 0: 275 | one_percent_retrieved += 1 276 | 277 | if mode != 'base': 278 | candidates = torch.tensor(candidates, dtype=torch.int32) 279 | final_candidates = torch.tensor(final_candidates, dtype=torch.int32) 280 | 281 | candidates_list = candidates.tolist() 282 | final_candidates_list = final_candidates.tolist() 283 | 284 | if len(correct) > 0: 285 | log_reranking_stats(candidates_list, final_candidates_list, correct, true_neighbors) 286 | 287 | one_percent_recall = (one_percent_retrieved/float(num_evaluated))*100 288 | recall = (np.cumsum(recall)/float(num_evaluated))*100 289 | return recall, top1_similarity_score, one_percent_recall 290 | 291 | 292 | def log_evaluation_stats(mode, ave_one_percent_recall, ave_recall, average_similarity): 293 | logging.info("-------------------------------------------------------------------") 294 | logging.info(f"Mode: {mode}") 295 | logging.info(f"ave_one_percent_recall: {ave_one_percent_recall}") 296 | logging.info(f"ave_recall: {ave_recall}") 297 | logging.info(f"average_similarity: {average_similarity}") 298 | logging.info("-------------------------------------------------------------------") 299 | 300 | def log_reranking_stats(candidates, final_candidates, correct, true_neighbors): 301 | candidate, score = correct[0] 302 | 303 | # Identify the index of the correct match in the initial and final candidate lists 304 | initial_correct_index = next((candidates.index(candidate) for candidate in candidates if candidate in true_neighbors), None) 305 | final_correct_index = next((final_candidates.index(candidate) for candidate in final_candidates if candidate in true_neighbors), None) 306 | 307 | # Compute distance to top-1 for initial and final rankings 308 | initial_distance_to_top1 = initial_correct_index if initial_correct_index is not None else -1 309 | final_distance_to_top1 = final_correct_index if final_correct_index is not None else -1 310 | 311 | if final_distance_to_top1 < initial_distance_to_top1: 312 | best_method = "with reranking" 313 | elif final_distance_to_top1 == initial_distance_to_top1: 314 | best_method = "equal" 315 | else: 316 | best_method = "without reranking" 317 | 318 | log_message = (f"Candidate: {candidate}, Score: {score:.2f}, Initial top-1: {initial_distance_to_top1}, " 319 | f"Reranking top-1: {final_distance_to_top1}, Best: {best_method}.") 320 | 321 | 322 | logging.info(log_message) 323 | -------------------------------------------------------------------------------- /figures/UMF_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLR-RM/UMF/77f71f94cd73d6623846eb7bf6fd87a7fe2e45ec/figures/UMF_architecture.png -------------------------------------------------------------------------------- /misc/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """Modified with utility for self-supervised learning.""" 2 | 3 | import math 4 | import warnings 5 | from typing import List 6 | 7 | from torch.optim import Optimizer 8 | from torch.optim.lr_scheduler import LambdaLR, _LRScheduler 9 | 10 | __all__ = ["LinearLR", "ExponentialLR"] 11 | 12 | 13 | class _LRSchedulerMONAI(_LRScheduler): 14 | """Base class for increasing the learning rate.""" 15 | 16 | def __init__( 17 | self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1 18 | ) -> None: 19 | """Initialize the _LRSchedulerMONAI scheduler. 20 | 21 | Args: 22 | optimizer: wrapped optimizer. 23 | end_lr: the final learning rate. 24 | num_iter: the number of iterations over which the test occurs. 25 | last_epoch: the index of last epoch. 26 | """ 27 | self.end_lr = end_lr 28 | self.num_iter = num_iter 29 | super().__init__(optimizer, last_epoch) 30 | 31 | 32 | class LinearLR(_LRSchedulerMONAI): 33 | """Linearly increases the learning rate.""" 34 | 35 | def get_lr(self): 36 | """Compute learning rate using base lr and end lr.""" 37 | # pylint: disable=invalid-name 38 | r = self.last_epoch / (self.num_iter - 1) 39 | return [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs] 40 | 41 | 42 | class ExponentialLR(_LRSchedulerMONAI): 43 | """Exponentially increases the learning rate.""" 44 | 45 | def get_lr(self): 46 | """Compute learning rate using base lr and end lr.""" 47 | # pylint: disable=invalid-name 48 | r = self.last_epoch / (self.num_iter - 1) 49 | return [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs] 50 | 51 | 52 | class WarmupCosineSchedule(LambdaLR): 53 | """Linear warmup and then cosine decay. 54 | Based on https://huggingface.co/ implementation. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | optimizer: Optimizer, 60 | warmup_steps: int, 61 | t_total: int, 62 | cycles: float = 0.5, 63 | last_epoch: int = -1, 64 | ) -> None: 65 | """Initialize the WarmupCosineSchedule scheduler. 66 | 67 | Args: 68 | optimizer: wrapped optimizer. 69 | warmup_steps: number of warmup iterations. 70 | t_total: total number of training iterations. 71 | cycles: cosine cycles parameter. 72 | last_epoch: the index of last epoch. 73 | """ 74 | self.warmup_steps = warmup_steps 75 | self.t_total = t_total 76 | self.cycles = cycles 77 | super().__init__(optimizer, self.lr_lambda, last_epoch) 78 | 79 | def lr_lambda(self, step): 80 | """Compute learning rate.""" 81 | if step < self.warmup_steps: 82 | return float(step) / float(max(1.0, self.warmup_steps)) 83 | progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps)) 84 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) 85 | 86 | 87 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 88 | """Linear warmup and then cosine decay.""" 89 | 90 | def __init__( 91 | self, 92 | optimizer: Optimizer, 93 | warmup_epochs: int, 94 | max_epochs: int, 95 | warmup_start_lr: float = 0.0, 96 | eta_min: float = 0.0, 97 | last_epoch: int = -1, 98 | ) -> None: 99 | """Initialize the LinearWarmupCosineAnnealingLR scheduler. 100 | 101 | Args: 102 | optimizer (Optimizer): Wrapped optimizer. 103 | warmup_epochs (int): Maximum number of iterations for linear warmup 104 | max_epochs (int): Maximum number of iterations 105 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 106 | eta_min (float): Minimum learning rate. Default: 0. 107 | last_epoch (int): The index of last epoch. Default: -1. 108 | """ 109 | self.warmup_epochs = warmup_epochs 110 | self.max_epochs = max_epochs 111 | self.warmup_start_lr = warmup_start_lr 112 | self.eta_min = eta_min 113 | 114 | super().__init__(optimizer, last_epoch) 115 | 116 | def get_lr(self) -> List[float]: 117 | """Compute linearly increasing learning rates for each parameter group.""" 118 | if not self._get_lr_called_within_step: 119 | warnings.warn( 120 | "To get the last learning rate computed by the scheduler, " 121 | "please use `get_last_lr()`.", 122 | UserWarning, 123 | ) 124 | 125 | if self.last_epoch == 0: 126 | return [self.warmup_start_lr] * len(self.base_lrs) 127 | if self.last_epoch < self.warmup_epochs: 128 | return [ 129 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 130 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 131 | ] 132 | if self.last_epoch == self.warmup_epochs: 133 | return self.base_lrs 134 | if (self.last_epoch - 1 - self.max_epochs) % ( 135 | 2 * (self.max_epochs - self.warmup_epochs) 136 | ) == 0: 137 | return [ 138 | group["lr"] 139 | + (base_lr - self.eta_min) 140 | * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) 141 | / 2 142 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 143 | ] 144 | 145 | return [ 146 | ( 147 | 1 148 | + math.cos( 149 | math.pi 150 | * (self.last_epoch - self.warmup_epochs) 151 | / (self.max_epochs - self.warmup_epochs) 152 | ) 153 | ) 154 | / ( 155 | 1 156 | + math.cos( 157 | math.pi 158 | * (self.last_epoch - self.warmup_epochs - 1) 159 | / (self.max_epochs - self.warmup_epochs) 160 | ) 161 | ) 162 | * (group["lr"] - self.eta_min) 163 | + self.eta_min 164 | for group in self.optimizer.param_groups 165 | ] 166 | 167 | def _get_closed_form_lr(self) -> List[float]: 168 | """Call when epoch is passed as a param to the `step` function of the scheduler.""" 169 | if self.last_epoch < self.warmup_epochs: 170 | return [ 171 | self.warmup_start_lr 172 | + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 173 | for base_lr in self.base_lrs 174 | ] 175 | 176 | return [ 177 | self.eta_min 178 | + 0.5 179 | * (base_lr - self.eta_min) 180 | * ( 181 | 1 182 | + math.cos( 183 | math.pi 184 | * (self.last_epoch - self.warmup_epochs) 185 | / (self.max_epochs - self.warmup_epochs) 186 | ) 187 | ) 188 | for base_lr in self.base_lrs 189 | ] 190 | -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | import yaml 5 | 6 | 7 | class ModelParams: 8 | def __init__(self, model_params_path): 9 | assert os.path.exists(model_params_path), 'Cannot find model-specific configuration file: {}'.format(model_params_path) 10 | with open(model_params_path, 'r') as f: 11 | params = yaml.safe_load(f) 12 | 13 | self.model_params_path = model_params_path 14 | self.params = params.get('model') 15 | self.model = self.params.get("name") 16 | 17 | def print(self): 18 | print('Model parameters:') 19 | param_dict = vars(self) 20 | for e in param_dict: 21 | print('{}: {}'.format(e, param_dict[e])) 22 | 23 | print('') 24 | 25 | def to_dict(self): 26 | param_dict = {} 27 | for e in param_dict: 28 | param_dict[e] = getattr(self, e) 29 | 30 | def get_datetime(): 31 | return time.strftime("%Y%m%d_%H%M") 32 | 33 | 34 | class UMFParams: 35 | """ 36 | Params for training MinkLoc models on Oxford dataset 37 | """ 38 | def __init__(self, params_path, model_params_path=None): 39 | """ 40 | Configuration files 41 | :param path: General configuration file 42 | :param model_params: Model-specific configuration 43 | """ 44 | 45 | assert os.path.exists(params_path), 'Cannot find configuration file: {}'.format(params_path) 46 | self.params_path = params_path 47 | self.model_params_path = model_params_path 48 | 49 | with open(self.params_path, 'r') as f: 50 | config = yaml.safe_load(f) 51 | 52 | params = config['DEFAULT'] 53 | 54 | self.dataset_path = params.get('dataset_path') 55 | self.dataset = params.get('dataset', 'robotcar') 56 | 57 | self.num_points = int(params.get('num_points', 4096)) 58 | self.dataset_folder = params.get('dataset_folder') 59 | self.use_cloud = bool(params.get('use_cloud', True)) 60 | self.debug = bool(params.get('debug', False)) 61 | if self.debug: 62 | torch.autograd.set_detect_anomaly(True) 63 | print('Debug mode: ON') 64 | 65 | # Train with RGB 66 | # Evaluate on Oxford only (no images for InHouse datasets) 67 | self.use_rgb = True 68 | self.image_path = params.get('image_path') 69 | 70 | if self.dataset.lower() == 'robotcar': 71 | if 'lidar2image_ndx_path' not in params: 72 | self.lidar2image_ndx_path = os.path.join(self.image_path, 'lidar2image_ndx.pickle') 73 | else: 74 | self.lidar2image_ndx_path = params.get('lidar2image_ndx_path') 75 | 76 | 77 | self.eval_database_files = ['oxford_evaluation_database.pickle'] 78 | self.eval_query_files = ['oxford_evaluation_query.pickle'] 79 | 80 | elif self.dataset.lower() == 'etna': 81 | self.eval_database_files = ["etna_evaluation_database.pickle"] 82 | self.eval_query_files = ["etna_evaluation_query.pickle"] 83 | 84 | assert len(self.eval_database_files) == len(self.eval_query_files) 85 | 86 | params = config['TRAIN'] 87 | self.num_workers = int(params.get('num_workers', 0)) 88 | self.train_step = params.get('train_step', 'single_step') 89 | self.batch_size = int(params.get('batch_size', 128)) 90 | # Validation batch size is fixed and does not grow 91 | self.val_batch_size = int(params.get('val_batch_size', 64)) 92 | 93 | # Set batch_expansion_th to turn on dynamic batch sizing 94 | # When number of non-zero triplets falls below batch_expansion_th, expand batch size 95 | self.batch_expansion_th = float(params.get('batch_expansion_th', None)) 96 | if self.batch_expansion_th is not None: 97 | assert 0. < self.batch_expansion_th < 1., 'batch_expansion_th must be between 0 and 1' 98 | self.batch_size_limit = int(params.get('batch_size_limit', 256)) 99 | # Batch size expansion rate 100 | self.batch_expansion_rate = float(params.get('batch_expansion_rate', 1.5)) 101 | assert self.batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 102 | else: 103 | self.batch_size_limit = self.batch_size 104 | self.batch_expansion_rate = None 105 | 106 | self.lr = float(params.get('lr', 1e-3)) 107 | # lr for image feature extraction 108 | self.image_lr = float(params.get('image_lr', 1e-4)) 109 | 110 | self.load_weights = None 111 | if "load_weights" in params: 112 | self.load_weights = params.get('load_weights') 113 | 114 | self.optimizer = params['optimizer'] 115 | 116 | self.scheduler = params.get('scheduler', 'MultiStepLR') 117 | if self.scheduler is not None: 118 | if self.scheduler == 'CosineAnnealingLR': 119 | self.min_lr = float(params.get('min_lr', 1e-5)) 120 | 121 | elif self.scheduler == 'WarmupCosineSchedule': 122 | self.warmup_steps = int(params.get('warmup_steps', 10)) 123 | 124 | elif self.scheduler == 'LinearWarmupCosineAnnealingLR': 125 | self.warmup_epochs = int(params.get('warmup_epochs', 10)) 126 | elif self.scheduler == 'OneCycleLR': 127 | pass 128 | elif self.scheduler == 'MultiStepLR': 129 | scheduler_milestones = params.get('scheduler_milestones') 130 | if not isinstance(scheduler_milestones, list): 131 | self.scheduler_milestones = scheduler_milestones 132 | else: 133 | self.scheduler_milestones = [int(e) for e in scheduler_milestones] 134 | elif self.scheduler == 'ReduceLROnPlateau': 135 | self.patience = int(params.get('patience', 2)) 136 | self.factor = float(params.get('factor', 0.9)) 137 | elif self.scheduler == 'ExpotentialLR': 138 | self.gamma = float(params.get('gamma', 0.95)) 139 | else: 140 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(self.scheduler)) 141 | 142 | self.epochs = int(params.get('epochs', 20)) 143 | self.weight_decay = float(params.get('weight_decay', None)) 144 | self.normalize_embeddings = bool(params.get('normalize_embeddings', True)) # Normalize embeddings during training and evaluation 145 | self.loss = params.get('loss') 146 | 147 | weights = params.get('weights', [0.3, 0.3, 0.3]) 148 | self.weights = [float(e) for e in weights] 149 | 150 | if 'Contrastive' in self.loss: 151 | self.pos_margin = float(params.get('pos_margin', 0.2)) 152 | self.neg_margin = float(params.get('neg_margin', 0.65)) 153 | elif 'Triplet' in self.loss: 154 | self.margin = float(params.get('margin', 0.4)) # Margin used in loss function 155 | else: 156 | raise 'Unsupported loss function: {}'.format(self.loss) 157 | 158 | self.aug_mode = int(params.get('aug_mode', 1)) # Augmentation mode (1 is default) 159 | 160 | self.train_file = params.get('train_file') 161 | self.val_file = params.get('val_file', None) 162 | 163 | # Read model parameters 164 | if self.model_params_path is not None: 165 | self.model_params = ModelParams(self.model_params_path) 166 | else: 167 | self.model_params = None 168 | 169 | self._check_params() 170 | 171 | def _check_params(self): 172 | assert os.path.exists(self.dataset_folder), 'Cannot access dataset: {}'.format(self.dataset_folder) 173 | 174 | def print(self): 175 | print('*Parameters:') 176 | param_dict = vars(self) 177 | for e in param_dict: 178 | if e not in ['model_params']: 179 | print('{}: {}'.format(e, param_dict[e])) 180 | 181 | if self.model_params is not None: 182 | self.model_params.print() 183 | print('') 184 | 185 | def to_dict(self): 186 | param_dict = {} 187 | for e in param_dict: 188 | if e not in ['model_params']: 189 | param_dict[e] = param_dict[e] 190 | 191 | if self.model_params is not None: 192 | param_dict["model_params"] = self.model_params.to_dict() 193 | -------------------------------------------------------------------------------- /models/UMF/UMFnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from how import layers 6 | import yaml 7 | import dotsi 8 | from models.UMF.resnet_fpn import ResNetFPN 9 | from models.UMF.utils.local_transformer import LocalFeatureTransformer 10 | from models.UMF.utils.local_transformer3D import LocalFeatureTransformer3D 11 | from models.UMF.utils.lit import LocalfeatureIntegrationTransformer 12 | from models.UMF.utils.lit3d import LocalfeatureIntegrationTransformer3D 13 | from models.UMF.voxel_encoder import Voxel_MAE 14 | from models.UMF.utils.multimodal_fusion import FusionEncoder 15 | 16 | 17 | def load_config(config_file): 18 | with open(config_file, 'r') as ymlfile: 19 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 20 | return cfg 21 | 22 | 23 | class UMFnet(torch.nn.Module): 24 | def __init__(self, cfg, grid_size=None, voxel_size=None, final_block: str = None): 25 | super().__init__() 26 | print(" final_block: ", final_block) 27 | point_cloud_range = cfg.model.point_cloud.range 28 | 29 | self.cloud_fe_size = cfg.model.point_cloud.out_dim 30 | self.image_fe_size = cfg.model.visual.out_dim 31 | pc_input_dim = cfg.model.point_cloud.input_dim 32 | img_input_dim = cfg.model.visual.input_dim 33 | cfg = dotsi.Dict(cfg) 34 | self.cfg = cfg 35 | self.fusion_only = cfg.model.mode== "fusion" 36 | 37 | grid_size = cfg.model.point_cloud.grid_size 38 | self.cloud_fe = Voxel_MAE(pc_input_dim, grid_size, voxel_size, point_cloud_range, 39 | fpn=not self.fusion_only) 40 | 41 | vis_model_name = cfg.model.visual.architecture 42 | self.image_fe = ResNetFPN(vis_model_name, input_dim=img_input_dim, 43 | fpn=not self.fusion_only) 44 | 45 | visual_dim = 2048 if vis_model_name == "resnet50" else 512 46 | lidar_dim = 128 47 | 48 | d_attn = cfg.model.fusion.d_attn 49 | num_heads = cfg.model.fusion.num_heads 50 | self.fusion_encoder = FusionEncoder(visual_dim, lidar_dim, d_attn, num_heads) 51 | self.output_dim = cfg.model.fusion.d_embedding 52 | self.final_block = final_block 53 | 54 | 55 | self.fused_dim = d_attn 56 | if self.final_block is None: 57 | self.final_net = None 58 | elif self.final_block == 'fc': 59 | self.final_net = nn.Linear(self.fused_dim, self.output_dim) 60 | 61 | elif self.final_block == 'mlp': 62 | temp_channels = self.output_dim 63 | self.final_net = nn.Sequential(nn.Linear(self.fused_dim, temp_channels, bias=False), 64 | nn.BatchNorm1d(temp_channels, affine=True), 65 | nn.ReLU(inplace=True), nn.Linear(temp_channels, self.output_dim)) 66 | else: 67 | raise NotImplementedError('Unsupported final block: {}'.format(self.final_block)) 68 | 69 | # FPN dim 70 | visual_dim = 512 if vis_model_name == "resnet50" else 128 71 | lidar_dim = 32 72 | 73 | if not self.fusion_only: 74 | if cfg.model.mode == "superfeatures": 75 | lit_cfg_im = cfg.model.visual.local_superfeatures 76 | lit_cfg_pc = cfg.model.point_cloud.local_superfeatures 77 | 78 | self.runtime = cfg.model.visual.local_superfeatures.runtime 79 | 80 | self.lit = LocalfeatureIntegrationTransformer(lit_cfg_im.T, lit_cfg_im.N, 81 | visual_dim, feat_dim=lit_cfg_im.dim, 82 | out_dim=lit_cfg_im.out_dim) 83 | self.lit3d = LocalfeatureIntegrationTransformer3D(lit_cfg_pc.T, lit_cfg_pc.N, 84 | lidar_dim, 85 | feat_dim=lit_cfg_pc.dim, 86 | out_dim=lit_cfg_pc.out_dim) 87 | if cfg.model.smoothing: 88 | self.smoothing = layers.smoothing.Smoothing() 89 | self.attention = layers.attention.L2Attention() 90 | self.attention3d = layers.attention.L2Attention() 91 | 92 | elif cfg.model.mode == "ransac": 93 | dim_local_feat_im = cfg.model.visual.local_ransac.dim 94 | dim_local_feat_pc = cfg.model.point_cloud.local_ransac.dim 95 | self.lt = LocalFeatureTransformer(input_dim=visual_dim, out_dim=dim_local_feat_im) 96 | self.lt3d = LocalFeatureTransformer3D(input_dim=lidar_dim, out_dim=dim_local_feat_pc) 97 | else: 98 | raise NotImplementedError('Unsupported mode: {}'.format(cfg.model.mode)) 99 | 100 | if cfg.model.visual.pretrained: 101 | print("load visual pretrained model: ", cfg.model.visual.pretrained) 102 | self.image_fe.load_checkpoint(cfg.model.visual.pretrained,) 103 | 104 | 105 | if cfg.model.point_cloud.pretrained: 106 | print("load point cloud pretrained model: ", cfg.model.point_cloud.pretrained) 107 | self.cloud_fe.load_checkpoint(cfg.model.point_cloud.pretrained) 108 | 109 | if cfg.model.pretrained: 110 | path_ckpt = cfg.model.pretrained 111 | pretrained_state_dict = torch.load(path_ckpt, map_location='cpu') 112 | dim_reduction_prefix = '.reduction_layer.' 113 | filtered_state_dict = {k: v for k, v in pretrained_state_dict.items() if dim_reduction_prefix not in k} 114 | 115 | self.load_state_dict(filtered_state_dict, strict=False) 116 | print("--> UMF - load pretrained model: ", path_ckpt) 117 | 118 | #self.image_fe.freeze_backbone() 119 | #self.cloud_fe.freeze_backbone() 120 | #self.freeze_fusion_branch() 121 | 122 | def freeze_fusion_branch(self): 123 | for param in self.fusion_encoder.parameters(): 124 | param.requires_grad = False 125 | 126 | 127 | def get_superfeatures(self, x, dim_reduction=False): 128 | """ 129 | return a list of tuple (features, attentionmpas) where each is a list containing requested scales 130 | features is a tensor BxDxNx1 131 | attentionmaps is a tensor BxNxHxW 132 | """ 133 | feats = [] 134 | attns = [] 135 | strengths = [] 136 | 137 | o, attn = self.lit(x, dim_reduction=dim_reduction) 138 | strength = self.attention(o) 139 | if self.cfg.model.smoothing: 140 | o = self.smoothing(o) 141 | 142 | feats = o.permute(0, 2, 1) 143 | attns = attn 144 | if strength.dim() != 3: 145 | strength = strength.unsqueeze(0) 146 | 147 | strengths = strength 148 | return feats, attns, strengths 149 | 150 | 151 | 152 | def get_local_features(self, x,dim_reduction=False, normalize=False): 153 | embed, feats, attn = self.lt(x, dim_reduction=dim_reduction, normalize=normalize) 154 | return embed, feats, attn 155 | 156 | 157 | def get_superfeatures_points(self, x, dim_reduction=False): 158 | """ 159 | return a list of tuple (features, attentionmpas) where each is a list containing requested scales 160 | features is a tensor BxDxNx1 161 | attentionmaps is a tensor BxNxHxW 162 | """ 163 | feats = [] 164 | attns = [] 165 | strengths = [] 166 | 167 | o, attn = self.lit3d(x, dim_reduction=dim_reduction) 168 | strength = self.attention3d(o) 169 | if self.cfg.model.smoothing: 170 | o = self.smoothing(o) 171 | feats = o.permute(0, 2, 1) 172 | attns = attn 173 | if strength.dim() != 3: 174 | strength = strength.unsqueeze(0) 175 | strengths = strength 176 | return feats, attns, strengths 177 | 178 | 179 | def get_local_features_3d(self, x, dim_reduction=False, normalize=False): 180 | embed, feats, attn = self.lt3d(x, dim_reduction=dim_reduction, normalize=normalize) 181 | return embed, feats, attn 182 | 183 | 184 | def forward(self, batch, dim_reduction=False, normalize=False): 185 | y = {} 186 | y_img = self.image_fe(batch["images"]) 187 | 188 | # cloud features 189 | batch_dict = { 190 | "voxel_features": batch["voxel_features"], 191 | "coordinates": batch["coordinates"], 192 | "batch_size": batch["images"].shape[0], 193 | } 194 | y_cloud = self.cloud_fe(batch_dict) 195 | 196 | if not self.fusion_only: 197 | fine_img = y_img['fine_2'] 198 | fine_cloud = y_cloud['fine_2'] 199 | 200 | im_feat = y_img['coarse'] 201 | pc_feat = y_cloud['coarse'] 202 | out = self.fusion_encoder(im_feat, pc_feat) 203 | 204 | if self.final_block is not None: 205 | out = self.final_net(out) 206 | y['embedding'] = out 207 | 208 | if self.cfg.model.mode == "superfeatures": 209 | img_super_feat, img_attns, img_strengths = self.get_superfeatures(fine_img, dim_reduction=dim_reduction) 210 | 211 | y['img_super_feat'] = img_super_feat 212 | y["img_attns"] = img_attns 213 | y["img_strengths"] = img_strengths 214 | 215 | pc_super_feat, pc_attns, pc_strengths = self.get_superfeatures_points(fine_cloud, dim_reduction=dim_reduction) 216 | y['pc_super_feat'] = pc_super_feat 217 | y["pc_attns"] = pc_attns 218 | y["pc_strengths"] = pc_strengths 219 | 220 | 221 | if self.cfg.model.mode == "ransac": 222 | # Visual local 223 | img_embed, img_local_feat, img_attn = self.get_local_features(fine_img, 224 | dim_reduction=dim_reduction, normalize=normalize) 225 | y['img_local_feat'] = img_local_feat 226 | y['image_embedding'] = img_embed 227 | y['img_fine_feat'] = img_local_feat 228 | y["img_attns"] = img_attn 229 | 230 | pc_embed, pc_local_feat, pc_attns = self.get_local_features_3d(fine_cloud, 231 | dim_reduction=dim_reduction, normalize=normalize) 232 | y['pc_local_feat'] = pc_local_feat 233 | y['pc_fine_feat'] = pc_local_feat 234 | y["pc_attns"] = pc_attns 235 | y['cloud_embedding'] = pc_embed 236 | 237 | return y 238 | 239 | def print_info(self): 240 | print('Model class: UMFnet') 241 | n_params = sum([param.nelement() for param in self.parameters()]) 242 | print('Total parameters: {}'.format(n_params)) 243 | 244 | n_params = sum([param.nelement() for param in self.image_fe.parameters()]) 245 | print('Image feature extractor parameters: {}'.format(n_params)) 246 | 247 | n_params = sum([param.nelement() for param in self.cloud_fe.parameters()]) 248 | print('Cloud feature extractor parameters: {}'.format(n_params)) 249 | 250 | # Fusion branch 251 | if hasattr(self, 'fusion_encoder'): 252 | n_params = sum([param.nelement() for param in self.fusion_encoder.parameters()]) 253 | print('Fusion model parameters: {}'.format(n_params)) 254 | 255 | 256 | print('Final block: {}'.format(self.final_block)) 257 | print('Dimensionality of cloud features: {}'.format(self.cloud_fe_size)) 258 | print('Dimensionality of image features: {}'.format(self.image_fe_size)) 259 | print('Dimensionality of final descriptor: {}'.format(self.output_dim)) 260 | 261 | 262 | 263 | -------------------------------------------------------------------------------- /models/UMF/UMFnet.yml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | name: "UMF" 4 | pretrained: null 5 | mode: "fusion" # superfeatures, ransac, fusion 6 | 7 | visual: 8 | architecture: resnet50 # Backbone network 9 | pretrained: null 10 | out_dim: 256 11 | image_size: 224 12 | input_dim: 3 13 | 14 | point_cloud: 15 | architecture: VoxelNet 16 | pretrained: null 17 | out_dim: 256 18 | input_dim: 3 19 | grid_size: [200, 200, 200] 20 | max_num_points_per_voxel: 5 21 | num_point_features: 3 22 | max_num_voxels: 4096 23 | voxel_size: [0.01, 0.01, 0.01] 24 | range: [-1, -1, -1, 1, 1, 1] # xmin, ymin, zmin, xmax, ymax, zmax 25 | 26 | fusion: 27 | d_attn: 512 28 | d_embedding: 256 29 | num_heads: 8 30 | final_block: fc #fc mpl null -------------------------------------------------------------------------------- /models/UMF/UMFnet_ransac.yml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | name: "UMF" 4 | pretrained: null 5 | mode: "ransac" # superfeatures, ransac, fusion 6 | smoothing: false 7 | local_feat_margin: [0.4, 0.4] 8 | 9 | visual: 10 | architecture: resnet50 # Backbone network 11 | pretrained: null 12 | out_dim: 256 13 | image_size: 224 14 | input_dim: 3 15 | 16 | local_ransac: 17 | num_heads: 8 18 | hidden_dim: 512 19 | dim: 128 20 | 21 | point_cloud: 22 | architecture: VoxelNet # Backbone network 23 | pretrained: null 24 | out_dim: 256 25 | input_dim: 3 26 | grid_size: [200, 200, 200] 27 | max_num_points_per_voxel: 5 28 | num_point_features: 3 29 | max_num_voxels: 4096 30 | voxel_size: [0.01, 0.01, 0.01] 31 | range: [-1, -1, -1, 1, 1, 1] # xmin, ymin, zmin, xmax, ymax, zmax 32 | 33 | local_ransac: 34 | num_heads: 8 35 | hidden_dim: 256 36 | dim: 32 37 | 38 | fusion: 39 | d_attn: 512 40 | d_embedding: 256 41 | num_heads: 8 42 | final_block: fc #fc mpl null -------------------------------------------------------------------------------- /models/UMF/UMFnet_superfeat.yml: -------------------------------------------------------------------------------- 1 | 2 | model: 3 | name: "UMF" 4 | pretrained: null 5 | mode: "superfeatures" # superfeatures, ransac, fusion 6 | smoothing: false 7 | local_feat_weights: [0.02, 0.02] # pc, img 8 | local_attn_weights: [0.01, 0.01] 9 | local_feat_margin: [1.1, 1.1] 10 | 11 | visual: 12 | architecture: resnet50 # Backbone network 13 | pretrained: "" # path to the pretrained model 14 | out_dim: 128 # The output dimension 15 | image_size: 224 16 | input_dim: 3 17 | 18 | local_superfeatures: 19 | T: 6 20 | N: 64 21 | dim: 1024 22 | out_dim: 128 23 | 24 | point_cloud: 25 | architecture: VoxelNet # Backbone network 26 | pretrained: "" 27 | out_dim: 128 # The output dimension 28 | input_dim: 3 # The input dimension 29 | grid_size: [200, 200, 200] 30 | max_num_points_per_voxel: 5 31 | num_point_features: 3 32 | max_num_voxels: 4096 33 | voxel_size: [0.01, 0.01, 0.01] 34 | range: [-1, -1, -1, 1, 1, 1] # xmin, ymin, zmin, xmax, ymax, zmax 35 | 36 | local_superfeatures: 37 | T: 6 38 | N: 128 39 | dim: 256 40 | out_dim: 32 41 | 42 | fusion: 43 | d_attn: 512 44 | d_embedding: 256 45 | num_heads: 8 46 | final_block: fc #fc mpl null -------------------------------------------------------------------------------- /models/UMF/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DLR-RM/UMF/77f71f94cd73d6623846eb7bf6fd87a7fe2e45ec/models/UMF/__init__.py -------------------------------------------------------------------------------- /models/UMF/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import timm 5 | 6 | def conv1x1(in_planes, out_planes, stride=1): 7 | """1x1 convolution without padding""" 8 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) 9 | 10 | 11 | def conv3x3(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | class ResNetFPN(nn.Module): 17 | """ 18 | ResNet+FPN, output resolution are 1/16 and 1/4. 19 | Each block has 2 layers. 20 | """ 21 | 22 | def __init__(self, slug, input_dim=3, fpn=True): 23 | super().__init__() 24 | # Config 25 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | if slug == "resnet18": 28 | block_dims = [64, 128, 256, 512] 29 | 30 | elif slug == "resnet50": 31 | block_dims = [256, 512, 1024, 2048] 32 | 33 | slug = "se" + slug 34 | self.fpn = fpn 35 | self.resnet = timm.create_model(slug, pretrained=True, in_chans=input_dim) 36 | 37 | if input_dim != 3: 38 | print("Changing input channels from 3 to {}".format(input_dim)) 39 | self.resnet.conv1 = nn.Conv2d(input_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) 40 | if fpn: 41 | self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) 42 | self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) 43 | self.layer3_outconv2 = nn.Sequential( 44 | conv3x3(block_dims[3], block_dims[3]), 45 | nn.BatchNorm2d(block_dims[3]), 46 | nn.LeakyReLU(), 47 | conv3x3(block_dims[3], block_dims[2]), 48 | ) 49 | 50 | self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) 51 | self.layer2_outconv2 = nn.Sequential( 52 | conv3x3(block_dims[2], block_dims[2]), 53 | nn.BatchNorm2d(block_dims[2]), 54 | nn.LeakyReLU(), 55 | conv3x3(block_dims[2], block_dims[1]), 56 | ) 57 | 58 | def load_checkpoint(self, ckpt_path): 59 | print("Loading checkpoint from {}".format(ckpt_path)) 60 | checkpoint = torch.load(ckpt_path, map_location=self.device) 61 | self.load_state_dict(checkpoint, strict=False) 62 | print("Loaded checkpoint from {}".format(ckpt_path)) 63 | 64 | 65 | def freeze_backbone(self): 66 | for param in self.resnet.parameters(): 67 | param.requires_grad = False 68 | 69 | 70 | def _make_layer(self, block, dim, stride=1): 71 | layer1 = block(self.in_planes, dim, stride=stride) 72 | layer2 = block(dim, dim, stride=1) 73 | layers = (layer1, layer2) 74 | 75 | self.in_planes = dim 76 | return nn.Sequential(*layers) 77 | 78 | def forward(self, x): 79 | # ResNet Backbone 80 | y = {} 81 | x0 = self.resnet.conv1(x) 82 | x0 = self.resnet.bn1(x0) 83 | x0 = F.relu(x0) 84 | #x0 = self.resnet.maxpool(x0) 85 | 86 | x1 = self.resnet.layer1(x0) # 1/2 87 | x2 = self.resnet.layer2(x1) # 1/4 88 | x3 = self.resnet.layer3(x2) # 1/8 89 | x4 = self.resnet.layer4(x3) # 1/16 90 | 91 | if not self.fpn: 92 | y["coarse"] = x4 93 | return y 94 | 95 | # FPN 96 | x4_out = self.layer4_outconv(x4) 97 | 98 | x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) 99 | x3_out = self.layer3_outconv(x3) 100 | x3_out = self.layer3_outconv2(x3_out+x4_out_2x) 101 | 102 | x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) 103 | x2_out = self.layer2_outconv(x2) 104 | x2_out = self.layer2_outconv2(x2_out+x3_out_2x) 105 | 106 | #y["fine_1"] = x3_out 107 | y["fine_2"] = x2_out 108 | y["coarse"] = x4 109 | return y 110 | -------------------------------------------------------------------------------- /models/UMF/spconv_backbone.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch.nn as nn 4 | 5 | from models.UMF.utils.spconv_utils import replace_feature, spconv 6 | from spconv.pytorch.modules import SparseModule, SparseSequential 7 | 8 | 9 | def post_act_block(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, 10 | conv_type='subm', norm_fn=None): 11 | 12 | if conv_type == 'subm': 13 | conv = spconv.SubMConv3d(in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key, algo=spconv.ConvAlgo.Native) 14 | elif conv_type == 'spconv': 15 | conv = spconv.SparseConv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, 16 | bias=False, indice_key=indice_key, algo=spconv.ConvAlgo.Native) 17 | elif conv_type == 'inverseconv': 18 | conv = spconv.SparseInverseConv3d(in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False, algo=spconv.ConvAlgo.Native) 19 | else: 20 | raise NotImplementedError 21 | 22 | m = spconv.SparseSequential( 23 | conv, 24 | norm_fn(out_channels), 25 | nn.GELU(), 26 | ) 27 | 28 | return m 29 | 30 | 31 | class SparseBasicBlock(SparseModule): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, norm_fn=None, downsample=None, indice_key=None): 35 | super(SparseBasicBlock, self).__init__() 36 | 37 | assert norm_fn is not None 38 | bias = norm_fn is not None 39 | self.conv1 = spconv.SubMConv3d( 40 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key 41 | ) 42 | self.bn1 = norm_fn(planes) 43 | self.relu = nn.GELU() 44 | self.conv2 = spconv.SubMConv3d( 45 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=bias, indice_key=indice_key 46 | ) 47 | self.bn2 = norm_fn(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | identity = x 53 | 54 | out = self.conv1(x) 55 | out = replace_feature(out, self.bn1(out.features)) 56 | out = replace_feature(out, self.relu(out.features)) 57 | 58 | out = self.conv2(out) 59 | out = replace_feature(out, self.bn2(out.features)) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out = replace_feature(out, out.features + identity.features) 65 | out = replace_feature(out, self.relu(out.features)) 66 | 67 | return out 68 | 69 | 70 | class VoxelBackBone8x(nn.Module): 71 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs): 72 | super().__init__() 73 | self.model_cfg = model_cfg 74 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 75 | 76 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 77 | 78 | self.conv_input = spconv.SparseSequential( 79 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), 80 | norm_fn(16), 81 | nn.GELU(), 82 | ) 83 | block = post_act_block 84 | 85 | self.conv1 = spconv.SparseSequential( 86 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), 87 | ) 88 | 89 | self.conv2 = spconv.SparseSequential( 90 | # [1600, 1408, 41] <- [800, 704, 21] 91 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 92 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 93 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 94 | ) 95 | 96 | self.conv3 = spconv.SparseSequential( 97 | # [800, 704, 21] <- [400, 352, 11] 98 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 99 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 100 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 101 | ) 102 | 103 | self.conv4 = spconv.SparseSequential( 104 | # [400, 352, 11] <- [200, 176, 5] 105 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), 106 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 107 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 108 | ) 109 | 110 | last_pad = 0 111 | last_pad = self.model_cfg.get('last_pad', last_pad) 112 | self.conv_out = spconv.SparseSequential( 113 | # [200, 150, 5] -> [200, 150, 2] 114 | spconv.SparseConv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, 115 | bias=False, indice_key='spconv_down2'), 116 | norm_fn(128), 117 | nn.GELU(), 118 | ) 119 | self.num_point_features = 128 120 | self.backbone_channels = { 121 | 'x_conv1': 16, 122 | 'x_conv2': 32, 123 | 'x_conv3': 64, 124 | 'x_conv4': 64 125 | } 126 | 127 | 128 | 129 | def forward(self, batch_dict): 130 | """ 131 | Args: 132 | batch_dict: 133 | batch_size: int 134 | vfe_features: (num_voxels, C) 135 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 136 | Returns: 137 | batch_dict: 138 | encoded_spconv_tensor: sparse tensor 139 | """ 140 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords'] 141 | batch_size = batch_dict['batch_size'] 142 | input_sp_tensor = spconv.SparseConvTensor( 143 | features=voxel_features, 144 | indices=voxel_coords.int(), 145 | spatial_shape=self.sparse_shape, 146 | batch_size=batch_size 147 | ) 148 | 149 | x = self.conv_input(input_sp_tensor) 150 | 151 | x_conv1 = self.conv1(x) 152 | x_conv2 = self.conv2(x_conv1) 153 | x_conv3 = self.conv3(x_conv2) 154 | x_conv4 = self.conv4(x_conv3) 155 | 156 | # for detection head 157 | # [200, 176, 5] -> [200, 176, 2] 158 | out = self.conv_out(x_conv4) 159 | 160 | batch_dict.update({ 161 | 'encoded_spconv_tensor': out, 162 | 'encoded_spconv_tensor_stride': 8 163 | }) 164 | batch_dict.update({ 165 | 'multi_scale_3d_features': { 166 | 'x_conv1': x_conv1, 167 | 'x_conv2': x_conv2, 168 | 'x_conv3': x_conv3, 169 | 'x_conv4': x_conv4, 170 | } 171 | }) 172 | batch_dict.update({ 173 | 'multi_scale_3d_strides': { 174 | 'x_conv1': 1, 175 | 'x_conv2': 2, 176 | 'x_conv3': 4, 177 | 'x_conv4': 8, 178 | } 179 | }) 180 | 181 | return batch_dict 182 | 183 | 184 | class VoxelResBackBone8x(nn.Module): 185 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs): 186 | super().__init__() 187 | self.model_cfg = model_cfg 188 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 189 | 190 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 191 | 192 | self.conv_input = spconv.SparseSequential( 193 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1'), 194 | norm_fn(16), 195 | nn.GELU(), 196 | ) 197 | block = post_act_block 198 | 199 | self.conv1 = spconv.SparseSequential( 200 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), 201 | SparseBasicBlock(16, 16, norm_fn=norm_fn, indice_key='res1'), 202 | ) 203 | 204 | self.conv2 = spconv.SparseSequential( 205 | # [1600, 1408, 41] <- [800, 704, 21] 206 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 207 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), 208 | SparseBasicBlock(32, 32, norm_fn=norm_fn, indice_key='res2'), 209 | ) 210 | 211 | self.conv3 = spconv.SparseSequential( 212 | # [800, 704, 21] <- [400, 352, 11] 213 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 214 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), 215 | SparseBasicBlock(64, 64, norm_fn=norm_fn, indice_key='res3'), 216 | ) 217 | 218 | self.conv4 = spconv.SparseSequential( 219 | # [400, 352, 11] <- [200, 176, 5] 220 | block(64, 128, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='spconv'), 221 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), 222 | SparseBasicBlock(128, 128, norm_fn=norm_fn, indice_key='res4'), 223 | ) 224 | 225 | last_pad = 0 226 | last_pad = self.model_cfg.get('last_pad', last_pad) 227 | self.conv_out = spconv.SparseSequential( 228 | # [200, 150, 5] -> [200, 150, 2] 229 | spconv.SparseConv3d(128, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, 230 | bias=False, indice_key='spconv_down2'), 231 | norm_fn(128), 232 | nn.GELU(), 233 | ) 234 | self.num_point_features = 128 235 | self.backbone_channels = { 236 | 'x_conv1': 16, 237 | 'x_conv2': 32, 238 | 'x_conv3': 64, 239 | 'x_conv4': 128 240 | } 241 | 242 | def forward(self, batch_dict): 243 | """ 244 | Args: 245 | batch_dict: 246 | batch_size: int 247 | vfe_features: (num_voxels, C) 248 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 249 | Returns: 250 | batch_dict: 251 | encoded_spconv_tensor: sparse tensor 252 | """ 253 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords'] 254 | batch_size = batch_dict['batch_size'] 255 | input_sp_tensor = spconv.SparseConvTensor( 256 | features=voxel_features, 257 | indices=voxel_coords.int(), 258 | spatial_shape=self.sparse_shape, 259 | batch_size=batch_size 260 | ) 261 | x = self.conv_input(input_sp_tensor) 262 | 263 | x_conv1 = self.conv1(x) 264 | x_conv2 = self.conv2(x_conv1) 265 | x_conv3 = self.conv3(x_conv2) 266 | x_conv4 = self.conv4(x_conv3) 267 | 268 | # for detection head 269 | # [200, 176, 5] -> [200, 176, 2] 270 | out = self.conv_out(x_conv4) 271 | 272 | batch_dict.update({ 273 | 'encoded_spconv_tensor': out, 274 | 'encoded_spconv_tensor_stride': 8 275 | }) 276 | batch_dict.update({ 277 | 'multi_scale_3d_features': { 278 | 'x_conv1': x_conv1, 279 | 'x_conv2': x_conv2, 280 | 'x_conv3': x_conv3, 281 | 'x_conv4': x_conv4, 282 | } 283 | }) 284 | 285 | batch_dict.update({ 286 | 'multi_scale_3d_strides': { 287 | 'x_conv1': 1, 288 | 'x_conv2': 2, 289 | 'x_conv3': 4, 290 | 'x_conv4': 8, 291 | } 292 | }) 293 | 294 | return batch_dict 295 | -------------------------------------------------------------------------------- /models/UMF/utils/lit.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2022 Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | from torch import nn 6 | import math 7 | from how import layers 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | class LocalfeatureIntegrationTransformer(nn.Module): 13 | """Map a set of local features to a fixed number of SuperFeatures """ 14 | 15 | def __init__(self, T, N, input_dim, feat_dim, out_dim): 16 | """ 17 | T: number of iterations 18 | N: number of SuperFeatures 19 | input_dim: dimension of input local features 20 | dim: dimension of SuperFeatures 21 | """ 22 | super().__init__() 23 | self.T = T 24 | self.N = N 25 | self.input_dim = input_dim 26 | dim = feat_dim 27 | 28 | # learnable initialization 29 | self.templates_init = nn.Parameter(torch.randn(1,self.N,dim)) 30 | 31 | # qkv 32 | self.project_q = nn.Linear(dim, dim, bias=False) 33 | self.project_k = nn.Linear(input_dim, dim, bias=False) 34 | self.project_v = nn.Linear(input_dim, dim, bias=False) 35 | 36 | self.norm_inputs = nn.LayerNorm(input_dim) 37 | self.norm_templates = nn.LayerNorm(dim) 38 | self.softmax = nn.Softmax(dim=-1) 39 | self.dropout = nn.Dropout(0.1) 40 | self.scale = dim ** -0.5 41 | # mlp 42 | self.norm_mlp = nn.LayerNorm(dim) 43 | mlp_dim = dim*2 44 | self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim), nn.LeakyReLU(), nn.Linear(mlp_dim, dim) ) 45 | 46 | self.reduction_layer = layers.dim_reduction.LinearDimReduction(dim=out_dim, input_dim=dim) 47 | 48 | 49 | def forward(self, x, dim_reduction=False): 50 | """ 51 | input: 52 | x has shape BxCxHxW 53 | output: 54 | template (output SuperFeatures): tensor of shape BxCxNx1 55 | attn (attention over local features at the last iteration): tensor of shape BxNxHxW 56 | """ 57 | # reshape inputs from BxCxHxW to Bx(H*W)xC 58 | B,C,H,W = x.size() 59 | x = x.reshape(B,C,H*W).permute(0,2,1) 60 | 61 | x = self.norm_inputs(x) 62 | k = self.project_k(x) 63 | v = self.project_v(x) 64 | 65 | # template initialization 66 | templates = torch.repeat_interleave(self.templates_init, B, dim=0) 67 | attn = None 68 | 69 | # main iteration loop 70 | for _ in range(self.T): 71 | templates_prev = templates 72 | 73 | # q projection 74 | templates = self.norm_templates(templates) 75 | q = self.project_q(templates) 76 | 77 | q = q * self.scale # Normalization. 78 | 79 | 80 | attn_logits = torch.einsum('bnd,bld->bln', q, k) 81 | attn = self.softmax(attn_logits) 82 | attn = attn + 1e-8 # to avoid zero when with the L1 norm below 83 | attn = attn / attn.sum(dim=-2, keepdim=True) 84 | 85 | # update template 86 | templates = templates_prev + torch.einsum('bld,bln->bnd', v, attn) 87 | 88 | # mlp 89 | templates = templates + self.mlp(self.norm_mlp(templates)) 90 | 91 | # reshape templates to BxDxNx1 92 | templates = templates.permute(0,2,1)[:,:,:] 93 | attn = attn.permute(0,2,1).view(B,self.N,H,W) 94 | 95 | if dim_reduction: 96 | with torch.no_grad(): 97 | templates = self.reduction_layer(templates) 98 | 99 | return templates, attn 100 | 101 | def __repr__(self): 102 | s = str(self.__class__.__name__) 103 | for k in ["T","N","input_dim","dim"]: 104 | s += "\n {:s}: {:d}".format(k, getattr(self,k)) 105 | return s -------------------------------------------------------------------------------- /models/UMF/utils/lit3d.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2021-2022 Naver Corporation. All rights reserved. 2 | # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). 3 | 4 | import torch 5 | from torch import nn 6 | import math 7 | from how import layers 8 | import torch.nn.functional as F 9 | 10 | 11 | class LocalfeatureIntegrationTransformer3D(nn.Module): 12 | """Map a set of local features to a fixed number of SuperFeatures """ 13 | 14 | def __init__(self, T, N, input_dim, feat_dim, out_dim): 15 | """ 16 | T: number of iterations 17 | N: number of SuperFeatures 18 | input_dim: dimension of input local features 19 | dim: dimension of SuperFeatures 20 | """ 21 | super().__init__() 22 | self.T = T 23 | self.N = N 24 | self.input_dim = input_dim 25 | dim = feat_dim 26 | hidden_dim = input_dim*2 27 | 28 | # learnable initialization 29 | self.templates_init = nn.Parameter(torch.randn(1, self.N, dim)) 30 | self.project = nn.Sequential( 31 | nn.Conv3d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=0), 32 | nn.ReLU(), 33 | ) 34 | 35 | # qkv 36 | self.project_q = nn.Linear(dim, dim, bias=False) 37 | self.project_k = nn.Linear(hidden_dim, dim, bias=False) 38 | self.project_v = nn.Linear(hidden_dim, dim, bias=False) 39 | # layer norms 40 | self.norm_inputs = nn.LayerNorm(hidden_dim) 41 | self.norm_templates = nn.LayerNorm(dim) 42 | self.softmax = nn.Softmax(dim=-1) 43 | self.scale = dim ** -0.5 44 | # mlp 45 | self.norm_mlp = nn.LayerNorm(dim) 46 | mlp_dim = dim*2 47 | self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, dim) ) 48 | self.reduction_layer = layers.dim_reduction.LinearDimReduction(dim=out_dim, input_dim=dim) 49 | 50 | 51 | def forward(self, x, dim_reduction=False): 52 | """ 53 | input: 54 | x has shape BxCxXxYxZ 55 | output: 56 | template (output SuperFeatures): tensor of shape BxCxNx1 57 | attn (attention over local features at the last iteration): tensor of shape BxNxXxYxZ 58 | """ 59 | x = self.project(x) 60 | B, C, X, Y, Z = x.size() 61 | 62 | x = x.reshape(B, C, X * Y * Z).permute(0, 2, 1) 63 | 64 | x = self.norm_inputs(x) 65 | k = self.project_k(x) 66 | v = self.project_v(x) 67 | 68 | # template initialization 69 | templates = torch.repeat_interleave(self.templates_init, B, dim=0) 70 | attn = None 71 | 72 | # main iteration loop 73 | for _ in range(self.T): 74 | templates_prev = templates 75 | 76 | # q projection 77 | templates = self.norm_templates(templates) 78 | q = self.project_q(templates) 79 | 80 | q = q * self.scale # Normalization. 81 | 82 | attn_logits = torch.einsum('bnd,bld->bln', q, k) 83 | attn = self.softmax(attn_logits) 84 | attn = attn + 1e-8 # to avoid zero when with the L1 norm below 85 | attn = attn / attn.sum(dim=-2, keepdim=True) 86 | 87 | # update template 88 | templates = templates_prev + torch.einsum('bld,bln->bnd', v, attn) 89 | 90 | # mlp 91 | templates = templates + self.mlp(self.norm_mlp(templates)) 92 | 93 | # reshape templates to BxDxNx1 94 | templates = templates.permute(0,2,1)[:,:,:] 95 | attn = attn.permute(0,2,1).view(B,self.N,X,Y,Z) 96 | 97 | if dim_reduction: 98 | with torch.no_grad(): 99 | templates = self.reduction_layer(templates) 100 | 101 | return templates, attn 102 | 103 | def __repr__(self): 104 | s = str(self.__class__.__name__) 105 | for k in ["T","N","input_dim","dim"]: 106 | s += "\n {:s}: {:d}".format(k, getattr(self,k)) 107 | return s -------------------------------------------------------------------------------- /models/UMF/utils/local_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .pooling import rgem, gemp, relup 5 | from .swin_transformer import TransformerEncoder 6 | from how import layers 7 | 8 | 9 | class LocalFeatureTransformer(nn.Module): 10 | def __init__(self, input_dim, out_dim, num_heads=8, depth=2, **kwargs): 11 | super().__init__() 12 | self.input_dim = input_dim 13 | self.out_dim = out_dim 14 | 15 | self.transformer_encoder = TransformerEncoder(input_dim, depth=depth, drop_path=0.0, 16 | num_heads=num_heads, window_size=(7, 7)) 17 | 18 | self.pool = nn.Sequential( 19 | nn.Conv2d(input_dim, 256, kernel_size=3, stride=2, padding=0), 20 | relup(), 21 | rgem(), 22 | gemp(emb_dim=1), 23 | nn.Flatten() 24 | ) 25 | self.reduction_layer = layers.dim_reduction.ConvDimReduction(dim=out_dim, input_dim=input_dim) 26 | 27 | def forward(self, x, dim_reduction=False, normalize=False): 28 | embed = x 29 | 30 | out, attn_maps = self.transformer_encoder(embed) 31 | embeddings = self.pool(out) 32 | 33 | if normalize: 34 | out = F.normalize(out, p=2, dim=-1) 35 | embeddings = F.normalize(embeddings, p=2, dim=-1) 36 | 37 | if dim_reduction: 38 | with torch.no_grad(): 39 | out = self.reduction_layer(out) 40 | 41 | return embeddings, out.detach(), attn_maps 42 | 43 | 44 | def __repr__(self): 45 | s = str(self.__class__.__name__) 46 | for k in ["input_dim","out_dim"]: 47 | s += "\n {:s}: {:d}".format(k, getattr(self,k)) 48 | return s -------------------------------------------------------------------------------- /models/UMF/utils/local_transformer3D.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from .pooling import relup, gemp3D 5 | from .swin_transformer import TransformerEncoder 6 | from how import layers 7 | 8 | 9 | class LocalFeatureTransformer3D(nn.Module): 10 | def __init__(self, input_dim, out_dim, num_heads=8, depth=1, **kwargs): 11 | super().__init__() 12 | self.input_dim = input_dim 13 | self.out_dim = out_dim 14 | 15 | self.pool = nn.Sequential( 16 | nn.Conv3d(input_dim, 256, kernel_size=3, stride=2, padding=0), 17 | relup(), 18 | gemp3D(emb_dim=1), 19 | nn.Flatten(), 20 | ) 21 | 22 | self.transformer_encoder = TransformerEncoder(input_dim, depth=depth, drop_path=0.0, 23 | num_heads=num_heads, window_size=(3, 3, 3)) 24 | 25 | self.reduction_layer = layers.dim_reduction.ConvDimReduction3D(dim=out_dim, input_dim=input_dim) 26 | 27 | def forward(self, x, dim_reduction=False, normalize=False): 28 | embed = x 29 | 30 | out, attn_maps = self.transformer_encoder(embed) 31 | embeddings = self.pool(out) 32 | 33 | if normalize: 34 | out = F.normalize(out, p=2, dim=-1) 35 | embeddings = F.normalize(out, p=2, dim=-1) 36 | 37 | if dim_reduction: 38 | with torch.no_grad(): 39 | out = self.reduction_layer(out) 40 | 41 | 42 | return embeddings, out.detach(), attn_maps 43 | 44 | def __repr__(self): 45 | s = str(self.__class__.__name__) 46 | for k in ["input_dim","out_dim"]: 47 | s += "\n {:s}: {:d}".format(k, getattr(self,k)) 48 | return s 49 | 50 | -------------------------------------------------------------------------------- /models/UMF/utils/multimodal_fusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.UMF.utils.pooling import rgem, relup, gemp, gemp3D 4 | 5 | 6 | class FusionEncoder(nn.Module): 7 | """Image and Voxel Fusion Encoder for multimodal fusion bnanch""" 8 | 9 | def __init__(self, visual_dim, lidar_dim, d_attn, num_heads): 10 | super().__init__() 11 | 12 | emb_dim = 64 13 | self.visual_pos_enc = nn.Parameter(torch.randn(emb_dim, d_attn)) 14 | self.lidar_pos_enc = nn.Parameter(torch.randn(emb_dim, d_attn)) 15 | 16 | # Encoder for visual modality 17 | self.visual_encoder = nn.Sequential( 18 | nn.Conv2d(visual_dim, d_attn, kernel_size=3, stride=2, padding=1), 19 | nn.BatchNorm2d(d_attn), 20 | relup(), 21 | rgem(), 22 | gemp(emb_dim=emb_dim), 23 | ) 24 | 25 | # Encoder for lidar modality 26 | self.lidar_encoder = nn.Sequential( 27 | nn.Conv3d(lidar_dim, d_attn, kernel_size=3, stride=2, padding=1), 28 | nn.BatchNorm3d(d_attn), 29 | relup(), 30 | gemp3D(emb_dim=emb_dim), 31 | ) 32 | 33 | self.cross_attn1 = nn.MultiheadAttention(d_attn, num_heads, batch_first=True) 34 | self.layer_norm1 = nn.LayerNorm(d_attn) 35 | self.layer_norm11 = nn.LayerNorm(d_attn) 36 | 37 | self.self_attn1 = nn.MultiheadAttention(d_attn, num_heads, batch_first=True) 38 | self.layer_norm2 = nn.LayerNorm(d_attn) 39 | 40 | self.cross_attn2 = nn.MultiheadAttention(d_attn, num_heads, batch_first=True) 41 | self.layer_norm3 = nn.LayerNorm(d_attn) 42 | 43 | self.self_attn2 = nn.MultiheadAttention(d_attn, num_heads, batch_first=True) 44 | self.layer_norm4 = nn.LayerNorm(d_attn) 45 | 46 | self.cross_attn3 = nn.MultiheadAttention(d_attn, num_heads, batch_first=True) 47 | self.layer_norm5 = nn.LayerNorm(d_attn) 48 | 49 | self.pool = nn.AdaptiveAvgPool1d(1) 50 | 51 | 52 | def forward(self, visual_input, lidar_input): 53 | im_feat = self.visual_encoder(visual_input) 54 | pc_feat = self.lidar_encoder(lidar_input) 55 | 56 | im_feat = im_feat.squeeze(-1).permute(0, 2, 1) 57 | pc_feat = pc_feat.squeeze(-1).squeeze(-1).permute(0, 2, 1) 58 | 59 | im_feat = im_feat + self.visual_pos_enc 60 | pc_feat = pc_feat + self.lidar_pos_enc 61 | 62 | pc_feat = self.layer_norm1(pc_feat) 63 | im_feat = self.layer_norm11(im_feat) 64 | 65 | x, _ = self.cross_attn1(pc_feat, im_feat, im_feat) 66 | fused_feat = x + pc_feat # skip connection 67 | 68 | fused_feat = self.layer_norm2(fused_feat) 69 | x, _ = self.self_attn1(fused_feat, fused_feat, fused_feat) 70 | fused_feat = x + fused_feat # skip connection 71 | pc_skip = fused_feat 72 | 73 | fused_feat = self.layer_norm3(fused_feat) 74 | x, _ = self.cross_attn2(fused_feat, im_feat, im_feat) 75 | fused_feat = x + fused_feat # skip connection 76 | 77 | fused_feat = self.layer_norm4(fused_feat) 78 | x, _ = self.self_attn2(fused_feat, fused_feat, fused_feat) 79 | fused_feat = x + fused_feat # skip connection 80 | 81 | fused_feat = self.layer_norm5(fused_feat) 82 | x, _ = self.cross_attn3(fused_feat, pc_skip, pc_skip) 83 | x = x + fused_feat # skip connection 84 | 85 | embedding = self.pool(x.permute(0, 2, 1)).squeeze(-1) # B x F 86 | return embedding 87 | -------------------------------------------------------------------------------- /models/UMF/utils/patch_matcher.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import cv2 4 | import torch.nn.functional as F 5 | import open3d as o3d 6 | import numpy as np 7 | import torch 8 | import cv2 9 | from sklearn.preprocessing import RobustScaler 10 | 11 | 12 | class PatchMatcher(object): 13 | """Patch matcher class, used to match keypoints and features from attention maps 14 | using strong geometric verification at different scales. 15 | """ 16 | 17 | def __init__(self, patch_sizes=[5], patch_size3D=[3], strides=[1], 18 | stride3D=[1], th_img=0.5, th_pc=0.5): 19 | """Initialize PatchMatcher 20 | Args: 21 | patch_sizes: list of patch sizes for 2D attention maps 22 | patch_size3D: list of patch sizes for 3D attention maps 23 | strides: list of strides for 2D attention maps 24 | stride3D: list of strides for 3D attention maps 25 | th_img: threshold for 2D attention maps 26 | th_pc: threshold for 3D attention maps 27 | """ 28 | assert len(patch_sizes) == len(strides) 29 | assert len(patch_size3D) == len(stride3D) 30 | assert th_img >= 0 and th_img <= 1 31 | assert th_pc >= 0 and th_pc <= 1 32 | self.patch_sizes = patch_sizes 33 | self.patch_sizes3D = patch_size3D 34 | self.delta_img = th_img 35 | self.delta_pc = th_pc 36 | self.strides = strides 37 | self.strides3D = stride3D 38 | 39 | 40 | def match(self, qfeats, dbfeats, qattn, dbattn): 41 | """Match query and database features using attention maps 42 | Args: 43 | qfeats: query features 44 | dbfeats: database features 45 | qattn: query attention maps 46 | dbattn: database attention maps 47 | Returns: 48 | scores: matching scores 49 | """ 50 | keypoints_q_l, filtered_qfeats_l, outputs = [], [], [] 51 | # Extract keypoints and features for query images at different scales 52 | for patch_size, stride in zip(self.patch_sizes, self.strides): 53 | # resize attn maps and features 54 | keypoints_q, filtered_qfeats = self.get_keypoints_from_attention_maps_2d(qattn, qfeats, patch_size, stride) 55 | keypoints_q_l.append(keypoints_q) 56 | filtered_qfeats_l.append(filtered_qfeats) 57 | # Similarly extract for database images 58 | out = [self.get_keypoints_from_attention_maps_2d(attn, feat, patch_size, stride) for attn, feat in zip(dbattn, dbfeats)] 59 | outputs.append(out) 60 | 61 | scores = np.zeros((len(outputs), len(outputs[0]))) 62 | for scale_idx, output in enumerate(outputs): 63 | for img_idx, (keypoints_db, db_feats) in enumerate(output): 64 | scores[scale_idx, img_idx] = self.compare_two_ransac(filtered_qfeats_l[scale_idx], db_feats, keypoints_q_l[scale_idx], keypoints_db, self.patch_sizes[scale_idx], self.strides[scale_idx]) 65 | 66 | final_scores = np.sum(scores, axis=0) 67 | return -final_scores 68 | 69 | def match_pc(self, qfeats, dbfeats, qattn, dbattn): 70 | """Match query and database 3D features using attention maps 71 | Args: 72 | qfeats: query features 73 | dbfeats: database features 74 | qattn: query attention maps 75 | dbattn: database attention maps 76 | Returns: 77 | scores: matching scores 78 | """ 79 | keypoints_q_l, filtered_qfeats_l, outputs = [], [], [] 80 | for patch_size, stride in zip(self.patch_sizes, self.strides): 81 | keypoints_q, filtered_qfeats = self.get_keypoints_from_attention_maps_3d(qattn, qfeats, patch_size, stride) 82 | keypoints_q_l.append(keypoints_q) 83 | filtered_qfeats_l.append(filtered_qfeats) 84 | out = [self.get_keypoints_from_attention_maps_3d(attn, feat, patch_size, stride) for attn, feat in zip(dbattn, dbfeats)] 85 | outputs.append(out) 86 | 87 | scores = np.zeros((len(outputs), len(outputs[0]))) 88 | for i, output in enumerate(outputs): 89 | for j, (keypoints_db, db_feats) in enumerate(output): 90 | scores[i, j] = self.compare_two_ransac_pc(filtered_qfeats_l[i], db_feats, keypoints_q_l[i], keypoints_db, self.patch_sizes[i], self.strides[i]) 91 | 92 | final_scores = np.sum(scores, axis=0) 93 | return -final_scores 94 | 95 | 96 | def get_keypoints_from_attention_maps_2d(self, attn, feat, patch_size, stride): 97 | attn = attn[0].unsqueeze(0).squeeze(-1) 98 | feat_i = feat[0].unsqueeze(0) 99 | 100 | attn = F.avg_pool2d(attn, kernel_size=patch_size, stride=stride, padding=0) 101 | feat_i = F.avg_pool2d(feat_i, kernel_size=patch_size, stride=stride, padding=0) 102 | attn = attn.squeeze(0).cpu().numpy() 103 | attn = cv2.GaussianBlur(attn, (5, 5), 0) 104 | feat_i = feat_i.squeeze(0) 105 | 106 | kp = filter_keypoint_attention(attn, th=self.delta_img) 107 | patches = feat_i[:, kp[:, 0], kp[:, 1]] 108 | return np.array(kp), np.array(patches) 109 | 110 | def get_keypoints_from_attention_maps_3d(self, attn, feat, patch_size, stride): 111 | attn = attn[0].cpu().squeeze(-1) 112 | 113 | attn = attn.unsqueeze(0).squeeze(-1) 114 | feat_i = feat[0].unsqueeze(0) 115 | attn = F.avg_pool3d(attn, kernel_size=patch_size, stride=stride, padding=0) 116 | feat_i = F.avg_pool3d(feat_i, kernel_size=patch_size, stride=stride, padding=0) 117 | 118 | attn = attn.squeeze(0).cpu().numpy() 119 | feat = feat_i.squeeze(0) 120 | 121 | kp = filter_keypoint_attention(attn, th=self.delta_pc) 122 | patches = feat[:, kp[:, 0], kp[:, 1], kp[:, 2]] 123 | return np.array(kp), np.array(patches) 124 | 125 | 126 | def compare_two_ransac_pc(self, qfeat, dbfeat, kpQ, kp2, patch_size, stride): 127 | if kpQ.shape[0] < 4 or kp2.shape[0] < 4: 128 | return 0 129 | 130 | if qfeat.shape[0] < 4 or dbfeat.shape[0] < 4: 131 | return 0 132 | 133 | MIN_MATCH_COUNT = 40 134 | threshold = 0.5 135 | 136 | qf = np.array(qfeat, dtype=np.float32).T 137 | dbf = np.array(dbfeat, dtype=np.float32).T 138 | num_total_kps = len(kp2) 139 | 140 | # find matches 141 | bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) 142 | good_matches = bf.match(qf, dbf) 143 | 144 | if len(good_matches) < MIN_MATCH_COUNT: 145 | return 0 146 | 147 | best_inliers = 0 148 | best_transformation = np.eye(4) 149 | all_src_pts = np.float32([kpQ[m.queryIdx] for m in good_matches]) 150 | all_dst_pts = np.float32([kp2[m.trainIdx] for m in good_matches]) 151 | for _ in range(120): # Number of RANSAC iterations 152 | 153 | subset_indices = np.random.choice(len(good_matches), 3, replace=False) 154 | matches = [good_matches[i] for i in subset_indices] 155 | src_pts = np.float32([kpQ[m.queryIdx] for m in matches]) 156 | dst_pts = np.float32([kp2[m.trainIdx] for m in matches]) 157 | 158 | transformation = estimate_rigid_transform(src_pts, dst_pts) 159 | inliers = count_inliers(all_src_pts, all_dst_pts, transformation, threshold=threshold) 160 | 161 | if inliers > best_inliers: 162 | best_inliers = inliers 163 | best_transformation = transformation 164 | 165 | # Refining the transformation using ICP 166 | refined_transformation, fitness_score = refine_with_icp(all_src_pts, all_dst_pts, best_transformation, threshold=threshold) 167 | fitness_score = fitness_score # / num_total_kps 168 | return fitness_score 169 | 170 | 171 | 172 | def compare_two_ransac(self, qfeat, dbfeat, kpQ, kp2, patch_size, stride): 173 | MIN_MATCH_COUNT = 10 174 | 175 | if kpQ.shape[0] < 4 or kp2.shape[0] < 4: 176 | return 0 177 | 178 | qf = np.array(qfeat, dtype=np.float32).T 179 | dbf = np.array(dbfeat, dtype=np.float32).T 180 | 181 | bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) 182 | good_matches = bf.match(qf, dbf) 183 | 184 | if len(good_matches) > MIN_MATCH_COUNT: 185 | src_pts = np.float32([kpQ[m.queryIdx] for m in good_matches]).reshape(-1, 1, 2) 186 | dst_pts = np.float32([kp2[m.trainIdx] for m in good_matches]).reshape(-1, 1, 2) 187 | 188 | _, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, (patch_size)*1.5*4) 189 | 190 | inlier_index_keypoints = src_pts[mask.ravel() == 1] 191 | inlier_count = inlier_index_keypoints.shape[0] 192 | 193 | total_keypoints = len(kpQ) 194 | normalized_score = inlier_count / total_keypoints 195 | return normalized_score 196 | else: 197 | return 0 198 | 199 | 200 | def estimate_rigid_transform(src, dst): 201 | """ 202 | Estimate rigid transformation from src to dst. 203 | Args: 204 | src (np.array): Source points (Nx3). 205 | dst (np.array): Destination points (Nx3). 206 | 207 | Returns: 208 | np.array: 4x4 transformation matrix. 209 | """ 210 | # Compute centroids 211 | centroid_src = np.mean(src, axis=0) 212 | centroid_dst = np.mean(dst, axis=0) 213 | 214 | # Center the points 215 | src_centered = src - centroid_src 216 | dst_centered = dst - centroid_dst 217 | 218 | # Compute the covariance matrix 219 | H = np.dot(src_centered.T, dst_centered) 220 | 221 | U, S, Vt = np.linalg.svd(H) 222 | R = np.dot(Vt.T, U.T) 223 | 224 | # Special reflection case 225 | if np.linalg.det(R) < 0: 226 | Vt[2, :] *= -1 227 | R = np.dot(Vt.T, U.T) 228 | 229 | # Compute the translation 230 | t = centroid_dst.T - np.dot(R, centroid_src.T) 231 | 232 | transformation = np.identity(4) 233 | transformation[:3, :3] = R 234 | transformation[:3, 3] = t 235 | 236 | return transformation 237 | 238 | def count_inliers(src, dst, transformation, threshold): 239 | """ 240 | Count how many points in src are within 'threshold' distance to dst after applying transformation. 241 | Args: 242 | src (np.array): Source points (Nx3). 243 | dst (np.array): Destination points (Nx3). 244 | transformation (np.array): 4x4 transformation matrix. 245 | threshold (float): Distance threshold for counting inliers. 246 | 247 | Returns: 248 | int: Count of inliers. 249 | """ 250 | src_homogeneous = np.hstack((src, np.ones((src.shape[0], 1)))) 251 | transformed_src = np.dot(transformation, src_homogeneous.T).T[:, :3] 252 | 253 | distances = np.sqrt(np.sum((transformed_src - dst) ** 2, axis=1)) 254 | inliers = np.sum(distances < threshold) 255 | 256 | return inliers 257 | 258 | def refine_with_icp(source_points, target_points, initial_transformation, threshold=0.02): 259 | if isinstance(source_points, torch.Tensor): 260 | source_points = source_points.cpu().numpy() 261 | if isinstance(target_points, torch.Tensor): 262 | target_points = target_points.cpu().numpy() 263 | 264 | # Ensure the data type and shape are correct 265 | source_points = np.ascontiguousarray(source_points, dtype=np.float64) 266 | target_points = np.ascontiguousarray(target_points, dtype=np.float64) 267 | 268 | # Ensure the arrays are 2D 269 | if source_points.ndim != 2 or source_points.shape[1] != 3: 270 | raise ValueError("source_points must be a 2D array of shape [N, 3]") 271 | if target_points.ndim != 2 or target_points.shape[1] != 3: 272 | raise ValueError("target_points must be a 2D array of shape [N, 3]") 273 | 274 | # Convert numpy arrays to Open3D PointCloud objects 275 | source = o3d.geometry.PointCloud() 276 | source.points = o3d.utility.Vector3dVector(source_points) 277 | 278 | target = o3d.geometry.PointCloud() 279 | target.points = o3d.utility.Vector3dVector(target_points) 280 | 281 | # Set the ICP convergence criteria 282 | criteria = o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000) 283 | 284 | # Perform ICP refinement 285 | result = o3d.pipelines.registration.registration_icp( 286 | source, target, max_correspondence_distance=threshold, 287 | init=initial_transformation, 288 | estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(), 289 | criteria=criteria 290 | ) 291 | 292 | return result.transformation, result.fitness 293 | 294 | 295 | def filter_keypoint_attention(attn, th=0.8): 296 | """Filter keypoints from attention map 297 | Args: 298 | attn: 2d or 2d attention map 299 | th: threshold 300 | Returns: 301 | keypoints: list of keypoints 302 | """ 303 | if attn.ndim == 3: 304 | attn = attn / attn.max() 305 | else: 306 | attn = cv2.GaussianBlur(attn, (5, 5), 0) 307 | attn = RobustScaler().fit_transform(attn) 308 | keypoints = np.argwhere(attn > th) 309 | return keypoints 310 | -------------------------------------------------------------------------------- /models/UMF/utils/pooling.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | class relup(nn.Module): 6 | """ Reranking with maximum descriptors aggregation """ 7 | def __init__(self, alpha=0.014): 8 | super(relup, self).__init__() 9 | self.alpha = alpha 10 | def forward(self, x): 11 | x = x.clamp(self.alpha) 12 | return x 13 | 14 | 15 | class gemp(nn.Module): 16 | """ Reranking with maximum descriptors aggregation, can substitute regular GeM pooling 17 | https://github.com/ShihaoShao-GH/SuperGlobal/blob/main/modules/coarse_retrieval/rgem.py 18 | """ 19 | def __init__(self, p=4.6, eps = 1e-8, emb_dim=None): 20 | super(gemp, self).__init__() 21 | self.p = p 22 | self.eps = eps 23 | self.avg_pool = nn.AdaptiveAvgPool2d((emb_dim, 1)) 24 | def forward(self, x): 25 | x = x.clamp(self.eps).pow(self.p) 26 | x = self.avg_pool(x).pow(1. / (self.p) ) 27 | return x 28 | 29 | class rgem(nn.Module): 30 | """ Reranking with maximum descriptors aggregation """ 31 | def __init__(self, pr=2.5, size = 5): 32 | super(rgem, self).__init__() 33 | self.pr = pr 34 | self.size = size 35 | self.lppool = nn.LPPool2d(self.pr, int(self.size), stride=1) 36 | self.pad = nn.ReflectionPad2d(int((self.size-1)//2.)) 37 | def forward(self, x): 38 | nominater = (self.size**2) **(1./self.pr) 39 | x = 0.5*self.lppool(self.pad(x/nominater)) + 0.5*x 40 | return x 41 | 42 | 43 | class gemp3D(nn.Module): 44 | """ Reranking with maximum descriptors aggregation """ 45 | def __init__(self, p=4.6, eps = 1e-8, emb_dim=None): 46 | super(gemp3D, self).__init__() 47 | self.p = p 48 | self.eps = eps 49 | self.avg_pool = nn.AdaptiveAvgPool3d((emb_dim, 1, 1)) 50 | def forward(self, x): 51 | x = x.clamp(self.eps).pow(self.p) 52 | x = self.avg_pool(x).pow(1. / (self.p) ) 53 | return x 54 | 55 | -------------------------------------------------------------------------------- /models/UMF/utils/spconv_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Set 2 | 3 | try: 4 | import spconv.pytorch as spconv 5 | except: 6 | import spconv as spconv 7 | 8 | import torch.nn as nn 9 | 10 | 11 | def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: 12 | """ 13 | Finds all spconv keys that need to have weight's transposed 14 | """ 15 | found_keys: Set[str] = set() 16 | for name, child in model.named_children(): 17 | new_prefix = f"{prefix}.{name}" if prefix != "" else name 18 | 19 | if isinstance(child, spconv.conv.SparseConvolution): 20 | new_prefix = f"{new_prefix}.weight" 21 | found_keys.add(new_prefix) 22 | 23 | found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) 24 | 25 | return found_keys 26 | 27 | 28 | def replace_feature(out, new_features): 29 | if "replace_feature" in out.__dir__(): 30 | # spconv 2.x behaviour 31 | return out.replace_feature(new_features) 32 | else: 33 | out.features = new_features 34 | return out 35 | -------------------------------------------------------------------------------- /models/UMF/utils/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .linear_attention import LinearMultiheadAttention 4 | 5 | 6 | class Embeddings(nn.Module): 7 | """Construct the embeddings from patch, position embeddings. 8 | """ 9 | def __init__(self, in_dim, hidden_size, dropout_rate=0.1): 10 | super(Embeddings, self).__init__() 11 | 12 | self.project = nn.Conv2d(in_dim, hidden_size, kernel_size=1, stride=1, padding=0) 13 | n_patches = 60 * 80 14 | 15 | 16 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size)) 17 | self.dropout = nn.Dropout(dropout_rate) 18 | 19 | def forward(self, x): 20 | B = x.shape[0] 21 | 22 | x = self.project(x) 23 | projected_shape = x.shape 24 | x = x.flatten(2) 25 | x = x.transpose(-1, -2) 26 | 27 | embeddings = x + self.position_embeddings 28 | embeddings = self.dropout(embeddings) 29 | return embeddings, projected_shape 30 | 31 | 32 | class Embeddings3D(nn.Module): 33 | """Construct the embeddings from patch, position embeddings. 34 | """ 35 | def __init__(self, in_dim, hidden_size, dropout_rate=0.1): 36 | super(Embeddings3D, self).__init__() 37 | 38 | self.project = nn.Conv3d(in_dim, hidden_size, kernel_size=2, stride=2, padding=0) 39 | n_patches = 25 * 25 * 25 40 | 41 | 42 | self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, hidden_size)) 43 | self.dropout = nn.Dropout(dropout_rate) 44 | 45 | def forward(self, x): 46 | B = x.shape[0] 47 | 48 | x = self.project(x) 49 | projected_shape = x.shape 50 | x = x.flatten(2) 51 | x = x.transpose(-1, -2) 52 | 53 | embeddings = x + self.position_embeddings 54 | embeddings = self.dropout(embeddings) 55 | return embeddings, projected_shape 56 | 57 | 58 | 59 | class Mlp(nn.Module): 60 | def __init__(self, hidden_size, mlp_dim, dropout_rate=0.1): 61 | super(Mlp, self).__init__() 62 | self.fc1 = nn.Linear(hidden_size, mlp_dim) 63 | self.fc2 = nn.Linear(mlp_dim, hidden_size) 64 | self.act_fn = nn.GELU() 65 | self.dropout = nn.Dropout(dropout_rate) 66 | 67 | self._init_weights() 68 | 69 | def _init_weights(self): 70 | nn.init.xavier_uniform_(self.fc1.weight) 71 | nn.init.xavier_uniform_(self.fc2.weight) 72 | nn.init.normal_(self.fc1.bias, std=1e-6) 73 | nn.init.normal_(self.fc2.bias, std=1e-6) 74 | 75 | def forward(self, x): 76 | x = self.fc1(x) 77 | x = self.act_fn(x) 78 | x = self.dropout(x) 79 | x = self.fc2(x) 80 | x = self.dropout(x) 81 | return x 82 | 83 | class TransformerBlock(nn.Module): 84 | def __init__(self, hidden_size, seq_len, num_heads=8, mpl_drop=0.1, attn_drop=0., qkv_bias=False,): 85 | super(TransformerBlock, self).__init__() 86 | self.hidden_size = hidden_size 87 | self.attention_norm = nn.LayerNorm(hidden_size, eps=1e-6) 88 | self.add_norm = nn.LayerNorm(hidden_size, eps=1e-6) 89 | 90 | self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6) 91 | self.ffn = Mlp(hidden_size, hidden_size*4, dropout_rate=mpl_drop) 92 | self.attn = LinearMultiheadAttention(hidden_size, num_heads=num_heads, batch_first=True, seq_len=seq_len, dropout=attn_drop) 93 | 94 | # Linear layers for Q, K, V projections 95 | self.q_linear = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 96 | self.k_linear = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 97 | self.v_linear = nn.Linear(hidden_size, hidden_size, bias=qkv_bias) 98 | self.apply(self._init_weights) 99 | 100 | def forward(self, x): 101 | h = x 102 | x_norm = self.attention_norm(x) 103 | 104 | # Project Q, K, V 105 | q = self.q_linear(x_norm) 106 | k = self.k_linear(x_norm) 107 | v = self.v_linear(x_norm) 108 | 109 | 110 | x, weights = self.attn(q, k, v, need_weights=True) 111 | 112 | x = x + h 113 | 114 | h = x 115 | x = self.ffn_norm(self.add_norm(x)) 116 | x = self.ffn(x) 117 | x = x + h 118 | return x, weights 119 | 120 | 121 | def _init_weights(self, m): 122 | if isinstance(m, nn.Linear): 123 | trunc_normal_(m.weight, std=.02) 124 | if isinstance(m, nn.Linear) and m.bias is not None: 125 | nn.init.constant_(m.bias, 0) 126 | elif isinstance(m, nn.LayerNorm): 127 | nn.init.constant_(m.bias, 0) 128 | nn.init.constant_(m.weight, 1.0) 129 | -------------------------------------------------------------------------------- /models/UMF/voxel_encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | 7 | from models.UMF.utils.spconv_utils import replace_feature, spconv 8 | from .spconv_backbone import post_act_block 9 | import torch.nn.functional as F 10 | import spconv as spconv_core 11 | spconv_core.constants.SPCONV_ALLOW_TF32 = True 12 | 13 | from functools import partial 14 | 15 | class SparseBasicBlock(spconv.SparseModule): 16 | expansion = 1 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, indice_key=None, norm_fn=None): 19 | super(SparseBasicBlock, self).__init__() 20 | self.conv1 = spconv.SubMConv3d( 21 | inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False, indice_key=indice_key, algo=spconv.ConvAlgo.Native 22 | ) 23 | self.bn1 = norm_fn(planes) 24 | self.relu = nn.ReLU() 25 | self.conv2 = spconv.SubMConv3d( 26 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False, indice_key=indice_key, algo=spconv.ConvAlgo.Native 27 | ) 28 | self.bn2 = norm_fn(planes) 29 | self.downsample = downsample 30 | self.stride = stride 31 | 32 | def forward(self, x): 33 | identity = x.features 34 | 35 | assert x.features.dim() == 2, 'x.features.dim()=%d' % x.features.dim() 36 | 37 | out = self.conv1(x) 38 | out = replace_feature(out, self.bn1(out.features)) 39 | out = replace_feature(out, self.relu(out.features)) 40 | 41 | out = self.conv2(out) 42 | out = replace_feature(out, self.bn2(out.features)) 43 | 44 | if self.downsample is not None: 45 | identity = self.downsample(x) 46 | 47 | out = replace_feature(out, out.features + identity) 48 | out = replace_feature(out, self.relu(out.features)) 49 | 50 | return out 51 | 52 | class GeM3D(nn.Module): 53 | def __init__(self, input_dim, p=3, eps=1e-6): 54 | super(GeM3D, self).__init__() 55 | self.input_dim = input_dim 56 | self.output_dim = self.input_dim 57 | self.p = nn.Parameter(torch.ones(1) * p) 58 | self.eps = eps 59 | 60 | def forward(self, x: spconv.SparseConvTensor): 61 | # Convert SparseConvTensor to dense tensor (B, C, Z, Y, X) format 62 | dense_x = x.dense() 63 | 64 | # Clamp, raise to power p, and average pool over spatial dimensions 65 | pooled = F.avg_pool3d(dense_x.clamp(min=self.eps).pow(self.p), 66 | kernel_size=dense_x.size()[2:]) 67 | 68 | # Squeeze spatial dimensions and raise to power 1/p 69 | output = pooled.squeeze(-1).squeeze(-1).squeeze(-1).pow(1./self.p) 70 | 71 | return output 72 | 73 | 74 | class Voxel_MAE(nn.Module): 75 | """ 76 | Sparse Convolution based UNet for point-wise feature learning. 77 | Reference Paper: https://arxiv.org/abs/1907.03670 (Shaoshuai Shi, et. al) 78 | From Points to Parts: 3D Object Detection from Point Cloud with Part-aware and Part-aggregation Network 79 | """ 80 | 81 | def __init__(self, input_channels, grid_size, voxel_size, point_cloud_range, fpn=False, **kwargs): 82 | print("Voxel_MAE init", input_channels, grid_size, voxel_size, point_cloud_range, fpn, kwargs) 83 | super().__init__() 84 | self.sparse_shape = np.array(grid_size[::-1]) 85 | self.voxel_size = voxel_size 86 | self.point_cloud_range = point_cloud_range 87 | 88 | self.fpn = fpn 89 | self.frozen = False 90 | self.freeze_fpn = fpn 91 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 92 | input_channels = 1 93 | 94 | norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01) 95 | 96 | self.conv_input = spconv.SparseSequential( 97 | spconv.SubMConv3d(input_channels, 16, 3, padding=1, bias=False, indice_key='subm1', algo=spconv.ConvAlgo.Native), 98 | norm_fn(16), 99 | nn.ReLU(), 100 | ) 101 | block = post_act_block 102 | 103 | self.conv1 = spconv.SparseSequential( 104 | # 200, 200, 200 105 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), 106 | ) 107 | 108 | self.conv2 = spconv.SparseSequential( 109 | # 100, 100, 100 110 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='spconv'), 111 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 112 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 113 | ) 114 | 115 | self.conv3 = spconv.SparseSequential( 116 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='spconv'), 117 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 118 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 119 | ) 120 | 121 | self.conv4 = spconv.SparseSequential( 122 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(1, 1, 1), indice_key='spconv4', conv_type='spconv'), 123 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 124 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 125 | ) 126 | 127 | last_pad = 0 128 | 129 | self.conv_out = spconv.SparseSequential( 130 | spconv.SparseConv3d(64, 128, (1, 1, 1), stride=(1, 1, 1), padding=last_pad, 131 | bias=False, indice_key='spconv_down2', algo=spconv.ConvAlgo.Native), 132 | norm_fn(128), 133 | nn.ReLU(), 134 | ) 135 | 136 | 137 | if self.fpn: 138 | self.conv_up_t4 = SparseBasicBlock(64, 64, indice_key='subm4', norm_fn=norm_fn) 139 | self.conv_up_m4 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4') 140 | self.inv_conv4 = block(64, 64, 3, norm_fn=norm_fn, indice_key='spconv4', conv_type='inverseconv') 141 | 142 | self.conv_up_t3 = SparseBasicBlock(64, 64, indice_key='subm3', norm_fn=norm_fn) 143 | self.conv_up_m3 = block(128, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3') 144 | self.inv_conv3 = block(64, 32, 3, norm_fn=norm_fn, indice_key='spconv3', conv_type='inverseconv') 145 | 146 | self.conv3_up = spconv.SparseSequential( 147 | block(64, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm3') 148 | ) 149 | 150 | self.num_point_features = 32 151 | 152 | 153 | def freeze_backbone(self): 154 | self.frozen = True 155 | for param in self.conv_input.parameters(): 156 | param.requires_grad = False 157 | 158 | for param in self.conv1.parameters(): 159 | param.requires_grad = False 160 | 161 | for param in self.conv2.parameters(): 162 | param.requires_grad = False 163 | 164 | for param in self.conv3.parameters(): 165 | param.requires_grad = False 166 | 167 | for param in self.conv4.parameters(): 168 | param.requires_grad = False 169 | 170 | for param in self.conv_out.parameters(): 171 | param.requires_grad = False 172 | 173 | def UR_block_forward(self, x_lateral, x_bottom, conv_t, conv_m, conv_inv): 174 | x_trans = conv_t(x_lateral) 175 | x = x_trans 176 | x = replace_feature(x, torch.cat((x_bottom.features, x_trans.features), dim=1)) 177 | x_m = conv_m(x) 178 | x = self.channel_reduction(x, x_m.features.shape[1]) 179 | x = replace_feature(x, x_m.features + x.features) 180 | x = conv_inv(x) 181 | return x 182 | 183 | @staticmethod 184 | def channel_reduction(x, out_channels): 185 | """ 186 | Args: 187 | x: x.features (N, C1) 188 | out_channels: C2 189 | 190 | Returns: 191 | 192 | """ 193 | features = x.features 194 | n, in_channels = features.shape 195 | assert (in_channels % out_channels == 0) and (in_channels >= out_channels) 196 | 197 | x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2)) 198 | return x 199 | 200 | def load_checkpoint(self, ckpt_path): 201 | print("Loading checkpoint from {}".format(ckpt_path)) 202 | checkpoint = torch.load(ckpt_path, map_location=self.device) 203 | self.load_state_dict(checkpoint, strict=False) 204 | print("Loaded checkpoint from {}".format(ckpt_path)) 205 | 206 | def forward(self, batch_dict): 207 | """ 208 | Args: 209 | batch_dict: 210 | batch_size: int 211 | vfe_features: (num_voxels, C) 212 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 213 | Returns: 214 | batch_dict: 215 | encoded_spconv_tensor: sparse tensor 216 | point_features: (N, C) 217 | """ 218 | voxel_features, voxel_coords = batch_dict['voxel_features'].to(self.device), batch_dict['coordinates'].to(self.device) 219 | 220 | batch_size = batch_dict['batch_size'] 221 | 222 | 223 | input_sp_tensor = spconv.SparseConvTensor( 224 | features=voxel_features, 225 | indices=voxel_coords.int(), 226 | spatial_shape=self.sparse_shape, 227 | batch_size=batch_size 228 | ) 229 | x = self.conv_input(input_sp_tensor) 230 | 231 | x_conv1 = self.conv1(x) 232 | x_conv2 = self.conv2(x_conv1) 233 | x_conv3 = self.conv3(x_conv2) 234 | x_conv4 = self.conv4(x_conv3) 235 | 236 | out = self.conv_out(x_conv4) 237 | batch_dict["coarse"] = out.dense() 238 | 239 | if not self.fpn: 240 | return batch_dict 241 | 242 | x_up4 = self.UR_block_forward(x_conv4, x_conv4, self.conv_up_t4, self.conv_up_m4, self.inv_conv4) 243 | x_up3 = self.UR_block_forward(x_conv3, x_up4, self.conv_up_t3, self.conv_up_m3, self.conv3_up) 244 | 245 | # 50, 50, 50 246 | #batch_dict["fine_1"] = x_up4.dense() 247 | batch_dict["fine_2"] = x_up3.dense() 248 | 249 | 250 | return batch_dict -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pytorch_metric_learning import losses, miners, reducers 4 | from pytorch_metric_learning.distances import LpDistance 5 | from models.losses.super_feature_losses import DecorrelationAttentionLoss, SuperfeatureLoss 6 | from models.losses.truncated_smoothap import TruncatedSmoothAP 7 | 8 | def make_loss(params): 9 | if params.loss == 'BatchHardTripletMarginLoss': 10 | # BatchHard mining with triplet margin loss 11 | # Expects input: embeddings, positives_mask, negatives_mask 12 | loss_fn = BatchHardTripletLossWithMasks(params.margin, params.normalize_embeddings) 13 | elif params.loss == 'MultiBatchHardTripletMarginLoss': 14 | # BatchHard mining with triplet margin loss 15 | # Expects input: embeddings, positives_mask, negatives_mask 16 | loss_fn = MultiBatchHardTripletLossWithMasks(params.margin, params.normalize_embeddings, params.weights) 17 | 18 | elif params.loss == 'MultiBatchHardTripletLossWithMasksAugmented': 19 | loss_fn = MultiBatchHardTripletLossWithMasksAugmented(params) 20 | 21 | elif params.loss == 'MultiBatchHardTripletLossWithMasksAP': 22 | loss_fn = MultiBatchHardTripletLossWithMasksAP(params) 23 | else: 24 | print('Unknown loss: {}'.format(params.loss)) 25 | raise NotImplementedError 26 | return loss_fn 27 | 28 | 29 | class HardTripletMinerWithMasks: 30 | # Hard triplet miner 31 | def __init__(self, distance): 32 | self.distance = distance 33 | # Stats 34 | self.max_pos_pair_dist = None 35 | self.max_neg_pair_dist = None 36 | self.mean_pos_pair_dist = None 37 | self.mean_neg_pair_dist = None 38 | self.min_pos_pair_dist = None 39 | self.min_neg_pair_dist = None 40 | 41 | def __call__(self, embeddings, positives_mask, negatives_mask): 42 | assert embeddings.dim() == 2 43 | d_embeddings = embeddings.detach() 44 | with torch.no_grad(): 45 | hard_triplets = self.mine(d_embeddings, positives_mask, negatives_mask) 46 | return hard_triplets 47 | 48 | def mine(self, embeddings, positives_mask, negatives_mask): 49 | # Based on pytorch-metric-learning implementation 50 | dist_mat = self.distance(embeddings.float()) 51 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask) 52 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 53 | a_keep_idx = torch.where(a1p_keep & a2n_keep) 54 | a = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx] 55 | p = hardest_positive_indices[a_keep_idx] 56 | n1 = hardest_negative_indices[a_keep_idx] 57 | 58 | # second hardest negative 59 | # first remove the hardest negative 60 | dist_mat[a, n1] = float('inf') 61 | (second_hardest_negative_dist, second_hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 62 | n2 = second_hardest_negative_indices[a_keep_idx] 63 | 64 | # third hardest negative 65 | # first remove the second hardest negative 66 | dist_mat[a, n2] = float('inf') 67 | (third_hardest_negative_dist, third_hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 68 | n3 = third_hardest_negative_indices[a_keep_idx] 69 | 70 | #n = torch.stack([n1, n2, n3], dim=1) 71 | n = torch.stack([n1, n2, n3], dim=1) 72 | 73 | self.max_pos_pair_dist = torch.max(hardest_positive_dist).item() 74 | self.max_neg_pair_dist = torch.max(hardest_negative_dist).item() 75 | self.mean_pos_pair_dist = torch.mean(hardest_positive_dist).item() 76 | self.mean_neg_pair_dist = torch.mean(hardest_negative_dist).item() 77 | self.min_pos_pair_dist = torch.min(hardest_positive_dist).item() 78 | self.min_neg_pair_dist = torch.min(hardest_negative_dist).item() 79 | return a, p, n 80 | 81 | 82 | def get_max_per_row(mat, mask): 83 | non_zero_rows = torch.any(mask, dim=1) 84 | mat_masked = mat.clone() 85 | mat_masked[~mask] = 0 86 | return torch.max(mat_masked, dim=1), non_zero_rows 87 | 88 | 89 | def get_min_per_row(mat, mask): 90 | non_inf_rows = torch.any(mask, dim=1) 91 | mat_masked = mat.clone() 92 | mat_masked[~mask] = float('inf') 93 | return torch.min(mat_masked, dim=1), non_inf_rows 94 | 95 | 96 | 97 | class MultiBatchHardTripletLossWithMasks: 98 | def __init__(self, margin, normalize_embeddings, weights): 99 | assert len(weights) == 3 100 | self.weights = weights 101 | self.final_loss = BatchHardTripletLossWithMasksHelper(margin[0], normalize_embeddings) 102 | self.cloud_loss = BatchHardTripletLossWithMasksHelper(margin[1], normalize_embeddings) 103 | self.image_loss = BatchHardTripletLossWithMasksHelper(margin[2], normalize_embeddings) 104 | print('MultiBatchHardTripletLossWithMasks') 105 | print('Weights (final/cloud/image): {}'.format(weights)) 106 | print('Margins (final/cloud/image): {}'.format(margin)) 107 | 108 | def __call__(self, x, positives_mask, negatives_mask): 109 | # Loss on the final global descriptor 110 | final_loss, final_stats, final_hard_triplets = self.final_loss(x['embedding'], positives_mask, negatives_mask) 111 | final_stats = {'final_{}'.format(e): final_stats[e] for e in final_stats} 112 | 113 | loss = 0. 114 | 115 | stats = final_stats 116 | if self.weights[0] > 0.: 117 | loss = self.weights[0] * final_loss + loss 118 | 119 | # Loss on the cloud-based descriptor 120 | if 'cloud_embedding' in x: 121 | cloud_loss, cloud_stats, _ = self.cloud_loss(x['cloud_embedding'], positives_mask, negatives_mask) 122 | cloud_stats = {'cloud_{}'.format(e): cloud_stats[e] for e in cloud_stats} 123 | stats.update(cloud_stats) 124 | if self.weights[1] > 0.: 125 | loss = self.weights[1] * cloud_loss + loss 126 | 127 | # Loss on the image-based descriptor 128 | if 'image_embedding' in x: 129 | image_loss, image_stats, _ = self.image_loss(x['image_embedding'], positives_mask, negatives_mask) 130 | image_stats = {'image_{}'.format(e): image_stats[e] for e in image_stats} 131 | stats.update(image_stats) 132 | if self.weights[2] > 0.: 133 | loss = self.weights[2] * image_loss + loss 134 | 135 | stats['loss'] = loss.item() 136 | return loss, stats, None 137 | 138 | class MultiBatchHardTripletLossWithMasksAP: 139 | def __init__(self, margin, normalize_embeddings, weights): 140 | assert len(weights) == 3 141 | self.weights = weights 142 | 143 | tau1 = 0.01 144 | positives_per_query = 4 145 | similarity = 'cosine' 146 | 147 | self.final_loss = TruncatedSmoothAP(tau1=tau1, similarity=similarity, 148 | positives_per_query=positives_per_query) 149 | self.cloud_loss = TruncatedSmoothAP(tau1=tau1, similarity=similarity, 150 | positives_per_query=positives_per_query) 151 | self.image_loss = TruncatedSmoothAP(tau1=tau1, similarity=similarity, 152 | positives_per_query=positives_per_query) 153 | 154 | print('MultiBatchHardTripletLossWithMasks') 155 | print('Tau1: {}'.format(tau1)) 156 | print('Similarity: {}'.format(similarity)) 157 | print('Positives per query: {}'.format(positives_per_query)) 158 | 159 | def __call__(self, x, positives_mask, negatives_mask): 160 | # Loss on the final global descriptor 161 | final_loss, final_stats, final_hard_triplets = self.final_loss(x['embedding'], positives_mask, negatives_mask) 162 | final_stats = {'final_{}'.format(e): final_stats[e] for e in final_stats} 163 | 164 | loss = 0. 165 | 166 | stats = final_stats 167 | if self.weights[0] > 0.: 168 | loss = self.weights[0] * final_loss + loss 169 | 170 | # Loss on the cloud-based descriptor 171 | if 'cloud_embedding' in x: 172 | cloud_loss, cloud_stats, _ = self.cloud_loss(x['cloud_embedding'], positives_mask, negatives_mask) 173 | cloud_stats = {'cloud_{}'.format(e): cloud_stats[e] for e in cloud_stats} 174 | stats.update(cloud_stats) 175 | if self.weights[1] > 0.: 176 | loss = self.weights[1] * cloud_loss + loss 177 | 178 | # Loss on the image-based descriptor 179 | if 'image_embedding' in x: 180 | image_loss, image_stats, _ = self.image_loss(x['image_embedding'], positives_mask, negatives_mask) 181 | image_stats = {'image_{}'.format(e): image_stats[e] for e in image_stats} 182 | stats.update(image_stats) 183 | if self.weights[2] > 0.: 184 | loss = self.weights[2] * image_loss + loss 185 | 186 | stats['loss'] = loss.item() 187 | return loss, stats, None 188 | 189 | 190 | class MultiBatchHardTripletLossWithMasksAugmented(MultiBatchHardTripletLossWithMasks): 191 | def __init__(self, params): 192 | self.normalize_embeddings = params.normalize_embeddings 193 | self.margin = [params.margin] * 3 194 | self.mode = params.cfg.model.mode 195 | 196 | if self.mode.lower() == "ransac": 197 | self.margin[1] = params.cfg.model.local_feat_margin[0] 198 | self.margin[2] = params.cfg.model.local_feat_margin[1] 199 | 200 | self.weights = params.weights 201 | print('MultiBatchHardTripletLossWithMasksAugmented, model mode: ', self.mode) 202 | super().__init__( self.margin, self.normalize_embeddings, self.weights) 203 | 204 | if self.mode == "superfeatures": 205 | print('Superfeature loss') 206 | 207 | self.local_weights = params.cfg.model.local_feat_weights 208 | local_feat_margin = params.cfg.model.local_feat_margin 209 | self.local_attn_weights = params.cfg.model.local_attn_weights 210 | 211 | self.criterion_superfeatures = SuperfeatureLoss(margin=local_feat_margin[1], weight=self.local_weights[1]).to("cuda") 212 | self.criterion_superfeatures_pc = SuperfeatureLoss(margin=local_feat_margin[0], weight=self.local_weights[0]).to("cuda") 213 | 214 | self.criterion_attns = DecorrelationAttentionLoss(weight= self.local_attn_weights[1]).to("cuda") 215 | self.criterion_attns_3D = DecorrelationAttentionLoss(weight= self.local_attn_weights[0]).to("cuda") 216 | 217 | def __call__(self, x, positives_mask, negatives_mask): 218 | stats = {} 219 | loss = 0. 220 | 221 | embeddings = x['embedding'].float() 222 | final_loss, final_stats, final_hard_triplets = self.final_loss(embeddings, positives_mask, negatives_mask) 223 | final_stats = {'final_{}'.format(e): final_stats[e] for e in final_stats} 224 | device = x['embedding'][0].device 225 | 226 | loss = 0. 227 | stats = final_stats 228 | if self.weights[0] > 0.: 229 | loss = self.weights[0] * final_loss + loss 230 | 231 | if self.mode == "fusion": 232 | return loss, stats, None 233 | 234 | # Cloud - Loss on the cloud-based descriptor 235 | cloud_stats = {} 236 | image_stats = {} 237 | if self.mode == "ransac": 238 | cloud_loss, cloud_stats, pc_hard_triplets = self.cloud_loss(x['cloud_embedding'], positives_mask, negatives_mask) 239 | 240 | if self.weights[1] > 0.: 241 | loss = self.weights[1] * cloud_loss + loss 242 | 243 | if self.mode == "superfeatures": 244 | q_triplet = final_hard_triplets[0] 245 | p_triplet = final_hard_triplets[1] 246 | n_triplet = final_hard_triplets[2] 247 | 248 | a_pc_feat = x['pc_super_feat'][q_triplet] 249 | p_pc_feat = x['pc_super_feat'][p_triplet] 250 | feat_dim = a_pc_feat.shape[-1] 251 | n_pc_feat = x['pc_super_feat'][n_triplet.reshape(-1)] 252 | n_pc_feat = n_pc_feat.view(*n_triplet.shape, -1, feat_dim) 253 | 254 | a_img_feat = x['img_super_feat'][q_triplet] 255 | p_img_feat = x['img_super_feat'][p_triplet] 256 | feat_dim = a_img_feat.shape[-1] 257 | n_img_feat = x['img_super_feat'][n_triplet.reshape(-1)] 258 | n_img_feat = n_img_feat.view(*n_triplet.shape, -1, feat_dim) 259 | 260 | # Initialize loss 261 | loss_local_3d = 0.0 262 | loss_local_img = 0.0 263 | 264 | loss_attn_3d = self.criterion_attns_3D(x['pc_attns']) 265 | loss_attn = self.criterion_attns(x['img_attns']) 266 | 267 | for i in range(len(q_triplet)): 268 | # Construct target for one positive and five negatives 269 | target = torch.tensor([-1, 1] + [0] * len(n_pc_feat[i])).to(device) 270 | 271 | # Concatenate anchor, positive, and negatives along the feature dimension 272 | pc_superfeatures = torch.cat([a_pc_feat[i].unsqueeze(0), p_pc_feat[i].unsqueeze(0), n_pc_feat[i]], dim=0) 273 | img_superfeatures = torch.cat([a_img_feat[i].unsqueeze(0), p_img_feat[i].unsqueeze(0), n_img_feat[i]], dim=0) 274 | 275 | loss_local_3d += self.criterion_superfeatures_pc(pc_superfeatures, target) 276 | loss_local_img += self.criterion_superfeatures(img_superfeatures, target) 277 | 278 | loss_local_3d /= len(q_triplet) 279 | loss_local_img /= len(q_triplet) 280 | loss_attn_3d /= len(q_triplet) 281 | loss_attn /= len(q_triplet) 282 | loss = loss + loss_attn_3d + loss_local_3d + loss_local_img + loss_attn 283 | cloud_stats["loss_attn_3d"] = loss_attn_3d.item() 284 | cloud_stats["loss_super_pc"] = loss_local_3d.item() 285 | image_stats["loss_attn"] = loss_attn.item() 286 | image_stats["loss_super_im"] = loss_local_img.item() 287 | 288 | image_stats = {'image_{}'.format(e): image_stats[e] for e in image_stats} 289 | stats.update(image_stats) 290 | 291 | cloud_stats = {'cloud_{}'.format(e): cloud_stats[e] for e in cloud_stats} 292 | stats.update(cloud_stats) 293 | 294 | stats['loss'] = loss.item() 295 | return loss, stats, None 296 | 297 | 298 | class BatchHardTripletLossWithMasks: 299 | def __init__(self, margin, normalize_embeddings): 300 | self.loss_fn = BatchHardTripletLossWithMasksHelper(margin, normalize_embeddings) 301 | 302 | def __call__(self, x, positives_mask, negatives_mask): 303 | embeddings = x['embedding'] 304 | return self.loss_fn(embeddings, positives_mask, negatives_mask) 305 | 306 | 307 | class BatchHardTripletLossWithMasksHelper: 308 | def __init__(self, margin, normalize_embeddings): 309 | self.margin = margin 310 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings, collect_stats=True) 311 | # We use triplet loss with Euclidean distance 312 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 313 | reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) 314 | self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance, 315 | reducer=reducer_fn, collect_stats=True) 316 | 317 | def __call__(self, embeddings, positives_mask, negatives_mask): 318 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 319 | # select only first negative 320 | hard_triplets_cp = (hard_triplets[0], hard_triplets[1], hard_triplets[2][:, 0]) 321 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 322 | loss = self.loss_fn(embeddings.float(), dummy_labels, hard_triplets_cp) 323 | stats = {'loss': loss.item(), 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 324 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter, 325 | 'num_triplets': len(hard_triplets_cp[0]), 326 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 327 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 328 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 329 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 330 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 331 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist, 332 | 'normalized_loss': loss.item() * self.loss_fn.reducer.triplets_past_filter, 333 | # total loss per batch 334 | 'total_loss': self.loss_fn.reducer.loss * self.loss_fn.reducer.triplets_past_filter 335 | } 336 | 337 | 338 | return loss, stats, hard_triplets -------------------------------------------------------------------------------- /models/losses/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Functions and classes used by different loss functions 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | 6 | EPS = 1e-5 7 | 8 | 9 | def metrics_mean(l): 10 | # Compute the mean and return as Python number 11 | metrics = {} 12 | for e in l: 13 | for metric_name in e: 14 | if metric_name not in metrics: 15 | metrics[metric_name] = [] 16 | metrics[metric_name].append(e[metric_name]) 17 | 18 | for metric_name in metrics: 19 | metrics[metric_name] = np.mean(np.array(metrics[metric_name])) 20 | 21 | return metrics 22 | 23 | 24 | def squared_euclidean_distance(x: Tensor, y: Tensor) -> Tensor: 25 | ''' 26 | Compute squared Euclidean distance 27 | Input: x is Nxd matrix 28 | y is Mxd matirx 29 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 30 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 31 | Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3 32 | ''' 33 | x_norm = (x ** 2).sum(1).view(-1, 1) 34 | y_t = torch.transpose(y, 0, 1) 35 | y_norm = (y ** 2).sum(1).view(1, -1) 36 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 37 | return torch.clamp(dist, 0.0, np.inf) 38 | 39 | 40 | def sigmoid(tensor: Tensor, temp: float) -> Tensor: 41 | """ temperature controlled sigmoid 42 | takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp 43 | """ 44 | exponent = -tensor / temp 45 | # clamp the input tensor for stability 46 | exponent = torch.clamp(exponent, min=-50, max=50) 47 | y = 1.0 / (1.0 + torch.exp(exponent)) 48 | return y 49 | 50 | 51 | def compute_aff(x: Tensor, similarity: str = 'cosine') -> Tensor: 52 | """computes the affinity matrix between an input vector and itself""" 53 | if similarity == 'cosine': 54 | x = torch.mm(x, x.t()) 55 | elif similarity == 'euclidean': 56 | x = x.unsqueeze(0) 57 | x = torch.cdist(x, x, p=2) 58 | x = x.squeeze(0) 59 | # The greater the distance the smaller affinity 60 | x = -x 61 | else: 62 | raise NotImplementedError(f"Incorrect similarity measure: {similarity}") 63 | return x 64 | -------------------------------------------------------------------------------- /models/losses/super_feature_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from cirtorch.layers.loss import ContrastiveLoss 5 | 6 | class DecorrelationAttentionLoss(nn.Module): 7 | 8 | def __init__(self, weight=1.0): 9 | super().__init__() 10 | self.weight = weight 11 | 12 | def forward(self, attention_list): 13 | """ 14 | attention_list is a list of tensor of size N x H x W where N is the number of attention maps per image 15 | """ 16 | total_loss = 0.0 17 | for attn in attention_list: 18 | N = attn.size(0) 19 | attn = attn.view( N, -1) 20 | attnN = F.normalize(attn, dim=1) 21 | corr = torch.einsum("rn,sn -> rs", attnN, attnN) 22 | loss = (corr.sum() - torch.diagonal(corr, dim1=0, dim2=1).sum() ) / (N * (N-1) ) # sum over non-diagonal elements 23 | total_loss += loss 24 | return total_loss * self.weight 25 | 26 | def __repr__(self): 27 | return "{:s}(weight={:g})".format(self.__class__.__name__, self.weight) 28 | 29 | 30 | def match_super(query_feat, pos_feat, LoweRatioTh=0.9): 31 | # first perform reciprocal nn 32 | dist = torch.cdist(query_feat, pos_feat) 33 | best1 = torch.argmin(dist, dim=1) 34 | best2 = torch.argmin(dist, dim=0) 35 | arange = torch.arange(best2.size(0), device=best2.device) 36 | reciprocal = best1[best2]==arange 37 | # check Lowe ratio test 38 | dist2 = dist.clone() 39 | dist2[best2,arange] = float('Inf') 40 | dist2_second2 = torch.argmin(dist2, dim=0) 41 | ratio1to2 = dist[best2,arange] / dist2_second2 42 | valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh) 43 | pindices = torch.where(valid)[0] 44 | qindices = best2[pindices] 45 | # keep only the ones with same indices 46 | valid = pindices==qindices 47 | return pindices[valid] 48 | 49 | class SuperfeatureLoss(nn.Module): 50 | 51 | def __init__(self, margin=1.1, weight=1.0): 52 | super().__init__() 53 | self.weight = weight 54 | self.criterion = ContrastiveLoss(margin=margin) 55 | 56 | def forward(self, superfeatures_list, target): 57 | """ 58 | superfeatures_list is a list of tensor of size N x D containing the superfeatures for each image 59 | """ 60 | assert target[0]==-1 and target[1]==1 and torch.all(target[2:]==0), "Only implemented for one tuple where the first element is the query, the second one the positive, and the rest are negatives" 61 | N = superfeatures_list[0].size(0) 62 | assert all(s.size(0)==N for s in superfeatures_list[1:]), "All images should have the same number of features" 63 | query_feat = F.normalize(superfeatures_list[0], dim=1) 64 | pos_feat = F.normalize(superfeatures_list[1], dim=1) 65 | neg_feat_list = [F.normalize(neg, dim=1) for neg in superfeatures_list[2:]] 66 | # perform matching 67 | indices = match_super(query_feat, pos_feat) 68 | if indices.size(0)==0: 69 | return torch.sum(query_feat[:1,:1])*0.0 # for having a gradient that depends on the input to avoid torch error when using multiple processes 70 | # loss 71 | nneg = len(neg_feat_list) 72 | target = torch.Tensor( ([-1, 1]+[0]*nneg) * len(indices)).to(dtype=torch.int64, device=indices.device) 73 | catfeats = torch.cat([query_feat[indices, None, :], pos_feat[indices, None, :]] + \ 74 | [neg_feat[indices,None,:] for neg_feat in neg_feat_list], dim=1) # take qindices for the negatives 75 | catfeats = catfeats.view(-1, query_feat.size(1)) 76 | 77 | loss = self.criterion(catfeats.T, target.detach()) 78 | return loss * self.weight 79 | 80 | def __repr__(self): 81 | return "{:s}(margin={:g}, weight={:g})".format(self.__class__.__name__, self.criterion.margin, self.weight) -------------------------------------------------------------------------------- /models/losses/truncated_smoothap.py: -------------------------------------------------------------------------------- 1 | # Implemented as per "Recall@k Surrogate Loss with Large Batches and Similarity Mixup" paper 2 | # but only the fixed number of the closest positives is considered 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from models.losses.loss_utils import sigmoid, compute_aff 8 | 9 | 10 | class TruncatedSmoothAP: 11 | def __init__(self, tau1: float = 0.01, similarity: str = 'cosine', positives_per_query: int = 4): 12 | # We reversed the notation compared to the paper (tau1 is sigmoid on similarity differences) 13 | # tau1: sigmoid temperature applied on similarity differences 14 | # positives_per_query: number of positives per query to consider 15 | # negatives_only: if True in denominator we consider positives and negatives; if False we consider all elements 16 | # (with except to the anchor itself) 17 | 18 | self.tau1 = tau1 19 | self.similarity = similarity 20 | self.positives_per_query = positives_per_query 21 | 22 | def __call__(self, embeddings, positives_mask, negatives_mask): 23 | device = embeddings.device 24 | 25 | positives_mask = positives_mask.to(device) 26 | negatives_mask = negatives_mask.to(device) 27 | 28 | # Ranking of the retrieval set 29 | # For each element we ignore elements that are neither positives nor negatives 30 | 31 | # Compute cosine similarity scores 32 | # 1st dimension corresponds to q, 2nd dimension to z 33 | s_qz = compute_aff(embeddings, similarity=self.similarity) 34 | 35 | # Find the positives_per_query closest positives for each query 36 | s_positives = s_qz.detach().clone() 37 | s_positives.masked_fill_(torch.logical_not(positives_mask), np.NINF) 38 | #closest_positives_ndx = torch.argmax(s_positives, dim=1).view(-1, 1) # Indices of closests positives for each query 39 | closest_positives_ndx = torch.topk(s_positives, k=self.positives_per_query, dim=1, largest=True, sorted=True)[1] 40 | # closest_positives_ndx is (batch_size, positives_per_query) with positives_per_query closest positives 41 | # per each batch element 42 | 43 | n_positives = positives_mask.sum(dim=1) # Number of positives for each anchor 44 | 45 | # Compute the rank of each example x with respect to query element q as per Eq. (2) 46 | s_diff = s_qz.unsqueeze(1) - s_qz.gather(1, closest_positives_ndx).unsqueeze(2) 47 | s_sigmoid = sigmoid(s_diff, temp=self.tau1) 48 | 49 | # Compute the nominator in Eq. 2 and 5 - for q compute the ranking of each of its positives with respect to other positives of q 50 | # Filter out z not in Positives 51 | pos_mask = positives_mask.unsqueeze(1) 52 | pos_s_sigmoid = s_sigmoid * pos_mask 53 | 54 | # Filter out z on the same position as the positive (they have value = 0.5, as the similarity difference is zero) 55 | mask = torch.ones_like(pos_s_sigmoid).scatter(2, closest_positives_ndx.unsqueeze(2), 0.) 56 | pos_s_sigmoid = pos_s_sigmoid * mask 57 | 58 | # Compute the rank for each query and its positives_per_query closest positive examples with respect to other positives 59 | r_p = torch.sum(pos_s_sigmoid, dim=2) + 1. 60 | # r_p is (batch_size, positives_per_query) matrix 61 | 62 | # Consider only positives and negatives in the denominator 63 | # Compute the denominator in Eq. 5 - add sum of Indicator function for negatives (or non-positives) 64 | neg_mask = negatives_mask.unsqueeze(1) 65 | neg_s_sigmoid = s_sigmoid * neg_mask 66 | r_omega = r_p + torch.sum(neg_s_sigmoid, dim=2) 67 | 68 | # Compute R(i, S_p) / R(i, S_omega) ration in Eq. 2 69 | r = r_p / r_omega 70 | 71 | # Compute metrics mean ranking of the positive example, recall@1 72 | stats = {} 73 | # Mean number of positives per query 74 | stats['positives_per_query'] = n_positives.float().mean(dim=0).item() 75 | # Mean ranking of selected positive examples (closests positives) 76 | temp = s_diff.detach() > 0 77 | temp = torch.logical_and(temp[:, 0], negatives_mask) # Take the best positive 78 | hard_ranking = temp.sum(dim=1) 79 | stats['best_positive_ranking'] = hard_ranking.float().mean(dim=0).item() 80 | # Recall at 1 81 | stats['recall'] = {1: (hard_ranking <= 1).float().mean(dim=0).item()} 82 | 83 | # r is (N, positives_per_query) tensor 84 | # Zero entries not corresponding to real positives - this happens when the number of true positives is lower than positives_per_query 85 | valid_positives_mask = torch.gather(positives_mask, 1, closest_positives_ndx) # () tensor 86 | masked_r = r * valid_positives_mask 87 | n_valid_positives = valid_positives_mask.sum(dim=1) 88 | 89 | # Filter out rows (queries) without any positive to avoid division by zero 90 | valid_q_mask = n_valid_positives > 0 91 | masked_r = masked_r[valid_q_mask] 92 | 93 | ap = (masked_r.sum(dim=1) / n_valid_positives[valid_q_mask]).mean() 94 | loss = 1. - ap 95 | 96 | stats['loss'] = loss.item() 97 | stats['ap'] = ap.item() 98 | stats['avg_embedding_norm'] = embeddings.norm(dim=1).mean().item() 99 | 100 | # Identify hard negatives for each query: the most similar negatives 101 | s_negatives = s_qz.detach().clone() 102 | s_negatives.masked_fill_(torch.logical_not(negatives_mask), np.NINF) 103 | hard_negatives_ndx = torch.topk(s_negatives, k=5, dim=1, largest=True, sorted=True)[1] 104 | 105 | query_indices = torch.arange(embeddings.size(0), device=device).view(-1, 1) 106 | hard_triplets = torch.cat((query_indices, closest_positives_ndx[:, :1], hard_negatives_ndx), dim=1) 107 | 108 | 109 | return loss, stats, hard_triplets -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | from misc.utils import UMFParams 2 | import dotsi 3 | import yaml 4 | from models.UMF.UMFnet import UMFnet 5 | 6 | 7 | def load_config(config_file): 8 | with open(config_file, 'r') as ymlfile: 9 | cfg = yaml.load(ymlfile, Loader=yaml.FullLoader) 10 | return cfg 11 | 12 | def model_factory(params: UMFParams): 13 | in_channels = 1 14 | 15 | if params.model_params.model == 'UMF': 16 | # XMFnet baseline model 17 | cfg = load_config(params.model_params_path) 18 | cfg = dotsi.Dict(cfg) 19 | # add cfg to params 20 | params.cfg = cfg 21 | model = UMFnet(cfg, final_block=cfg.model.fusion.final_block) 22 | else: 23 | raise ValueError('Unknown model: {}'.format(params.model_params.model)) 24 | 25 | 26 | 27 | return model 28 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings("ignore") 3 | 4 | import argparse 5 | import torch 6 | from datasets.dataset_utils import make_dataloaders 7 | from misc.utils import UMFParams 8 | from training.trainer import do_train 9 | import wandb 10 | 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser( 14 | description='Train Minkowski Net embeddings using BatchHard negative mining') 15 | parser.add_argument('--config', type=str, required=True, 16 | help='Path to configuration file') 17 | parser.add_argument('--model_config', type=str, required=True, 18 | help='Path to the model-specific configuration file') 19 | parser.add_argument('--debug', dest='debug', action='store_true') 20 | # Run argument integer default 0 21 | parser.add_argument('--run', type=int, default=0, help='Run number') 22 | 23 | parser.set_defaults(debug=False) 24 | 25 | args = parser.parse_args() 26 | print('Training config path: {}'.format(args.config)) 27 | print('Model config path: {}'.format(args.model_config)) 28 | print('Debug mode: {}'.format(args.debug)) 29 | 30 | params = UMFParams(args.config, args.model_config) 31 | params.print() 32 | 33 | if args.debug: 34 | torch.autograd.set_detect_anomaly(True) 35 | 36 | dataloaders = make_dataloaders(params, debug=args.debug) 37 | params.debug = args.debug 38 | 39 | if not args.debug: 40 | wandb.init( 41 | # Set the project where this run will be logged 42 | project=params.model_params.model, 43 | # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) 44 | name=f"experiment_{args.run}", 45 | # Track hyperparameters and run metadata 46 | config=params.to_dict() 47 | ) 48 | 49 | 50 | do_train(dataloaders, params, debug=args.debug) 51 | wandb.finish() 52 | -------------------------------------------------------------------------------- /training/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from datetime import datetime 3 | import numpy as np 4 | import torch 5 | import pickle 6 | import tqdm 7 | import pathlib 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | from misc.lr_scheduler import LinearWarmupCosineAnnealingLR 11 | 12 | from misc.utils import UMFParams, get_datetime 13 | from models.loss import make_loss 14 | from models.model_factory import model_factory 15 | from models.UMF.UMFnet import UMFnet 16 | import os 17 | from matplotlib import pyplot as plt 18 | import wandb 19 | 20 | 21 | VERBOSE = False 22 | 23 | 24 | def print_stats(stats, phase): 25 | if 'num_triplets' in stats: 26 | # For triplet loss 27 | s = '{} - Loss (mean/total): {:.4f} / {:.4f} Avg. embedding norm: {:.4f} Triplets per batch (all/non-zero): {:.1f}/{:.1f}' 28 | print(s.format(phase, stats['loss'], stats['total_loss'], stats['avg_embedding_norm'], stats['num_triplets'], 29 | stats['num_non_zero_triplets'])) 30 | elif 'num_pos' in stats: 31 | s = '{} - Mean loss: {:.6f} Avg. embedding norm: {:.4f} #positives/negatives: {:.1f}/{:.1f}' 32 | print(s.format(phase, stats['loss'], stats['avg_embedding_norm'], stats['num_pos'], stats['num_neg'])) 33 | 34 | s = '' 35 | l = [] 36 | if 'mean_pos_pair_dist' in stats: 37 | s += 'Pos dist (min/mean/max): {:.4f}/{:.4f}/{:.4f} Neg dist (min/mean/max): {:.4f}/{:.4f}/{:.4f}' 38 | l += [stats['min_pos_pair_dist'], stats['mean_pos_pair_dist'], stats['max_pos_pair_dist'], 39 | stats['min_neg_pair_dist'], stats['mean_neg_pair_dist'], stats['max_neg_pair_dist']] 40 | if 'pos_loss' in stats: 41 | if len(s) > 0: 42 | s += ' ' 43 | s += 'Pos loss: {:.4f} Neg loss: {:.4f}' 44 | l += [stats['pos_loss'], stats['neg_loss']] 45 | if len(l) > 0: 46 | print(s.format(*l)) 47 | 48 | if 'final_loss' in stats: 49 | # Multi loss 50 | s1 = '{} - Loss (total/final'.format(phase) 51 | s2 = '{:.4f} / {:.4f}'.format(stats['loss'], stats['final_loss']) 52 | s3 = 'Active triplets (final ' 53 | s4 = '{:.1f}'.format(stats['final_num_non_zero_triplets']) 54 | if 'cloud_loss' in stats: 55 | s1 += '/cloud' 56 | s2 += '/ {:.4f}'.format(stats['cloud_loss']) 57 | s3 += '/cloud' 58 | s4 += '/ {:.1f}'.format(stats['cloud_num_non_zero_triplets'],) 59 | if 'image_loss' in stats: 60 | s1 += '/image' 61 | s2 += '/ {:.4f}'.format(stats['image_loss']) 62 | s3 += '/image' 63 | s4 += '/ {:.1f}'.format(stats['image_num_non_zero_triplets'],) 64 | 65 | s1 += '): ' 66 | s3 += '): ' 67 | print(s1 + s2) 68 | print(s3 + s4) 69 | 70 | 71 | def tensors_to_numbers(stats): 72 | stats = {e: stats[e].item() if torch.is_tensor(stats[e]) else stats[e] for e in stats} 73 | return stats 74 | 75 | 76 | def split_batch_pre_tensor(batch, minibatch_size): 77 | """Split a batch into minibatches before converting to tensors.""" 78 | minibatches = [] 79 | # Determine the size of the batch based on one of the elements 80 | batch_size = len(next(iter(batch.values()))) 81 | for start_idx in range(0, batch_size, minibatch_size): 82 | end_idx = min(start_idx + minibatch_size, batch_size) # Handle the case of the last minibatch being smaller 83 | minibatch = {} 84 | for key, value in batch.items(): 85 | if key not in ['coordinates', 'voxel_features']: 86 | minibatch[key] = value[start_idx:end_idx] 87 | 88 | if key == 'coordinates': 89 | coords = value.cpu() 90 | indices = torch.where((coords[:, 0] >= start_idx) & (coords[:, 0] < end_idx)) 91 | coords = coords[indices[0], :] 92 | coords[:, 0] = coords[:, 0] - start_idx 93 | coords[:, 0] = coords[:, 0].to(torch.int32) 94 | minibatch['coordinates'] = coords 95 | feats_batch = torch.ones((coords.shape[0], 1), dtype=torch.float32) 96 | minibatch['voxel_features'] = feats_batch 97 | 98 | minibatches.append(minibatch) 99 | return minibatches 100 | 101 | 102 | def train_step(batch, model, phase, device, optimizer, scheduler, loss_fn, params): 103 | assert phase in ['train', 'val'] 104 | if phase == 'train': 105 | model.train() 106 | else: 107 | model.eval() 108 | positives_mask = batch["positives_mask"] 109 | negatives_mask = batch["negatives_mask"] 110 | 111 | # Move batch to device 112 | batch = {e: torch.from_numpy(np.array(batch[e])).to(device) for e in batch} 113 | with torch.cuda.amp.autocast(enabled=phase == 'train'): 114 | with torch.set_grad_enabled(phase == 'train'): 115 | y = model(batch) 116 | loss, stats, _ = loss_fn(y, positives_mask, negatives_mask) 117 | 118 | stats = tensors_to_numbers(stats) 119 | stats['loss'] = loss.item() 120 | 121 | if phase == 'train': 122 | loss.backward() 123 | torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0) 124 | optimizer.step() 125 | if params.scheduler in ['OneCycleLR', 'WarmupCosineSchedule']: 126 | scheduler.step() 127 | optimizer.zero_grad() 128 | return stats, loss 129 | 130 | 131 | 132 | def multistaged_training_step(batch, model, phase, device, optimizer, loss_fn, mode, dataset): 133 | # Training step using multistaged backpropagation algorithm as per: 134 | # "Learning with Average Precision: Training Image Retrieval with a Listwise Loss" 135 | # This method will break when the model contains Dropout, as the same mini-batch will produce different embeddings. 136 | # Make sure mini-batches in step 1 and step 3 are the same (so that BatchNorm produces the same results) 137 | # See some exemplary implementation here: https://gist.github.com/ByungSun12/ad964a08eba6a7d103dab8588c9a3774 138 | assert phase in ['train', 'val'] 139 | minibatch_size = 20 140 | 141 | positives_mask = batch['positives_mask'].to(device) 142 | negatives_mask = batch['negatives_mask'].to(device) 143 | 144 | if phase == 'train': 145 | model.train() 146 | else: 147 | model.eval() 148 | 149 | # Stage 1 - calculate descriptors of each batch element (with gradient turned off) 150 | # In training phase network is in the train mode to update BatchNorm stats 151 | img_attn_l = [] 152 | cloud_attn_l = [] 153 | embeddings = [] 154 | cloud_embeddings, image_embeddings = None, None 155 | embeddings_list, cloud_embeddings_list, image_embeddings_list = [], [], [] 156 | with torch.cuda.amp.autocast(enabled=phase == 'train'): 157 | with torch.set_grad_enabled(False): 158 | minibatches = split_batch_pre_tensor(batch, minibatch_size) 159 | for minibatch in minibatches: 160 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 161 | 162 | y = model(minibatch) 163 | embeddings_list.append(y['embedding']) 164 | if mode == 'superfeatures': 165 | cloud_embeddings_list.append(y['pc_local_feat']) 166 | image_embeddings_list.append(y['img_local_feat']) 167 | img_attn_l.append(y['img_attn']) 168 | cloud_attn_l.append(y['pc_attn']) 169 | elif mode == 'ransac': 170 | cloud_embeddings_list.append(y['cloud_embedding']) 171 | image_embeddings_list.append(y['image_embedding']) 172 | 173 | # Stage 2 - compute gradient of the loss w.r.t embeddings 174 | embeddings = torch.cat(embeddings_list, dim=0) 175 | if mode == 'superfeatures': 176 | cloud_embeddings = torch.cat(cloud_embeddings_list, dim=0) 177 | image_embeddings = torch.cat(image_embeddings_list, dim=0) 178 | pc_attn = torch.cat(cloud_attn_l, dim=0) 179 | img_attn = torch.cat(img_attn_l, dim=0) 180 | elif mode == 'ransac': 181 | cloud_embeddings = torch.cat(cloud_embeddings_list, dim=0) 182 | image_embeddings = torch.cat(image_embeddings_list, dim=0) 183 | 184 | with torch.cuda.amp.autocast(enabled=phase == 'train'): 185 | with torch.set_grad_enabled(phase == 'train'): 186 | if phase == 'train': 187 | embeddings.requires_grad_(True) 188 | if mode == 'superfeatures': 189 | cloud_embeddings.requires_grad_(True) 190 | image_embeddings.requires_grad_(True) 191 | pc_attn.requires_grad_(True) 192 | img_attn.requires_grad_(True) 193 | elif mode == 'ransac': 194 | 195 | cloud_embeddings.requires_grad_(True) 196 | image_embeddings.requires_grad_(True) 197 | 198 | out = {'embedding': embeddings} 199 | 200 | if mode == 'superfeatures' or mode == 'ransac': 201 | out.update({'cloud_embedding': cloud_embeddings, 'image_embedding': image_embeddings}) 202 | 203 | if mode == 'superfeatures': 204 | out['pc_attn'] = pc_attn 205 | out['img_attn'] = img_attn 206 | loss, stats, _ = loss_fn(out, positives_mask=positives_mask, negatives_mask=negatives_mask) 207 | stats = tensors_to_numbers(stats) 208 | stats['loss'] = loss.item() 209 | if phase == 'train': 210 | loss.backward() 211 | embeddings_grad = embeddings.grad 212 | if mode == 'superfeatures': 213 | cloud_embeddings_grad = cloud_embeddings.grad 214 | image_embeddings_grad = image_embeddings.grad 215 | pc_attn_grad = pc_attn.grad 216 | img_attn_grad = img_attn.grad 217 | elif mode == 'ransac': 218 | cloud_embeddings_grad = cloud_embeddings.grad 219 | image_embeddings_grad = image_embeddings.grad 220 | 221 | # Delete intermediary values 222 | del embeddings_list, cloud_embeddings_list, image_embeddings_list, embeddings, cloud_embeddings, image_embeddings, y, loss 223 | # Stage 3 - recompute descriptors with gradient enabled and compute the gradient of the loss w.r.t. 224 | # network parameters using cached gradient of the loss w.r.t embeddings 225 | if phase == 'train': 226 | optimizer.zero_grad() 227 | i = 0 228 | with torch.cuda.amp.autocast(enabled=phase == 'train'): 229 | with torch.set_grad_enabled(True): 230 | minibatches = split_batch_pre_tensor(batch, minibatch_size) 231 | for minibatch in minibatches: 232 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 233 | 234 | y = model(minibatch) 235 | embeddings = y['embedding'] 236 | minibatch_size = len(embeddings) 237 | # Compute gradients of network params w.r.t. the loss using the chain rule (using the 238 | # gradient of the loss w.r.t. embeddings stored in embeddings_grad) 239 | # By default gradients are accumulated 240 | # For all but the last minibatch, retain the graph 241 | if mode == 'superfeatures': 242 | grads = torch.cat([embeddings_grad[i: i+minibatch_size], 243 | cloud_embeddings_grad[i: i+minibatch_size], 244 | image_embeddings_grad[i: i+minibatch_size], 245 | pc_attn_grad[i: i+minibatch_size], 246 | img_attn_grad[i: i+minibatch_size]], dim=1) 247 | 248 | outputs = torch.cat([y['embedding'], y['pc_local_feat'], 249 | y['img_local_feat'], y['pc_attns'], y['img_attns']], dim=1) 250 | 251 | elif mode == 'ransac': 252 | grads = torch.cat([embeddings_grad[i: i+minibatch_size], cloud_embeddings_grad[i: i+minibatch_size], image_embeddings_grad[i: i+minibatch_size]], dim=1) 253 | outputs = torch.cat([y['embedding'], y['cloud_embedding'], y['image_embedding']], dim=1) 254 | 255 | else: 256 | grads = embeddings_grad[i: i+minibatch_size] 257 | outputs = y['embedding'] 258 | 259 | outputs = y['embedding'] 260 | outputs.backward(gradient=grads) 261 | i += minibatch_size 262 | torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0) 263 | optimizer.step() 264 | return stats 265 | 266 | 267 | def do_train(dataloaders, params: UMFParams, debug=False): 268 | # Create model class 269 | s = get_datetime() 270 | model = model_factory(params) 271 | 272 | # Move the model to the proper device before configuring the optimizer 273 | if torch.cuda.is_available(): 274 | device = "cuda" 275 | model.to(device) 276 | else: 277 | device = "cpu" 278 | 279 | print('Model device: {}'.format(device)) 280 | 281 | if params.load_weights is not None: 282 | assert os.path.exists(params.load_weights), 'Cannot open network weights: {}'.format(params.load_weights) 283 | print('Loading weights: {}'.format(params.load_weights)) 284 | model.load_state_dict(torch.load(params.load_weights, map_location=device)) 285 | 286 | model_name = 'model_' + params.model_params.model + '_' + s 287 | mode = params.model_params.params.get('mode', 'ransac') 288 | 289 | print('Model name: {}'.format(model_name)) 290 | weights_path = create_weights_folder() 291 | 292 | model_pathname = os.path.join(weights_path, model_name) 293 | if hasattr(model, 'print_info'): 294 | model.print_info() 295 | else: 296 | n_params = sum([param.nelement() for param in model.parameters()]) 297 | print('Number of model parameters: {}'.format(n_params)) 298 | 299 | loss_fn = make_loss(params) 300 | params_l = [] 301 | if isinstance(model, UMFnet): 302 | # Different LR for image feature extractor (pretrained ResNet) 303 | if model.image_fe is not None: 304 | lower_lr = params.image_lr / 10 305 | num_top_layers = 0 306 | # Set lower lr for top layers of resnet 307 | for i, (name, param) in enumerate(model.image_fe.named_parameters()): 308 | if i >= len(list(model.image_fe.parameters())) - num_top_layers: 309 | params_l.append({'params': param, 'lr': lower_lr}) 310 | else: 311 | params_l.append({'params': param, 'lr': params.image_lr}) 312 | 313 | if model.cloud_fe is not None: 314 | params_l.append({'params': model.cloud_fe.parameters(), 'lr': params.lr}) 315 | if model.final_block is not None: 316 | params_l.append({'params': model.fusion_encoder.parameters(), 'lr': params.lr}) 317 | else: 318 | # All parameters use the same lr 319 | params_l.append({'params': model.parameters(), 'lr': params.lr}) 320 | 321 | # Training elements 322 | if params.optimizer == 'Adam': 323 | if params.weight_decay is None or params.weight_decay == 0: 324 | optimizer = torch.optim.Adam(params_l) 325 | else: 326 | optimizer = torch.optim.Adam(params_l, weight_decay=params.weight_decay) 327 | elif params.optimizer == 'AdamW': 328 | if params.weight_decay is None or params.weight_decay == 0: 329 | optimizer = torch.optim.AdamW(params_l) 330 | else: 331 | optimizer = torch.optim.AdamW(params_l, weight_decay=params.weight_decay, eps=1e-4) 332 | elif params.optimizer == 'SGD': 333 | # SGD with momentum (default momentum = 0.9) 334 | if params.weight_decay is None or params.weight_decay == 0: 335 | optimizer = torch.optim.SGD(params_l, momentum=0.9 ) 336 | else: 337 | optimizer = torch.optim.SGD(params_l, weight_decay=params.weight_decay, momentum=0.9) 338 | else: 339 | raise NotImplementedError('Unsupported optimizer: {}'.format(params.optimizer)) 340 | 341 | 342 | if params.scheduler is None or params.optimizer == 'auto': 343 | scheduler = None 344 | else: 345 | if params.scheduler == 'CosineAnnealingLR': 346 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.epochs+1, 347 | eta_min=params.min_lr) 348 | 349 | elif params.scheduler == 'LinearWarmupCosineAnnealingLR': 350 | scheduler = LinearWarmupCosineAnnealingLR(optimizer, 351 | warmup_epochs=params.warmup_epochs, 352 | max_epochs=params.epochs) 353 | 354 | elif params.scheduler == 'OneCycleLR': 355 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=params.lr, 356 | total_steps=params.epochs*len(dataloaders['train'])) 357 | 358 | elif params.scheduler == 'ExpotentialLR': 359 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) 360 | 361 | elif params.scheduler == 'MultiStepLR': 362 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, params.scheduler_milestones, gamma=0.1) 363 | elif params.scheduler == 'ReduceLROnPlateau': 364 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=6, cooldown=2) 365 | 366 | else: 367 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(params.scheduler)) 368 | 369 | 370 | ########################################################################### 371 | # Initialize TensorBoard writer 372 | ########################################################################### 373 | now = datetime.now() 374 | logdir = os.path.join("../tf_logs", now.strftime("%Y%m%d-%H%M%S")) 375 | writer = SummaryWriter(logdir) 376 | 377 | 378 | is_validation_set = 'val' in dataloaders 379 | if is_validation_set: 380 | phases = ['train', 'val'] 381 | else: 382 | phases = ['train'] 383 | 384 | # Training statistics 385 | stats = {'train': [], 'val': [], 'eval': []} 386 | best_val_loss = 1e10 387 | best_val_triplets = 1e10 388 | 389 | 390 | for epoch in tqdm.tqdm(range(1, params.epochs + 1)): 391 | metrics = {'train': {}, 'val': {}} # Metrics for wandb reporting 392 | 393 | first_iter = True 394 | for phase in phases: 395 | running_stats = [] 396 | if phase == 'train': 397 | model.train() 398 | else: 399 | model.eval() 400 | 401 | running_stats = [] # running stats for the current epoch 402 | for batch_idx, (batch) in enumerate(dataloaders[phase]): 403 | # Prepare data and masks for multi-staged training 404 | positives_mask = batch["positives_mask"] 405 | negatives_mask = batch["negatives_mask"] 406 | dataset = dataloaders[phase].dataset 407 | 408 | # Skip batches without positives or negatives 409 | if torch.sum(positives_mask) == 0 or torch.sum(negatives_mask) == 0: 410 | print('WARNING: Skipping batch without positive or negative examples') 411 | continue 412 | 413 | if params.train_step == 'multistaged': 414 | # Multi-staged training step 415 | batch_stats = multistaged_training_step( 416 | batch, model, phase, device, optimizer, loss_fn, mode, dataset 417 | ) 418 | else: 419 | # Standard training step 420 | batch_stats, loss = train_step( 421 | batch, model, phase, device, optimizer, scheduler, loss_fn, params 422 | ) 423 | 424 | running_stats.append(batch_stats) 425 | 426 | if params.debug and batch_idx > 10: 427 | break 428 | 429 | if params.scheduler in ['OneCycleLR']: 430 | scheduler.step() 431 | 432 | epoch_stats = {} 433 | for key in running_stats[0]: 434 | temp = [e[key] for e in running_stats] 435 | if type(temp[0]) is dict: 436 | epoch_stats[key] = {key: np.mean([e[key] for e in temp]) for key in temp[0]} 437 | elif type(temp[0]) is np.ndarray: 438 | # Mean value per vector element 439 | epoch_stats[key] = np.mean(np.stack(temp), axis=0) 440 | else: 441 | epoch_stats[key] = np.mean(temp) 442 | 443 | stats[phase].append(epoch_stats) 444 | print_stats(epoch_stats, phase) 445 | 446 | # Log metrics for wandb 447 | for key in epoch_stats: 448 | if type(epoch_stats[key]) is dict: 449 | metrics[phase].update(epoch_stats[key]) 450 | else: 451 | metrics[phase][key] = epoch_stats[key] 452 | 453 | if wandb.run is not None: 454 | wandb.log(metrics) 455 | 456 | # ******* EPOCH END ******* 457 | if params.scheduler not in ['OneCycleLR']: 458 | if params.scheduler == 'ReduceLROnPlateau': 459 | scheduler.step(epoch_stats['loss']) 460 | else: 461 | scheduler.step() 462 | 463 | model.zero_grad() 464 | loss_metrics = {'train': stats['train'][-1]['loss']} 465 | if 'val' in phases: 466 | loss_metrics['val'] = stats['val'][-1]['loss'] 467 | 468 | print('Current lr: {:.7f}'.format(optimizer.param_groups[0]['lr'])) 469 | if wandb.run is not None: 470 | wandb.log({"lr": optimizer.param_groups[0]['lr']}) 471 | epoch_val_stats = stats['val'][-1] 472 | val_num_non_zero_triplets = epoch_val_stats.get("final_num_non_zero_triplets", 0) 473 | 474 | if val_num_non_zero_triplets < best_val_triplets: 475 | if not params.debug: 476 | torch.save(model.state_dict(), model_pathname + '_best_triplets.pth') 477 | 478 | if loss_metrics['val'] < best_val_loss: 479 | best_val_loss = loss_metrics['val'] 480 | #best_model_wts = copy.deepcopy(model.state_dict()) 481 | print('New best model found, loss: {:.6f}'.format(best_val_loss)) 482 | print('Saving model to {}'.format(model_pathname)) 483 | # Save model 484 | best_model_path = model_pathname + '_best.pth' 485 | torch.save(model.state_dict(), best_model_path) 486 | if not params.debug: 487 | torch.save(model.state_dict(), os.path.join(wandb.run.dir, 'best_model.pth')) 488 | 489 | 490 | writer.add_scalars('Loss', loss_metrics, epoch) 491 | 492 | if 'num_triplets' in stats['train'][-1]: 493 | nz_metrics = {'train': stats['train'][-1]['num_non_zero_triplets']} 494 | 495 | if 'val' in phases: 496 | nz_metrics['val'] = stats['val'][-1]['num_non_zero_triplets'] 497 | writer.add_scalars('Non-zero triplets', nz_metrics, epoch) 498 | 499 | elif 'num_pairs' in stats['train'][-1]: 500 | nz_metrics = {'train_pos': stats['train'][-1]['pos_pairs_above_threshold'], 501 | 'train_neg': stats['train'][-1]['neg_pairs_above_threshold']} 502 | if 'val' in phases: 503 | nz_metrics['val_pos'] = stats['val'][-1]['pos_pairs_above_threshold'] 504 | nz_metrics['val_neg'] = stats['val'][-1]['neg_pairs_above_threshold'] 505 | writer.add_scalars('Non-zero pairs', nz_metrics, epoch) 506 | 507 | 508 | if params.batch_expansion_th is not None: 509 | # Dynamic batch expansion of the training batch 510 | epoch_train_stats = stats['train'][-1] 511 | if 'num_non_zero_triplets' in epoch_train_stats: 512 | # Ratio of non-zero triplets 513 | rnz = epoch_train_stats['num_non_zero_triplets'] / epoch_train_stats['num_triplets'] 514 | if rnz < params.batch_expansion_th: 515 | dataloaders['train'].batch_sampler.expand_batch() 516 | elif 'final_num_non_zero_triplets' in epoch_train_stats: 517 | rnz = [] 518 | rnz.append(epoch_train_stats['final_num_non_zero_triplets'] / epoch_train_stats['final_num_triplets']) 519 | if 'image_num_non_zero_triplets' in epoch_train_stats: 520 | rnz.append(epoch_train_stats['image_num_non_zero_triplets'] / epoch_train_stats['image_num_triplets']) 521 | if 'cloud_num_non_zero_triplets' in epoch_train_stats: 522 | rnz.append(epoch_train_stats['cloud_num_non_zero_triplets'] / epoch_train_stats['cloud_num_triplets']) 523 | rnz = max(rnz) 524 | if rnz < params.batch_expansion_th: 525 | dataloaders['train'].batch_sampler.expand_batch() 526 | else: 527 | print('WARNING: Batch size expansion is enabled, but the loss function is not supported') 528 | print('') 529 | 530 | # Save final model weights 531 | final_model_path = model_pathname + '_final.pth' 532 | torch.save(model.state_dict(), final_model_path) 533 | print('Final model saved to: {}'.format(final_model_path)) 534 | 535 | 536 | def export_eval_stats(file_name, prefix, eval_stats): 537 | s = prefix 538 | ave_1p_recall_l = [] 539 | ave_recall_l = [] 540 | # Print results on the final model 541 | with open(file_name, "a") as f: 542 | for ds in ['etna', 'university', 'residential', 'business']: 543 | if ds not in eval_stats: 544 | continue 545 | ave_1p_recall = eval_stats[ds]['ave_one_percent_recall'] 546 | ave_1p_recall_l.append(ave_1p_recall) 547 | ave_recall = eval_stats[ds]['ave_recall'][0] 548 | ave_recall_l.append(ave_recall) 549 | s += ", {:0.2f}, {:0.2f}".format(ave_1p_recall, ave_recall) 550 | 551 | mean_1p_recall = np.mean(ave_1p_recall_l) 552 | mean_recall = np.mean(ave_recall_l) 553 | s += ", {:0.2f}, {:0.2f}\n".format(mean_1p_recall, mean_recall) 554 | f.write(s) 555 | 556 | 557 | def create_weights_folder(): 558 | # Create a folder to save weights of trained models 559 | this_file_path = pathlib.Path(__file__).parent.absolute() 560 | temp, _ = os.path.split(this_file_path) 561 | weights_path = os.path.join(temp, 'weights') 562 | if not os.path.exists(weights_path): 563 | os.mkdir(weights_path) 564 | assert os.path.exists(weights_path), 'Cannot create weights folder: {}'.format(weights_path) 565 | return weights_path 566 | --------------------------------------------------------------------------------