├── teaser.png ├── models ├── minklocrgb.txt ├── minkloc3d.txt ├── minklocmultimodal.txt ├── __pycache__ │ ├── loss.cpython-38.pyc │ ├── minkfpn.cpython-38.pyc │ ├── minkloc.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── loss_utils.cpython-38.pyc │ ├── minkloconly.cpython-38.pyc │ ├── model_factory.cpython-38.pyc │ └── minkloc_multimodal.cpython-38.pyc ├── model_factory.py ├── loss_utils.py ├── minkloc.py ├── minkloconly.py ├── resnet.py ├── minkfpn.py ├── loss.py └── minkloc_multimodal.py ├── hyptorch ├── __pycache__ │ ├── nn.cpython-38.pyc │ └── pmath.cpython-38.pyc ├── delta.py └── nn.py ├── network ├── __pycache__ │ ├── ffb.cpython-38.pyc │ ├── adaptor.cpython-38.pyc │ ├── grids.cpython-38.pyc │ ├── pooling.cpython-38.pyc │ ├── univpr.cpython-38.pyc │ ├── ffb_local.cpython-38.pyc │ ├── swinblock.cpython-38.pyc │ ├── univpr_v2.cpython-38.pyc │ ├── gatt_block.cpython-38.pyc │ ├── pointnet2mlp.cpython-38.pyc │ ├── image_pool_fns.cpython-38.pyc │ ├── minklocsimple.cpython-38.pyc │ ├── distil_imagefes.cpython-38.pyc │ ├── gatt_image_block.cpython-38.pyc │ ├── general_imagefes.cpython-38.pyc │ ├── general_minkfpn.cpython-38.pyc │ ├── pointnet_simple.cpython-38.pyc │ ├── resnetfpn_simple.cpython-38.pyc │ ├── swin_transformer.cpython-38.pyc │ ├── later_cloud_branch.cpython-38.pyc │ ├── later_image_branch.cpython-38.pyc │ ├── graph_attention_layer.cpython-38.pyc │ └── graph_attention_layer_fusion.cpython-38.pyc ├── general_pointnet.py ├── pooling.py ├── minklocsimple.py ├── general_minkfpn.py ├── general_imagefes.py ├── image_pool_fns.py ├── distil_imagefes.py └── resnetfpn_simple.py ├── selfagent ├── __pycache__ │ ├── kd.cpython-38.pyc │ ├── afd.cpython-38.pyc │ ├── csd.cpython-38.pyc │ ├── mkd.cpython-38.pyc │ ├── rkd.cpython-38.pyc │ ├── rkdg.cpython-38.pyc │ ├── rkdg2.cpython-38.pyc │ ├── epcnet.cpython-38.pyc │ └── lsdnet.cpython-38.pyc └── rkdg.py ├── tools ├── __pycache__ │ ├── utils.cpython-38.pyc │ ├── options.cpython-38.pyc │ ├── utils_adafusion.cpython-38.pyc │ └── utils_minkloc3dv2.cpython-38.pyc ├── utils_adafusion.py ├── utils.py └── utils_minkloc3dv2.py ├── datasets ├── __pycache__ │ ├── oxford.cpython-38.pyc │ ├── samplers.cpython-38.pyc │ ├── augmentation.cpython-38.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── make_collate_fn.cpython-38.pyc │ ├── make_dataloaders.cpython-38.pyc │ ├── dataloader_dataset.cpython-38.pyc │ └── dataloader_dataset_kitti.cpython-38.pyc ├── make_collate_fn.py ├── make_dataloaders.py ├── samplers.py ├── dataloader_dataset.py └── augmentation.py ├── layers ├── __pycache__ │ ├── pooling.cpython-38.pyc │ └── eca_block.cpython-38.pyc ├── eca_block.py └── pooling.py ├── config ├── config_refined.txt ├── config_baseline.txt ├── config_baseline_rgb.txt └── config_baseline_multimodal.txt ├── third_party └── robotcardatasetsdk │ └── image.py ├── README.md └── multi_stage_train.py /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/teaser.png -------------------------------------------------------------------------------- /models/minklocrgb.txt: -------------------------------------------------------------------------------- 1 | # MinkLocRGB model 2 | [MODEL] 3 | model = MinkLocRGB 4 | -------------------------------------------------------------------------------- /models/minkloc3d.txt: -------------------------------------------------------------------------------- 1 | # MinkLoc3D model 2 | [MODEL] 3 | model = MinkLoc3D 4 | mink_quantization_size = 0.01 5 | -------------------------------------------------------------------------------- /models/minklocmultimodal.txt: -------------------------------------------------------------------------------- 1 | # MinkLoc3D model 2 | [MODEL] 3 | model = MinkLocMultimodal 4 | mink_quantization_size = 0.01 5 | -------------------------------------------------------------------------------- /hyptorch/__pycache__/nn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/hyptorch/__pycache__/nn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/loss.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/ffb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/ffb.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/kd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/kd.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/tools/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/oxford.cpython-38.pyc -------------------------------------------------------------------------------- /hyptorch/__pycache__/pmath.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/hyptorch/__pycache__/pmath.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/pooling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/layers/__pycache__/pooling.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkfpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/minkfpn.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkloc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/minkloc.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/adaptor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/adaptor.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/grids.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/grids.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/pooling.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/pooling.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/univpr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/univpr.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/afd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/afd.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/csd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/csd.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/mkd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/mkd.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/rkd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/rkd.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/rkdg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/rkdg.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/rkdg2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/rkdg2.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/tools/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/samplers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/samplers.cpython-38.pyc -------------------------------------------------------------------------------- /layers/__pycache__/eca_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/layers/__pycache__/eca_block.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/loss_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/loss_utils.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/ffb_local.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/ffb_local.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/swinblock.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/swinblock.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/univpr_v2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/univpr_v2.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/epcnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/epcnet.cpython-38.pyc -------------------------------------------------------------------------------- /selfagent/__pycache__/lsdnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/selfagent/__pycache__/lsdnet.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkloconly.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/minkloconly.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/model_factory.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/model_factory.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/gatt_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/gatt_block.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/pointnet2mlp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/pointnet2mlp.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/augmentation.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/augmentation.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/image_pool_fns.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/image_pool_fns.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/minklocsimple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/minklocsimple.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils_adafusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/tools/__pycache__/utils_adafusion.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_collate_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/make_collate_fn.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/make_dataloaders.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/make_dataloaders.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/minkloc_multimodal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/models/__pycache__/minkloc_multimodal.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/distil_imagefes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/distil_imagefes.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/gatt_image_block.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/gatt_image_block.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/general_imagefes.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/general_imagefes.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/general_minkfpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/general_minkfpn.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/pointnet_simple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/pointnet_simple.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/resnetfpn_simple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/resnetfpn_simple.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/swin_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/swin_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /tools/__pycache__/utils_minkloc3dv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/tools/__pycache__/utils_minkloc3dv2.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataloader_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/dataloader_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/later_cloud_branch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/later_cloud_branch.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/later_image_branch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/later_image_branch.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/graph_attention_layer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/graph_attention_layer.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dataloader_dataset_kitti.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/datasets/__pycache__/dataloader_dataset_kitti.cpython-38.pyc -------------------------------------------------------------------------------- /network/__pycache__/graph_attention_layer_fusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sijieaaa/DistilVPR/HEAD/network/__pycache__/graph_attention_layer_fusion.cpython-38.pyc -------------------------------------------------------------------------------- /config/config_refined.txt: -------------------------------------------------------------------------------- 1 | # Config for training a single-modal model with point clouds only on Refined dataset (RobotCar and Inhouse) 2 | [DEFAULT] 3 | num_points = 4096 4 | dataset_folder = /data3/pointnetvlad/benchmark_datasets 5 | 6 | [TRAIN] 7 | num_workers = 8 8 | batch_size = 16 9 | val_batch_size = 256 10 | batch_size_limit = 256 11 | batch_expansion_rate = 1.4 12 | batch_expansion_th = 0.7 13 | 14 | lr = 1e-3 15 | epochs = 80 16 | scheduler_milestones = 60 17 | 18 | aug_mode = 1 19 | weight_decay = 1e-4 20 | 21 | loss = BatchHardTripletMarginLoss 22 | normalize_embeddings = False 23 | margin = 0.2 24 | 25 | train_file = training_queries_refine.pickle 26 | val_file = test_queries_baseline.pickle 27 | -------------------------------------------------------------------------------- /config/config_baseline.txt: -------------------------------------------------------------------------------- 1 | # Config for training a single-modal model with point clouds only on Baseline dataset (RobotCar) 2 | [DEFAULT] 3 | num_points = 4096 4 | 5 | dataset_folder = /data/sijie/vpr/MinkLocMultimodal/benchmark_datasets 6 | 7 | [TRAIN] 8 | num_workers = 8 9 | batch_size = 8 10 | val_batch_size = 256 11 | batch_size_limit = 256 12 | batch_expansion_rate = 1.4 13 | batch_expansion_th = 0.7 14 | 15 | lr = 1e-3 16 | epochs = 60 17 | scheduler_milestones = 40 18 | 19 | aug_mode = 1 20 | weight_decay = 1e-4 21 | 22 | loss = BatchHardTripletMarginLoss 23 | normalize_embeddings = False 24 | margin = 0.2 25 | 26 | train_file = training_queries_baseline.pickle 27 | val_file = test_queries_baseline.pickle 28 | -------------------------------------------------------------------------------- /config/config_baseline_rgb.txt: -------------------------------------------------------------------------------- 1 | # Config for training a single-modal model with RGB images only on Baseline dataset (RobotCar) 2 | [DEFAULT] 3 | dataset_folder = /data3/pointnetvlad/benchmark_datasets 4 | image_path = /data/sijie/vpr/RobotCar_checked_image 5 | use_cloud = False 6 | 7 | [TRAIN] 8 | num_workers = 8 9 | batch_size = 8 10 | val_batch_size = 256 11 | batch_size_limit = 256 12 | batch_expansion_rate = 1.4 13 | batch_expansion_th = 0.7 14 | 15 | lr = 1e-4 16 | image_lr = 1e-4 17 | epochs = 60 18 | scheduler_milestones = 40 19 | 20 | aug_mode = 1 21 | weight_decay = 1e-4 22 | 23 | loss = BatchHardTripletMarginLoss 24 | normalize_embeddings = False 25 | margin = 0.2 26 | 27 | train_file = training_queries_baseline.pickle 28 | val_file = test_queries_baseline.pickle -------------------------------------------------------------------------------- /config/config_baseline_multimodal.txt: -------------------------------------------------------------------------------- 1 | # Config for training a multi-modal model with point clouds and RGB images on Baseline dataset (RobotCar) 2 | [DEFAULT] 3 | num_points = 4096 4 | 5 | #dataset_folder = /home/sijie/vpr/benchmark_datasets 6 | #image_path = /home/sijie/vpr/RobotCar_checked_image 7 | 8 | #dataset_folder = /scratch/users/ntu/wang1679/vpr/benchmark_datasets 9 | #image_path = /scratch/users/ntu/wang1679/vpr/RobotCar_checked_image 10 | 11 | dataset_folder = /scratch/users/ntu/wang1679/vpr/benchmark_datasets 12 | image_path = /scratch/users/ntu/wang1679/vpr/RobotCar_checked_image 13 | 14 | 15 | [TRAIN] 16 | num_workers = 8 17 | batch_size = 80 18 | val_batch_size = 160 19 | batch_size_limit = 80 20 | batch_expansion_rate = 1.4 21 | batch_expansion_th = 0.7 22 | 23 | 24 | lr = 1e-3 25 | image_lr = 1e-4 26 | epochs = 60 27 | scheduler_milestones = 40 28 | 29 | aug_mode = 1 30 | weight_decay = 1e-4 31 | 32 | loss = MultiBatchHardTripletMarginLoss 33 | weights = 0.5, 0.5, 0.0 34 | normalize_embeddings = False 35 | margin = 0.2 36 | 37 | 38 | #train_file = training_queries_baseline.pickle 39 | #val_file = test_queries_baseline.pickle 40 | train_file = training_queries_baseline.pickle -------------------------------------------------------------------------------- /third_party/robotcardatasetsdk/image.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # 3 | # Copyright (c) 2017 University of Oxford 4 | # Authors: 5 | # Geoff Pascoe (gmp@robots.ox.ac.uk) 6 | # 7 | # This work is licensed under the Creative Commons 8 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 9 | # To view a copy of this license, visit 10 | # http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to 11 | # Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 12 | # 13 | ############################################################################### 14 | 15 | import re 16 | from PIL import Image 17 | from colour_demosaicing import demosaicing_CFA_Bayer_bilinear as demosaic 18 | import numpy as np 19 | 20 | BAYER_STEREO = 'gbrg' 21 | BAYER_MONO = 'rggb' 22 | 23 | 24 | def load_image(image_path, model=None): 25 | """Loads and rectifies an image from file. 26 | 27 | Args: 28 | image_path (str): path to an image from the dataset. 29 | model (camera_model.CameraModel): if supplied, model will be used to undistort image. 30 | 31 | Returns: 32 | numpy.ndarray: demosaiced and optionally undistorted image 33 | 34 | """ 35 | if model: 36 | camera = model.camera 37 | else: 38 | camera = re.search('(stereo|mono_(left|right|rear))', image_path).group(0) 39 | if camera == 'stereo': 40 | pattern = BAYER_STEREO 41 | else: 42 | pattern = BAYER_MONO 43 | 44 | img = Image.open(image_path) 45 | img = demosaic(img, pattern) 46 | if model: 47 | img = model.undistort(img) 48 | 49 | return np.array(img).astype(np.uint8) -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | from models.minkloc import MinkLoc 5 | from models.minkloc_multimodal import MinkLocMultimodal, ResnetFPN 6 | from models.minkloc_multimodal import ResNetFPNv2 7 | 8 | from tools.utils import set_seed 9 | set_seed(7) 10 | # from tools.options import Options 11 | # args = Options().parse() 12 | 13 | def model_factory( 14 | fuse_method, cloud_fe_size, image_fe_size, 15 | cloud_planes, cloud_layers, cloud_topdown, 16 | image_useallstages, image_fe, 17 | ): 18 | 19 | 20 | 21 | cloud_fe = MinkLoc(in_channels=1, feature_size=cloud_fe_size, output_dim=cloud_fe_size, 22 | planes=cloud_planes, layers=cloud_layers, num_top_down=cloud_topdown, 23 | conv0_kernel_size=5, block='ECABasicBlock', pooling_method='GeM') 24 | 25 | 26 | 27 | 28 | 29 | # image_fe = ResnetFPN(out_channels=image_fe_size, lateral_dim=image_fe_size, 30 | # fh_num_bottom_up=4, fh_num_top_down=0, 31 | # add_basicblock=resnetfpn_add_basicblock) 32 | image_fe = ResNetFPNv2( 33 | image_fe=image_fe, 34 | image_pool_method='GeM', 35 | image_useallstages=image_useallstages, 36 | output_dim=image_fe_size, 37 | ) 38 | 39 | 40 | model = MinkLocMultimodal( 41 | cloud_fe, cloud_fe_size, image_fe, image_fe_size, 42 | fuse_method=fuse_method 43 | ) 44 | 45 | 46 | 47 | 48 | return model 49 | -------------------------------------------------------------------------------- /models/loss_utils.py: -------------------------------------------------------------------------------- 1 | # Functions and classes used by different loss functions 2 | import numpy as np 3 | import torch 4 | from torch import Tensor 5 | 6 | EPS = 1e-5 7 | 8 | 9 | def metrics_mean(l): 10 | # Compute the mean and return as Python number 11 | metrics = {} 12 | for e in l: 13 | for metric_name in e: 14 | if metric_name not in metrics: 15 | metrics[metric_name] = [] 16 | metrics[metric_name].append(e[metric_name]) 17 | 18 | for metric_name in metrics: 19 | metrics[metric_name] = np.mean(np.array(metrics[metric_name])) 20 | 21 | return metrics 22 | 23 | 24 | def squared_euclidean_distance(x: Tensor, y: Tensor) -> Tensor: 25 | ''' 26 | Compute squared Euclidean distance 27 | Input: x is Nxd matrix 28 | y is Mxd matirx 29 | Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:] 30 | i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2 31 | Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/3 32 | ''' 33 | x_norm = (x ** 2).sum(1).view(-1, 1) 34 | y_t = torch.transpose(y, 0, 1) 35 | y_norm = (y ** 2).sum(1).view(1, -1) 36 | dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t) 37 | return torch.clamp(dist, 0.0, np.inf) 38 | 39 | 40 | def sigmoid(tensor: Tensor, temp: float) -> Tensor: 41 | """ temperature controlled sigmoid 42 | takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp 43 | """ 44 | exponent = -tensor / temp 45 | # clamp the input tensor for stability 46 | exponent = torch.clamp(exponent, min=-50, max=50) 47 | y = 1.0 / (1.0 + torch.exp(exponent)) 48 | return y 49 | 50 | 51 | def compute_aff(x: Tensor, similarity: str = 'cosine') -> Tensor: 52 | """computes the affinity matrix between an input vector and itself""" 53 | if similarity == 'cosine': 54 | x = torch.mm(x, x.t()) 55 | elif similarity == 'euclidean': 56 | x = x.unsqueeze(0) 57 | x = torch.cdist(x, x, p=2) 58 | x = x.squeeze(0) 59 | # The greater the distance the smaller affinity 60 | x = -x 61 | else: 62 | raise NotImplementedError(f"Incorrect similarity measure: {similarity}") 63 | return x -------------------------------------------------------------------------------- /hyptorch/delta.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | from scipy.spatial import distance_matrix 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | 12 | def delta_hyp(dismat): 13 | """ 14 | computes delta hyperbolicity value from distance matrix 15 | """ 16 | 17 | p = 0 18 | row = dismat[p, :][np.newaxis, :] 19 | col = dismat[:, p][:, np.newaxis] 20 | XY_p = 0.5 * (row + col - dismat) 21 | 22 | maxmin = np.max(np.minimum(XY_p[:, :, None], XY_p[None, :, :]), axis=1) 23 | return np.max(maxmin - XY_p) 24 | 25 | 26 | def batched_delta_hyp(X, n_tries=10, batch_size=1500): 27 | vals = [] 28 | for i in tqdm(range(n_tries)): 29 | idx = np.random.choice(len(X), batch_size) 30 | X_batch = X[idx] 31 | distmat = distance_matrix(X_batch, X_batch) 32 | diam = np.max(distmat) 33 | delta_rel = delta_hyp(distmat) / diam 34 | vals.append(delta_rel) 35 | return np.mean(vals), np.std(vals) 36 | 37 | 38 | class Flatten(nn.Module): 39 | def __init__(self): 40 | super().__init__() 41 | 42 | def forward(self, x): 43 | B = x.shape[0] 44 | return x.view(B, -1) 45 | 46 | 47 | def get_delta(loader): 48 | """ 49 | computes delta value for image data by extracting features using VGG network; 50 | input -- data loader for images 51 | """ 52 | vgg = torchvision.models.vgg16(pretrained=True) 53 | vgg_feats = vgg.features 54 | vgg_classifier = nn.Sequential(*list(vgg.classifier.children())[:-1]) 55 | 56 | vgg_part = nn.Sequential(vgg_feats, Flatten(), vgg_classifier).to(device) 57 | vgg_part.eval() 58 | 59 | all_features = [] 60 | for i, (batch, _) in enumerate(loader): 61 | with torch.no_grad(): 62 | batch = batch.to(device) 63 | all_features.append(vgg_part(batch).detach().cpu().numpy()) 64 | 65 | all_features = np.concatenate(all_features) 66 | idx = np.random.choice(len(all_features), 1500) 67 | all_features_small = all_features[idx] 68 | 69 | dists = distance_matrix(all_features_small, all_features_small) 70 | delta = delta_hyp(dists) 71 | diam = np.max(dists) 72 | return delta, diam 73 | -------------------------------------------------------------------------------- /layers/eca_block.py: -------------------------------------------------------------------------------- 1 | # Implementation of Efficient Channel Attention ECA block 2 | 3 | import numpy as np 4 | import torch.nn as nn 5 | 6 | import MinkowskiEngine as ME 7 | 8 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 9 | 10 | from tools.utils import set_seed 11 | set_seed(7) 12 | 13 | class ECALayer(nn.Module): 14 | def __init__(self, channels, gamma=2, b=1): 15 | super().__init__() 16 | t = int(abs((np.log2(channels) + b) / gamma)) 17 | k_size = t if t % 2 else t + 1 18 | self.avg_pool = ME.MinkowskiGlobalPooling() 19 | self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 20 | self.sigmoid = nn.Sigmoid() 21 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() 22 | 23 | def forward(self, x): 24 | # feature descriptor on the global spatial information 25 | y_sparse = self.avg_pool(x) 26 | 27 | # Apply 1D convolution along the channel dimension 28 | y = self.conv(y_sparse.F.unsqueeze(-1).transpose(-1, -2)).transpose(-1, -2).squeeze(-1) 29 | # y is (batch_size, channels) tensor 30 | 31 | # Multi-scale information fusion 32 | y = self.sigmoid(y) 33 | # y is (batch_size, channels) tensor 34 | 35 | y_sparse = ME.SparseTensor(y, coordinate_manager=y_sparse.coordinate_manager, 36 | coordinate_map_key=y_sparse.coordinate_map_key) 37 | # y must be features reduced to the origin 38 | # return self.broadcast_mul(x, y_sparse) 39 | 40 | output = self.broadcast_mul(x, y_sparse) 41 | 42 | return output 43 | 44 | 45 | class ECABasicBlock(BasicBlock): 46 | def __init__(self, 47 | inplanes, 48 | planes, 49 | stride=1, 50 | dilation=1, 51 | downsample=None, 52 | dimension=3): 53 | super(ECABasicBlock, self).__init__( 54 | inplanes, 55 | planes, 56 | stride=stride, 57 | dilation=dilation, 58 | downsample=downsample, 59 | dimension=dimension) 60 | self.eca = ECALayer(planes, gamma=2, b=1) 61 | 62 | def forward(self, x): 63 | residual = x 64 | 65 | out = self.conv1(x) 66 | out = self.norm1(out) 67 | out = self.relu(out) 68 | 69 | out = self.conv2(out) 70 | out = self.norm2(out) 71 | out = self.eca(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /models/minkloc.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import torch 5 | import torch.nn as nn 6 | import MinkowskiEngine as ME 7 | 8 | from models.minkfpn import MinkFPN 9 | from models.minkfpn import GeneralMinkFPN 10 | 11 | import layers.pooling as pooling 12 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 13 | from layers.eca_block import ECABasicBlock 14 | 15 | from layers.pooling import MinkGeM as MinkGeM 16 | 17 | # from tools.options import Options 18 | # args = Options().parse() 19 | 20 | from tools.utils import set_seed 21 | set_seed(7) 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | class MinkLoc(torch.nn.Module): 30 | def __init__(self, in_channels, feature_size, output_dim, planes, layers, num_top_down, conv0_kernel_size, 31 | block='BasicBlock', pooling_method='GeM', linear_block=False, dropout_p=None, 32 | minkfpn='minkfpn'): 33 | # block: Type of the network building block: BasicBlock or SEBasicBlock 34 | # add_linear_layers: Add linear layers at the end 35 | # dropout_p: dropout probability (None = no dropout) 36 | 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.feature_size = feature_size # Size of local features produced by local feature extraction block 40 | self.output_dim = output_dim # Dimensionality of the global descriptor produced by pooling layer 41 | self.block = block 42 | 43 | if block == 'BasicBlock': 44 | block_module = BasicBlock 45 | elif block == 'Bottleneck': 46 | block_module = Bottleneck 47 | elif block == 'ECABasicBlock': 48 | block_module = ECABasicBlock 49 | else: 50 | raise NotImplementedError('Unsupported network block: {}'.format(block)) 51 | 52 | self.pooling_method = pooling_method 53 | 54 | 55 | self.backbone = MinkFPN(in_channels=in_channels, out_channels=self.feature_size, num_top_down=num_top_down, 56 | conv0_kernel_size=conv0_kernel_size, block=block_module, layers=layers, planes=planes) 57 | 58 | 59 | self.pooling = pooling.PoolingWrapper(pool_method=pooling_method, in_dim=self.feature_size, 60 | output_dim=output_dim) 61 | 62 | 63 | def forward(self, batch): 64 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 65 | x = ME.SparseTensor(features=batch['features'], coordinates=batch['coords']) 66 | x = self.backbone(x) 67 | 68 | 69 | 70 | x_feat = x 71 | x_gem = self.pooling(x) 72 | 73 | 74 | return { 75 | 'output_cloud_feat': x_feat, 76 | 'output_cloud_gem': x_gem, 77 | 78 | 'embedding':x_gem 79 | } 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DistilVPR 2 | 3 | (AAAI 2024) DistilVPR: Cross-Modal Knowledge Distillation for Visual Place Recognition 🚀🚀🚀 4 | 5 | [ArXiv](https://arxiv.org/abs/2312.10616) 6 | 7 | 8 | 9 | 10 | ## Installation 11 | 12 | - Platform 13 | 14 | ``` 15 | Ubuntu 20.04 16 | python 3.8 17 | CUDA >= 11.8 18 | PyTorch >= 2.0 19 | ``` 20 | 21 | - PyTorch 22 | 23 | ``` 24 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 25 | ``` 26 | 27 | - MinkowskiEngine https://github.com/NVIDIA/MinkowskiEngine 28 | 29 | ``` 30 | conda install openblas-devel -c anaconda 31 | pip install pip==22.3.1 32 | pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=${CONDA_PREFIX}/include" --install-option="--blas=openblas" 33 | ``` 34 | 35 | - Others 36 | 37 | ``` 38 | pip install scikit-learn 39 | pip install tqdm 40 | pip install pytorch-metric-learning==1.1 41 | pip install tensorboard 42 | ``` 43 | 44 | 45 | 46 | ## Dataset 47 | 48 | The datasets are uploaded at [Google Drive](https://drive.google.com/drive/folders/13-3hhL0XzhXzhPULlbhuvYE6vnwxP3tE?usp=sharing). Please download them and unzip them. You need to change some arguments in `tools/options.py` as the directories: 49 | 50 | ``` 51 | --dataset 52 | --dataset_folder 53 | --image_path 54 | ``` 55 | 56 | The teachers' weights are stored in `teacher_weights/`, which is also uploaded at [Google Drive](https://drive.google.com/drive/folders/13-3hhL0XzhXzhPULlbhuvYE6vnwxP3tE?usp=sharing). 57 | 58 | 59 | 60 | ## Run 61 | 62 | We currently provide examples where the teacher is MinkLoc++ and the student is ResNet18+GeM (MinkLoc++2D): 63 | 64 | ``` 65 | # oxford 66 | python train.py --model minklocmmcat \ 67 | --teacher_weights_path teacher_weights/oxford__T:minklocmmcat__resnet18__img256__pc128__32_64_64__1_1_1__1__allstgF__b128__trainteacher/models/r1_best_ep57_97.24.pth \ 68 | --rkdgloss_weight 10 --crosslogitdistloss_weight_st2ss 0.1 --crosslogitsimloss_weight_st2ss 0.1 --crosslogitgeodistloss_weight_st2ss 0.1; 69 | 70 | 71 | # boreas 72 | python train.py --model minklocmmcat \ 73 | --teacher_weights_path teacher_weights/boreas__T:minklocmmcat__resnet18__img256__pc128__32_64_64__1_1_1__1__allstgF__b128__trainteacher/models/r1_best_ep48_93.05.pth \ 74 | --rkdgloss_weight 1 --crosslogitdistloss_weight_st2ss 0.1 --crosslogitsimloss_weight_st2ss 0.1 --crosslogitgeodistloss_weight_st2ss 0.1; 75 | ``` 76 | 77 | 78 | 79 | ## Citation 80 | 81 | ``` 82 | @inproceedings{wang2024distilvpr, 83 | title={DistilVPR: Cross-Modal Knowledge Distillation for Visual Place Recognition}, 84 | author={Wang, Sijie and She, Rui and Kang, Qiyu and Jian, Xingchao and Zhao, Kai and Song, Yang and Tay, Wee Peng}, 85 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 86 | volume={38}, 87 | number={9}, 88 | pages={10377--10385}, 89 | year={2024} 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /network/general_pointnet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch.nn as nn 3 | from pointnet2.pointnet2_modules import PointnetSAModule 4 | 5 | import torch 6 | 7 | 8 | 9 | 10 | 11 | 12 | class PointNetSimple(nn.Module): 13 | r""" 14 | Backbone network for point cloud feature learning. 15 | Based on Pointnet++ single-scale grouping network. 16 | 17 | Parameters 18 | ---------- 19 | input_feature_dim: int 20 | Number of input channels in the feature descriptor for each point. 21 | e.g. 3 for RGB. 22 | """ 23 | 24 | 25 | def __init__(self, input_feature_dim=0, use_xyz=True): 26 | super(PointNetSimple, self).__init__() 27 | 28 | self.sa1 = PointnetSAModule( 29 | npoint=1024, 30 | radius=0.1, 31 | nsample=32, 32 | mlp=[input_feature_dim, 32, 32], 33 | use_xyz=use_xyz 34 | ) 35 | 36 | 37 | self.sa2 = PointnetSAModule( 38 | npoint=512, 39 | radius=0.2, 40 | nsample=32, 41 | mlp=[32, 64, 64], 42 | use_xyz=use_xyz 43 | ) 44 | 45 | 46 | self.sa3 = PointnetSAModule( 47 | npoint=256, 48 | radius=0.4, 49 | nsample=16, 50 | mlp=[64, 64, 64, 128], 51 | use_xyz=use_xyz 52 | ) 53 | 54 | 55 | 56 | 57 | def _break_up_pc(self, pc): 58 | xyz = pc[..., 0:3].contiguous() 59 | features = ( 60 | pc[..., 3:].transpose(1, 2).contiguous() 61 | if pc.size(-1) > 3 else None 62 | ) 63 | return xyz, features 64 | 65 | 66 | 67 | def forward(self, feed_dict): 68 | r""" 69 | Forward pass of the network 70 | 71 | Parameters 72 | ---------- 73 | pointcloud: Variable(torch.cuda.FloatTensor) 74 | (B, N, 3 + input_feature_dim) tensor 75 | Point cloud to run predicts on 76 | Each point in the point-cloud MUST 77 | be formated as (x, y, z, features...) 78 | 79 | Returns 80 | ---------- 81 | end_points: {XXX_xyz, XXX_features, XXX_inds} 82 | XXX_xyz: float32 Tensor of shape (B,K,3) 83 | XXX_features: float32 Tensor of shape (B,K,D) 84 | XXX-inds: int64 Tensor of shape (B,K) values in [0,N-1] 85 | """ 86 | 87 | clouds = feed_dict['clouds'] # [b, n, 3] 88 | 89 | 90 | xyz, features = self._break_up_pc(clouds) 91 | 92 | # --------- 4 SET ABSTRACTION LAYERS --------- 93 | xyz, features = self.sa1(xyz, features) 94 | 95 | xyz, features = self.sa2(xyz, features) 96 | 97 | xyz, features = self.sa3(xyz, features) 98 | 99 | 100 | 101 | return features 102 | 103 | 104 | if __name__ == '__main__': 105 | backbone_net = PointNetSimple(input_feature_dim=0).to("cuda") 106 | print(backbone_net) 107 | backbone_net.eval() 108 | out = backbone_net(torch.rand(160, 4096, 3).to("cuda")) 109 | print(out.shape) 110 | -------------------------------------------------------------------------------- /network/pooling.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | class GeM(nn.Module): 16 | def __init__(self, p=3, eps=1e-6): 17 | super().__init__() 18 | self.p = Parameter(torch.ones(1)*p) 19 | self.eps = eps 20 | 21 | 22 | def gem(self, x, p=3, eps=1e-6): 23 | out = F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p) 24 | return out 25 | 26 | 27 | def forward(self, x): 28 | out = self.gem(x, p=self.p, eps=self.eps) 29 | return out 30 | 31 | 32 | def __repr__(self): 33 | return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')' 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py 43 | class NetVLAD(nn.Module): 44 | """NetVLAD layer implementation""" 45 | 46 | def __init__(self, num_clusters=64, dim=128, alpha=100.0, 47 | normalize_input=True): 48 | """ 49 | Args: 50 | num_clusters : int 51 | The number of clusters 52 | dim : int 53 | Dimension of descriptors 54 | alpha : float 55 | Parameter of initialization. Larger value is harder assignment. 56 | normalize_input : bool 57 | If true, descriptor-wise L2 normalization is applied to input. 58 | """ 59 | super(NetVLAD, self).__init__() 60 | self.num_clusters = num_clusters 61 | self.dim = dim 62 | self.alpha = alpha 63 | self.normalize_input = normalize_input 64 | self.conv = nn.Conv2d(dim, num_clusters, kernel_size=(1, 1), bias=True) 65 | self.centroids = nn.Parameter(torch.rand(num_clusters, dim)) 66 | # self._init_params() 67 | 68 | def _init_params(self): 69 | self.conv.weight = nn.Parameter( 70 | (2.0 * self.alpha * self.centroids).unsqueeze(-1).unsqueeze(-1) 71 | ) 72 | self.conv.bias = nn.Parameter( 73 | - self.alpha * self.centroids.norm(dim=1) 74 | ) 75 | 76 | def forward(self, x): 77 | N, C = x.shape[:2] 78 | 79 | if self.normalize_input: 80 | x = F.normalize(x, p=2, dim=1) # across descriptor dim 81 | 82 | # soft-assignment 83 | soft_assign = self.conv(x).view(N, self.num_clusters, -1) 84 | soft_assign = F.softmax(soft_assign, dim=1) 85 | 86 | x_flatten = x.view(N, C, -1) 87 | 88 | # calculate residuals to each clusters 89 | residual = x_flatten.expand(self.num_clusters, -1, -1, -1).permute(1, 0, 2, 3) - \ 90 | self.centroids.expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 91 | residual *= soft_assign.unsqueeze(2) 92 | vlad = residual.sum(dim=-1) 93 | 94 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 95 | vlad = vlad.view(x.size(0), -1) # flatten 96 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 97 | 98 | return vlad 99 | 100 | -------------------------------------------------------------------------------- /selfagent/rkdg.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from hyptorch.nn import ToPoincare 6 | from hyptorch.pmath import dist, dist_matrix 7 | 8 | 9 | 10 | from tools.options import Options 11 | args = Options().parse() 12 | from tools.utils import set_seed 13 | set_seed(7) 14 | 15 | 16 | 17 | 18 | def _pdist(e, squared, eps): 19 | e_square = e.pow(2).sum(dim=1) 20 | prod = e @ e.t() 21 | res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps) 22 | 23 | if not squared: 24 | res = res.sqrt() 25 | 26 | res = res.clone() 27 | res[range(len(e)), range(len(e))] = 0 28 | return res 29 | 30 | 31 | 32 | 33 | 34 | def compute_rkdg_loss(output_dict_stu, output_dict_tea, squared=False, eps=1e-12, distance_weight=1, angle_weight=1, geodesic_weight=1): 35 | f_s = output_dict_stu['embedding'] 36 | f_t = output_dict_tea['embedding'] 37 | 38 | stu = f_s.view(f_s.shape[0], -1) 39 | tea = f_t.view(f_t.shape[0], -1) 40 | 41 | 42 | 43 | 44 | 45 | # RKD distance loss 46 | with torch.no_grad(): 47 | t_d = _pdist(tea, squared, eps) 48 | mean_td = t_d[t_d > 0].mean() 49 | t_d = t_d / mean_td 50 | 51 | d = _pdist(stu, squared, eps) 52 | mean_d = d[d > 0].mean() 53 | d = d / mean_d 54 | 55 | loss_d = F.smooth_l1_loss(d, t_d) 56 | 57 | 58 | 59 | # RKD Angle loss 60 | with torch.no_grad(): 61 | td = tea.unsqueeze(0) - tea.unsqueeze(1) 62 | norm_td = F.normalize(td, p=2, dim=2) 63 | t_angle = torch.bmm(norm_td, norm_td.transpose(1, 2)).view(-1) 64 | 65 | sd = stu.unsqueeze(0) - stu.unsqueeze(1) 66 | norm_sd = F.normalize(sd, p=2, dim=2) 67 | s_angle = torch.bmm(norm_sd, norm_sd.transpose(1, 2)).view(-1) 68 | 69 | loss_a = F.smooth_l1_loss(s_angle, t_angle) 70 | 71 | 72 | 73 | 74 | # -- RKD geodesic distance loss 75 | to_poincare = ToPoincare(c=args.curvature, ball_dim=args.student_output_dim, riemannian=False, clip_r=None) 76 | # to_poincare = ToPoincare(c=1, ball_dim=args.student_output_dim, riemannian=False, clip_r=None) 77 | logit_s_poincare = to_poincare(stu) 78 | logit_t_poincare = to_poincare(tea) 79 | 80 | geodistmat_ss = dist_matrix(logit_s_poincare, logit_s_poincare, c=args.curvature) # [b,b] 81 | geodistmat_tt = dist_matrix(logit_t_poincare, logit_t_poincare, c=args.curvature) 82 | 83 | mean_ss = geodistmat_ss[geodistmat_ss > 0].mean() 84 | mean_tt = geodistmat_tt[geodistmat_tt > 0].mean() 85 | 86 | geodistmat_ss = geodistmat_ss / mean_ss 87 | geodistmat_tt = geodistmat_tt / mean_tt 88 | 89 | loss_g = F.smooth_l1_loss(geodistmat_ss, geodistmat_tt.detach()) 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | loss = distance_weight * loss_d \ 98 | + angle_weight * loss_a \ 99 | + geodesic_weight * loss_g 100 | 101 | 102 | return loss 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | 112 | 113 | b = 10 114 | c = 128 115 | 116 | output_dict_stu = {'embedding': torch.randn(b, c)} 117 | 118 | output_dict_tea = {'embedding': torch.randn(b, c)} 119 | 120 | 121 | loss = compute_rkdg_loss(output_dict_stu, output_dict_tea, squared=False, eps=1e-12, distance_weight=1, angle_weight=1) 122 | -------------------------------------------------------------------------------- /network/minklocsimple.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import torch 5 | import torch.nn as nn 6 | import MinkowskiEngine as ME 7 | 8 | from models.minkfpn import MinkFPN 9 | import layers.pooling as pooling 10 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 11 | from layers.eca_block import ECABasicBlock 12 | 13 | import numpy as np 14 | 15 | 16 | from tools.options import Options 17 | args = Options().parse() 18 | from tools.utils import set_seed 19 | set_seed(7) 20 | 21 | 22 | class MinkLocSimple(torch.nn.Module): 23 | def __init__(self, in_channels, feature_size, output_dim, planes, layers, num_top_down, conv0_kernel_size, 24 | block='BasicBlock', pooling_method='GeM', linear_block=False, dropout_p=None): 25 | # block: Type of the network building block: BasicBlock or SEBasicBlock 26 | # add_linear_layers: Add linear layers at the end 27 | # dropout_p: dropout probability (None = no dropout) 28 | 29 | super().__init__() 30 | self.in_channels = in_channels 31 | self.feature_size = feature_size # Size of local features produced by local feature extraction block 32 | self.output_dim = output_dim # Dimensionality of the global descriptor produced by pooling layer 33 | self.block = block 34 | 35 | if block == 'BasicBlock': 36 | block_module = BasicBlock 37 | elif block == 'Bottleneck': 38 | block_module = Bottleneck 39 | elif block == 'ECABasicBlock': 40 | block_module = ECABasicBlock 41 | else: 42 | raise NotImplementedError('Unsupported network block: {}'.format(block)) 43 | 44 | self.pooling_method = pooling_method 45 | self.linear_block = linear_block 46 | self.dropout_p = dropout_p 47 | self.backbone = MinkFPN(in_channels=in_channels, out_channels=self.feature_size, num_top_down=num_top_down, 48 | conv0_kernel_size=conv0_kernel_size, block=block_module, layers=layers, planes=planes) 49 | 50 | self.pooling = pooling.PoolingWrapper(pool_method=pooling_method, in_dim=self.feature_size, 51 | output_dim=output_dim) 52 | self.pooled_feature_size = self.pooling.output_dim # Number of channels returned by pooling layer 53 | 54 | if self.dropout_p is not None: 55 | self.dropout = nn.Dropout(p=self.dropout_p) 56 | else: 57 | self.dropout = None 58 | 59 | if self.linear_block: 60 | # At least output_dim neurons in intermediary layer 61 | int_channels = self.output_dim 62 | self.linear = nn.Sequential(nn.Linear(self.pooled_feature_size, int_channels, bias=False), 63 | nn.BatchNorm1d(int_channels, affine=True), 64 | nn.ReLU(inplace=True), nn.Linear(int_channels, output_dim)) 65 | else: 66 | self.linear = None 67 | 68 | 69 | 70 | def forward(self, batch_dict): 71 | 72 | 73 | # # -- new 74 | # in_field = batch_dict['in_field'] 75 | # x = in_field.sparse() 76 | # x = self.backbone(x) 77 | 78 | 79 | 80 | # -- old 81 | x = ME.SparseTensor(features=batch_dict['features'], coordinates=batch_dict['coords']) 82 | x = self.backbone(x) 83 | 84 | 85 | 86 | 87 | 88 | 89 | return x 90 | 91 | -------------------------------------------------------------------------------- /tools/utils_adafusion.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | import numpy as np 6 | 7 | 8 | import torch 9 | 10 | 11 | import open3d 12 | import numpy as np 13 | import open3d as o3d 14 | 15 | 16 | 17 | def viz_lidar_open3d(pointcloud, colors=None, width=None, height=None): 18 | 19 | x = pointcloud[:,0] # x position of point 20 | y = pointcloud[:,1] # y position of point 21 | z = pointcloud[:,2] # z position of point 22 | 23 | pcd = open3d.geometry.PointCloud() 24 | points = np.hstack([x[:,None],y[:,None],z[:,None]]) 25 | points = open3d.utility.Vector3dVector(points) 26 | pcd.points = points 27 | 28 | 29 | if colors is not None: 30 | pcd.colors = open3d.utility.Vector3dVector(colors) 31 | 32 | 33 | 34 | FOR1 = o3d.geometry.TriangleMesh.create_coordinate_frame(size=15, origin=[0, 0, 0]) 35 | 36 | 37 | 38 | if (width is not None) & (height is not None): 39 | open3d.visualization.draw_geometries([pcd,FOR1], width=width, height=height) 40 | else: 41 | 42 | open3d.visualization.draw_geometries([pcd,FOR1]) 43 | 44 | 45 | 46 | 47 | 48 | 49 | def pc_array_to_voxel(pc_ndarray: np.ndarray): 50 | """Convert (N,3) np.ndarray point cloud to voxel grid with shape 51 | `voxel_shape`. Point boundary is determined `VOXEL_LOWER_POINT` 52 | and `VOXEL_UPPER_POINT`. 53 | Args: 54 | pc_ndarray: The input (N,3) np.ndarray point cloud 55 | Returns: 56 | pc_tensor: The output point cloud voxel of type torch.Tensor 57 | [1, voxel_shape[0], voxel_shape[1], voxel_shape[2]] 58 | """ 59 | 60 | voxel_shape = np.array((72, 72, 48), np.int32) # (x,y,z) 61 | 62 | VOXEL_LOWER_VALUE = np.array((-0.8, -0.4, -0.2), np.float32) 63 | VOXEL_UPPER_VALUE = np.array((0.8, 0.4, 0.4), np.float32) # (x,y,z) 64 | VOXEL_PER_VALUE = voxel_shape / ( 65 | VOXEL_UPPER_VALUE - VOXEL_LOWER_VALUE 66 | ) # (x,y,z) 67 | 68 | 69 | 70 | 71 | voxel_index = ( 72 | pc_ndarray - np.tile(VOXEL_LOWER_VALUE, (pc_ndarray.shape[0], 1)) 73 | ) * VOXEL_PER_VALUE 74 | voxel_index = np.round(voxel_index).astype(np.int32) # raw index 75 | 76 | # filter out out-of-boundary points 77 | valid_mask = (voxel_index >= np.array([0, 0, 0], np.int32)) & ( 78 | voxel_index < voxel_shape 79 | ) # only True in (x,y,z) means valid index 80 | valid_mask = valid_mask[:, 0] & valid_mask[:, 1] & valid_mask[:, 2] 81 | voxel_index = voxel_index[valid_mask] # valid voxel index(inside voxel range) 82 | 83 | # deal with voxel according to index 84 | pc_voxel = np.zeros(voxel_shape, np.float32) # [72, 72, 48] 85 | pc_voxel[voxel_index[:, 0], voxel_index[:, 1], voxel_index[:, 2]] = 1.0 86 | pc_tensor = torch.unsqueeze(torch.from_numpy(pc_voxel), 0) # insert channel 1 87 | return pc_tensor 88 | 89 | 90 | 91 | 92 | if __name__ == '__main__': 93 | 94 | 95 | 96 | 97 | pc = np.fromfile('/data/sijie/distil/distil_v33/1400505893170765.bin', dtype=np.float64) 98 | pc = pc.reshape((-1, 3)) 99 | pc = pc.astype(np.float32) 100 | 101 | 102 | # viz_lidar_open3d(pc.copy()*100) 103 | 104 | voxel = pc_array_to_voxel(pc) 105 | 106 | 107 | 108 | voxel = voxel.squeeze(0).numpy() 109 | 110 | voxel_ids = np.where(voxel==1) 111 | 112 | voxel_ids = np.stack(voxel_ids,axis=1) # (N,3) 113 | 114 | 115 | 116 | # viz_lidar_open3d(voxel_ids) 117 | 118 | a=1 119 | 120 | 121 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import os 5 | import configparser 6 | import time 7 | import pickle 8 | import torch 9 | import numpy as np 10 | import random 11 | 12 | 13 | import matplotlib.pyplot as plt 14 | import cv2 15 | from tools.options import Options 16 | args = Options().parse() 17 | 18 | 19 | 20 | 21 | 22 | def get_datetime(): 23 | return time.strftime("%Y%m%d_%H%M") 24 | 25 | 26 | 27 | 28 | 29 | class ModelParams: 30 | def __init__(self, model_params_path): 31 | assert os.path.exists(model_params_path), 'Cannot find model-specific configuration file: {}'.format(model_params_path) 32 | config = configparser.ConfigParser() 33 | config.read(model_params_path) 34 | params = config['MODEL'] 35 | 36 | self.model_params_path = model_params_path 37 | self.model = params.get('model') 38 | self.mink_quantization_size = params.getfloat('mink_quantization_size', 0.01) 39 | 40 | def print(self): 41 | print('Model parameters:') 42 | param_dict = vars(self) 43 | for e in param_dict: 44 | print('{}: {}'.format(e, param_dict[e])) 45 | 46 | print('') 47 | 48 | 49 | 50 | class MinkLocParams: 51 | """ 52 | Params for training MinkLoc models on Oxford dataset 53 | """ 54 | def __init__(self, params_path, model_params_path=None): 55 | """ 56 | Configuration files 57 | :param path: General configuration file 58 | :param model_params: Model-specific configuration 59 | """ 60 | 61 | assert os.path.exists(params_path), 'Cannot find configuration file: {}'.format(params_path) 62 | self.params_path = params_path 63 | self.model_params_path = model_params_path 64 | 65 | config = configparser.ConfigParser() 66 | 67 | config.read(self.params_path) 68 | params = config['DEFAULT'] 69 | 70 | 71 | if args.dataset in ['oxford', 'oxfordadafusion']: 72 | self.num_points = params.getint('num_points', 4096) 73 | elif args.dataset == 'boreas': 74 | self.num_points = args.n_points_boreas 75 | else: 76 | raise Exception 77 | 78 | 79 | 80 | 81 | self.dataset_folder = args.dataset_folder 82 | self.use_cloud = params.getboolean('use_cloud', True) 83 | 84 | 85 | 86 | 87 | self._check_params() 88 | 89 | 90 | 91 | def _check_params(self): 92 | assert os.path.exists(self.dataset_folder), 'Cannot access dataset: {}'.format(self.dataset_folder) 93 | 94 | 95 | 96 | def print(self): 97 | print('Parameters:') 98 | param_dict = vars(self) 99 | for e in param_dict: 100 | if e not in ['model_params']: 101 | print('{}: {}'.format(e, param_dict[e])) 102 | 103 | # if self.model_params is not None: 104 | # self.model_params.print() 105 | print('') 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | def set_seed(seed=7): 115 | # seed = 7 116 | random.seed(seed) 117 | np.random.seed(seed) 118 | os.environ['PYTHONHASHSEED'] = str(seed) 119 | torch.manual_seed(seed) 120 | torch.cuda.manual_seed(seed) 121 | torch.cuda.manual_seed_all(seed) 122 | torch.backends.cudnn.benchmark = False 123 | torch.backends.cudnn.deterministic = True 124 | 125 | set_seed(7) 126 | 127 | 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /models/minkloconly.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import torch 5 | import torch.nn as nn 6 | import MinkowskiEngine as ME 7 | 8 | from models.minkfpn import MinkFPN 9 | from models.minkfpn import GeneralMinkFPN 10 | 11 | import layers.pooling as pooling 12 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 13 | from layers.eca_block import ECABasicBlock 14 | 15 | from layers.pooling import MinkGeM as MinkGeM 16 | 17 | # from tools.options import Options 18 | # args = Options().parse() 19 | 20 | from tools.utils import set_seed 21 | set_seed(7) 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | class MinkLocOnly(torch.nn.Module): 30 | def __init__(self, in_channels, feature_size, output_dim, planes, layers, num_top_down, conv0_kernel_size, 31 | block='BasicBlock', pooling_method='GeM', linear_block=False, dropout_p=None, 32 | minkfpn='minkfpn'): 33 | # block: Type of the network building block: BasicBlock or SEBasicBlock 34 | # add_linear_layers: Add linear layers at the end 35 | # dropout_p: dropout probability (None = no dropout) 36 | 37 | super().__init__() 38 | self.in_channels = in_channels 39 | self.feature_size = feature_size # Size of local features produced by local feature extraction block 40 | self.output_dim = output_dim # Dimensionality of the global descriptor produced by pooling layer 41 | self.block = block 42 | 43 | if block == 'BasicBlock': 44 | block_module = BasicBlock 45 | elif block == 'Bottleneck': 46 | block_module = Bottleneck 47 | elif block == 'ECABasicBlock': 48 | block_module = ECABasicBlock 49 | else: 50 | raise NotImplementedError('Unsupported network block: {}'.format(block)) 51 | 52 | self.pooling_method = pooling_method 53 | 54 | 55 | self.backbone = MinkFPN(in_channels=in_channels, out_channels=feature_size, num_top_down=num_top_down, 56 | conv0_kernel_size=conv0_kernel_size, block=block_module, layers=layers, planes=planes) 57 | 58 | # self.conv1x1 = nn.Linear(feature_size, output_dim) 59 | 60 | 61 | self.pooling = pooling.PoolingWrapper(pool_method=pooling_method, 62 | in_dim=feature_size, 63 | output_dim=feature_size) 64 | 65 | # self.conv1x1 = nn.Linear(output_dim, output_dim) 66 | 67 | 68 | def forward(self, batch): 69 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 70 | x = ME.SparseTensor(features=batch['features'], coordinates=batch['coords']) 71 | x = self.backbone(x) 72 | 73 | 74 | 75 | 76 | 77 | x_feat = x 78 | # x = self.conv1x1(x.F) 79 | # x = ME.SparseTensor(features=x, 80 | # # coordinates=x_feat.C, 81 | # coordinate_manager=x_feat.coordinate_manager, 82 | # coordinate_map_key=x_feat.coordinate_map_key, 83 | # device=x_feat.device) 84 | x_gem = self.pooling(x) 85 | # x_gem = self.conv1x1(x_gem) 86 | 87 | 88 | return { 89 | 'output_cloud_feat': x_feat, 90 | 'output_cloud_gem': x_gem, 91 | 92 | 'embedding':x_gem 93 | } 94 | 95 | 96 | 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /network/general_minkfpn.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | import torch.nn as nn 12 | import MinkowskiEngine as ME 13 | from layers.eca_block import ECABasicBlock 14 | 15 | from models.resnet import ResNetBase 16 | 17 | 18 | from layers.pooling import MinkGeM as MinkGeM 19 | 20 | 21 | 22 | 23 | 24 | 25 | class GeneralMinkFPN(ResNetBase): 26 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 27 | # in_channels=1, out_channels=128, 1, ECABasicBlock, [1,1,1], [32,64,64] 28 | def __init__(self, in_channels, out_channels, num_top_down=1, conv0_kernel_size=5, block=ECABasicBlock, 29 | layers=(1, 1, 1), planes=(32, 64, 64)): 30 | # assert len(layers) == len(planes) 31 | # assert 1 <= len(layers) 32 | # assert 0 <= num_top_down <= len(layers) 33 | # self.out_channels = out_channels 34 | # self.num_bottom_up = len(layers) 35 | # self.num_top_down = num_top_down 36 | self.conv0_kernel_size = conv0_kernel_size 37 | self.block = block 38 | self.layers = layers 39 | self.planes = planes 40 | # self.lateral_dim = out_channels 41 | # self.init_dim = planes[0] 42 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 43 | 44 | def network_initialization(self, in_channels, out_channels, D): 45 | # assert len(self.layers) == len(self.planes) 46 | # assert len(self.planes) == self.num_bottom_up 47 | 48 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 49 | self.bns = nn.ModuleList() # Bottom-up BatchNorms 50 | self.blocks = nn.ModuleList() # Bottom-up blocks 51 | 52 | 53 | # The first convolution is special case, with kernel size = 5 54 | self.inplanes = self.planes[0] 55 | self.conv1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, dimension=D) 56 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 57 | self.relu = ME.MinkowskiReLU(inplace=True) 58 | 59 | for plane, layer in zip(self.planes, self.layers): 60 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 61 | self.bns.append(ME.MinkowskiBatchNorm(self.inplanes)) 62 | self.blocks.append(self._make_layer(self.block, plane, layer)) 63 | 64 | 65 | self.conv1x1 = ME.MinkowskiConvolution(self.planes[-1], 128, kernel_size=1, stride=1, dimension=D) 66 | 67 | self.mink_gem = MinkGeM(input_dim=128) 68 | 69 | 70 | def forward_backbone(self, x): 71 | feature_maps = [] 72 | 73 | x = self.conv1(x) 74 | x = self.bn1(x) 75 | x = self.relu(x) 76 | 77 | for layer_id, (conv, bn, block) in enumerate(zip(self.convs, self.bns, self.blocks)): 78 | x = conv(x) # Decreases spatial resolution (conv stride=2) 79 | x = bn(x) 80 | x = self.relu(x) 81 | x = block(x) 82 | feature_maps.append(x) 83 | 84 | 85 | return x, feature_maps 86 | 87 | 88 | 89 | def forward(self, data_dict): 90 | 91 | x = ME.SparseTensor(features=data_dict['features'], coordinates=data_dict['coords']) 92 | 93 | 94 | x, feature_maps = self.forward_backbone(x) 95 | 96 | x = self.conv1x1(x) 97 | 98 | x_gem = self.mink_gem(x) 99 | 100 | 101 | 102 | return { 103 | 'x_feat':x, 104 | 'embedding':x_gem 105 | } 106 | 107 | 108 | -------------------------------------------------------------------------------- /datasets/make_collate_fn.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import numpy as np 6 | from datasets.oxford import OxfordDataset 7 | import MinkowskiEngine as ME 8 | 9 | 10 | from tools.utils_adafusion import pc_array_to_voxel 11 | 12 | 13 | from tools.options import Options 14 | args = Options().parse() 15 | from tools.utils import set_seed 16 | set_seed(7) 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | def in_sorted_array(e: int, array: np.ndarray) -> bool: 25 | pos = np.searchsorted(array, e) 26 | if pos == len(array) or pos == -1: 27 | return False 28 | else: 29 | return array[pos] == e 30 | 31 | 32 | 33 | 34 | 35 | def make_collate_fn_bak(dataset: OxfordDataset): 36 | 37 | # set_transform: the transform to be applied to all batch elements 38 | def collate_fn_bak(data_list): 39 | # Constructs a batch object 40 | labels = [e['ndx'] for e in data_list] 41 | 42 | # Compute positives and negatives mask 43 | positives_mask = [[in_sorted_array(e, dataset.queries[label].positives) for e in labels] for label in labels] 44 | negatives_mask = [[not in_sorted_array(e, dataset.queries[label].non_negatives) for e in labels] for label in labels] 45 | positives_mask = torch.tensor(positives_mask) 46 | negatives_mask = torch.tensor(negatives_mask) 47 | 48 | # Returns (batch_size, n_points, 3) tensor and positives_mask and 49 | # negatives_mask which are batch_size x batch_size boolean tensors 50 | 51 | filenames = [e['filename'] for e in data_list] 52 | 53 | result = { 54 | 'positives_mask': positives_mask, 55 | 'negatives_mask': negatives_mask, 56 | 'filenames': filenames 57 | } 58 | 59 | if 'clouds' in data_list[0]: 60 | 61 | coords = [e['coords'] for e in data_list] 62 | clouds = [e['clouds'] for e in data_list] 63 | 64 | coords = ME.utils.batched_coordinates(coords) 65 | clouds = torch.cat(clouds, dim=0) 66 | assert coords.shape[0]==clouds.shape[0] 67 | feats = torch.ones((coords.shape[0], 1), dtype=torch.float32) 68 | 69 | result['coords'] = coords 70 | result['clouds'] = clouds 71 | result['features'] = feats 72 | 73 | if 'image' in data_list[0]: 74 | images = [e['image'] for e in data_list] 75 | result['images'] = torch.stack(images, dim=0) # Produces (N, C, H, W) tensor 76 | 77 | return result 78 | 79 | 80 | return collate_fn_bak 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | def make_collate_fn(dataset: OxfordDataset): 90 | 91 | 92 | 93 | 94 | def collate_fn(data_list): 95 | # Constructs a batch object 96 | labels = [e['ndx'] for e in data_list] 97 | 98 | # Compute positives and negatives mask 99 | positives_mask = [[in_sorted_array(e, dataset.queries[label].positives) for e in labels] for label in labels] 100 | negatives_mask = [[not in_sorted_array(e, dataset.queries[label].non_negatives) for e in labels] for label in labels] 101 | positives_mask = torch.tensor(positives_mask) 102 | negatives_mask = torch.tensor(negatives_mask) 103 | 104 | # Returns (batch_size, n_points, 3) tensor and positives_mask and 105 | # negatives_mask which are batch_size x batch_size boolean tensors 106 | 107 | filenames = [e['filename'] for e in data_list] 108 | 109 | # result = { 110 | # 'positives_mask': positives_mask, 111 | # 'negatives_mask': negatives_mask, 112 | # 'filenames': filenames 113 | # } 114 | 115 | 116 | if 'clouds' in data_list[0]: 117 | coords = [e['coords'] for e in data_list] 118 | clouds = [e['clouds'] for e in data_list] 119 | 120 | if 'image' in data_list[0]: 121 | images = [e['image'] for e in data_list] 122 | 123 | # if 'voxels' in data_list[0]: 124 | # voxels = [e['voxels'] for e in data_list] 125 | 126 | 127 | 128 | big_batch = [] 129 | batch_split_size = args.train_batch_split_size 130 | for i in range(0, len(data_list), batch_split_size): 131 | temp = coords[i:i + batch_split_size] 132 | imgs = images[i:i + batch_split_size] 133 | imgs = torch.stack(imgs, dim=0) 134 | c = ME.utils.batched_coordinates(temp) 135 | f = torch.ones((c.shape[0], 1), dtype=torch.float32) 136 | # v = voxels[i:i + batch_split_size] 137 | # v = torch.stack(v, dim=0) 138 | minibatch = { 139 | 'coords': c, 140 | 'features': f, 141 | 'images': imgs, 142 | # 'voxels': v 143 | } 144 | big_batch.append(minibatch) 145 | 146 | 147 | 148 | 149 | return big_batch, positives_mask, negatives_mask, filenames 150 | 151 | 152 | 153 | return collate_fn 154 | 155 | -------------------------------------------------------------------------------- /layers/pooling.py: -------------------------------------------------------------------------------- 1 | # Pooling methods code based on: https://github.com/filipradenovic/cnnimageretrieval-pytorch 2 | # Global covariance pooling methods implementation taken from: 3 | # https://github.com/jiangtaoxie/fast-MPN-COV 4 | # and ported to MinkowskiEngine by Jacek Komorowski 5 | 6 | import torch 7 | import torch.nn as nn 8 | import MinkowskiEngine as ME 9 | import spconv.pytorch as spconv 10 | from tools.utils import set_seed 11 | set_seed(7) 12 | 13 | class PoolingWrapper(nn.Module): 14 | def __init__(self, pool_method, in_dim, output_dim): 15 | super().__init__() 16 | 17 | self.pool_method = pool_method 18 | self.in_dim = in_dim 19 | self.output_dim = output_dim 20 | # Requires conversion of Minkowski sparse tensor to a batch 21 | self.convert_to_batch = False 22 | 23 | if pool_method == 'MAC': 24 | # Global max pooling 25 | assert in_dim == output_dim 26 | self.pooling = MAC(input_dim=in_dim) 27 | elif pool_method == 'SPoC': 28 | # Global average pooling 29 | assert in_dim == output_dim 30 | self.pooling = SPoC(input_dim=in_dim) 31 | elif pool_method == 'GeM': 32 | # Generalized mean pooling 33 | assert in_dim == output_dim 34 | self.pooling = MinkGeM(input_dim=in_dim) 35 | else: 36 | raise NotImplementedError('Unknown pooling method: {}'.format(pool_method)) 37 | 38 | def forward(self, x: ME.SparseTensor): 39 | if self.convert_to_batch: 40 | x = make_feature_batch(x) 41 | 42 | return self.pooling(x) 43 | 44 | 45 | class MAC(nn.Module): 46 | def __init__(self, input_dim): 47 | super().__init__() 48 | self.input_dim = input_dim 49 | # Same output number of channels as input number of channels 50 | self.output_dim = self.input_dim 51 | self.f = ME.MinkowskiGlobalMaxPooling() 52 | 53 | def forward(self, x: ME.SparseTensor): 54 | x = self.f(x) 55 | return x.F # Return (batch_size, n_features) tensor 56 | 57 | 58 | class SPoC(nn.Module): 59 | def __init__(self, input_dim): 60 | super().__init__() 61 | self.input_dim = input_dim 62 | # Same output number of channels as input number of channels 63 | self.output_dim = self.input_dim 64 | self.f = ME.MinkowskiGlobalAvgPooling() 65 | 66 | def forward(self, x: ME.SparseTensor): 67 | x = self.f(x) 68 | return x.F # Return (batch_size, n_features) tensor 69 | 70 | 71 | class MinkGeM(nn.Module): 72 | def __init__(self, input_dim, p=3, eps=1e-6): 73 | super(MinkGeM, self).__init__() 74 | self.input_dim = input_dim 75 | # Same output number of channels as input number of channels 76 | self.output_dim = self.input_dim 77 | self.p = nn.Parameter(torch.ones(1) * p) 78 | self.eps = eps 79 | self.f = ME.MinkowskiGlobalAvgPooling() 80 | 81 | def forward(self, x: ME.SparseTensor): 82 | assert isinstance(x, ME.SparseTensor) 83 | # This implicitly applies ReLU on x (clamps negative values) 84 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 85 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling 86 | # return temp.F.pow(1./self.p) # Return (batch_size, n_features) tensor 87 | output = temp.F.pow(1./self.p) # Return (batch_size, n_features) tensor 88 | return output 89 | 90 | 91 | 92 | class MinkSpconvGeM(nn.Module): 93 | def __init__(self, input_dim, p=3, eps=1e-6): 94 | super(MinkSpconvGeM, self).__init__() 95 | self.input_dim = input_dim 96 | # Same output number of channels as input number of channels 97 | self.output_dim = self.input_dim 98 | self.p = nn.Parameter(torch.ones(1) * p) 99 | self.eps = eps 100 | self.f = ME.MinkowskiGlobalAvgPooling() 101 | 102 | def forward(self, x): 103 | assert isinstance(x, ME.SparseTensor) or isinstance(x, spconv.SparseConvTensor) 104 | 105 | if isinstance(x, ME.SparseTensor): 106 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 107 | 108 | elif isinstance(x, spconv.SparseConvTensor): 109 | temp = ME.SparseTensor(x.features.clamp(min=self.eps).pow(self.p), coordinates=x.indices) 110 | 111 | temp = self.f(temp) 112 | output = temp.F.pow(1./self.p) 113 | return output 114 | 115 | 116 | 117 | 118 | 119 | 120 | def make_feature_batch(x: ME.SparseTensor): 121 | # Covert sparse features into a batch of size (batch_size, N, channels) padded with zeros to ensure the same 122 | # number of feature in each element 123 | features = x.decomposed_features 124 | 125 | # features is a list of (n_features, channels) tensors with variable number of points 126 | batch_size = len(features) 127 | features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True) 128 | # features is (batch_size, n_features, n_channels) tensor padded with zeros 129 | # features = features.permute(0, 2, 1).contiguous() 130 | # features is (batch_size, n_channels, n_features) tensor padded with zeros 131 | assert features.ndim == 3 132 | return features 133 | -------------------------------------------------------------------------------- /datasets/make_dataloaders.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import MinkowskiEngine as ME 8 | 9 | from datasets.oxford import OxfordDataset 10 | from datasets.augmentation import TrainTransform, TrainSetTransform, TrainRGBTransform, ValRGBTransform 11 | from datasets.samplers import BatchSampler 12 | from tools.utils import MinkLocParams 13 | try: from viz_lidar_mayavi_open3d import * 14 | except: None 15 | 16 | import matplotlib.pyplot as plt 17 | 18 | import torchvision 19 | 20 | import os 21 | 22 | from datasets.make_collate_fn import make_collate_fn 23 | from datasets.make_collate_fn import make_collate_fn_bak 24 | 25 | 26 | from tools.options import Options 27 | args = Options().parse() 28 | from tools.utils import set_seed 29 | set_seed(7) 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | def make_datasets(): 41 | # Create training and validation datasets 42 | datasets = {} 43 | # train_transform = TrainTransform(params.aug_mode) 44 | # train_set_transform = TrainSetTransform(params.aug_mode) 45 | train_transform = TrainTransform(aug_mode=1) 46 | train_set_transform = TrainSetTransform(aug_mode=1) 47 | 48 | 49 | # if params.use_rgb: 50 | # image_train_transform = TrainRGBTransform(aug_mode=1) 51 | # image_val_transform = ValRGBTransform() 52 | # else: 53 | # image_train_transform = None 54 | # image_val_transform = None 55 | image_train_transform = TrainRGBTransform(aug_mode=1) 56 | image_val_transform = ValRGBTransform() 57 | 58 | 59 | if args.dataset == 'oxford': 60 | lidar2image_ndx_path = os.path.join(args.image_path, 'lidar2image_ndx.pickle') 61 | 62 | datasets['train'] = OxfordDataset( 63 | args.dataset_folder, 64 | # query_filename=params.train_file, 65 | query_filename='training_queries_baseline.pickle', 66 | image_path=args.image_path, 67 | # lidar2image_ndx_path=params.lidar2image_ndx_path, 68 | lidar2image_ndx_path=lidar2image_ndx_path, 69 | transform=train_transform, 70 | set_transform=train_set_transform, 71 | image_transform=image_train_transform, 72 | use_cloud=True 73 | ) 74 | 75 | 76 | 77 | 78 | 79 | elif args.dataset == 'oxfordadafusion': 80 | lidar2image_ndx_path = os.path.join(args.image_path, 'oxfordadafusion_lidar2image_ndx.pickle') 81 | 82 | datasets['train'] = OxfordDataset( 83 | args.dataset_folder, 84 | query_filename='oxfordadafusion_training_queries_baseline.pickle', 85 | image_path=args.image_path, 86 | # lidar2image_ndx_path=params.lidar2image_ndx_path, 87 | lidar2image_ndx_path=lidar2image_ndx_path, 88 | transform=train_transform, 89 | set_transform=train_set_transform, 90 | image_transform=image_train_transform, 91 | use_cloud=True 92 | ) 93 | 94 | 95 | 96 | 97 | 98 | elif args.dataset == 'boreas': 99 | lidar2image_ndx_path = os.path.join(args.dataset_folder, 'boreas_lidar2image_ndx.pickle') 100 | assert os.path.exists(lidar2image_ndx_path) 101 | 102 | datasets['train'] = OxfordDataset( 103 | args.dataset_folder, 104 | query_filename='boreas_training_queries_baseline.pickle', 105 | image_path=args.image_path, 106 | # lidar2image_ndx_path=params.lidar2image_ndx_path, 107 | lidar2image_ndx_path=lidar2image_ndx_path, 108 | transform=train_transform, 109 | set_transform=train_set_transform, 110 | image_transform=image_train_transform, 111 | use_cloud=True 112 | ) 113 | 114 | 115 | 116 | else: 117 | raise Exception 118 | 119 | 120 | # a = len(datasets['train']) 121 | # b = len(datasets['val']) 122 | 123 | 124 | 125 | return datasets 126 | 127 | 128 | 129 | 130 | 131 | 132 | def make_dataloaders(): 133 | """ 134 | Create training and validation dataloaders that return groups of k=2 similar elements 135 | :param train_params: 136 | :param model_params: 137 | :return: 138 | """ 139 | datasets = make_datasets() 140 | 141 | 142 | 143 | 144 | dataloders = {} 145 | train_sampler = BatchSampler(datasets['train'], 146 | batch_size = args.train_batch_size, 147 | batch_size_limit = args.train_batch_size, 148 | batch_expansion_rate = None 149 | ) 150 | 151 | # ---- for multi-stage training 152 | train_collate_fn = make_collate_fn( 153 | datasets['train'], 154 | ) 155 | dataloders['train'] = DataLoader(datasets['train'], 156 | batch_sampler=train_sampler, 157 | collate_fn=train_collate_fn, 158 | num_workers=args.num_workers, 159 | pin_memory=True) 160 | 161 | 162 | 163 | 164 | # ---- for single stage training 165 | train_preloading_collate_fn = make_collate_fn_bak( 166 | datasets['train'], 167 | ) 168 | dataloders['train_preloading'] = DataLoader( 169 | datasets['train'], 170 | batch_size=args.train_batch_split_size, 171 | collate_fn=train_preloading_collate_fn, 172 | num_workers=args.num_workers, 173 | pin_memory=True) 174 | 175 | 176 | 177 | 178 | 179 | return dataloders 180 | 181 | 182 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import random 5 | import copy 6 | 7 | from torch.utils.data import Sampler 8 | 9 | from datasets.oxford import OxfordDataset 10 | 11 | from tools.utils import set_seed 12 | set_seed(7) 13 | 14 | class ListDict(object): 15 | def __init__(self, items=None): 16 | if items is not None: 17 | self.items = copy.deepcopy(items) 18 | self.item_to_position = {item: ndx for ndx, item in enumerate(items)} 19 | else: 20 | self.items = [] 21 | self.item_to_position = {} 22 | 23 | def add(self, item): 24 | if item in self.item_to_position: 25 | return 26 | self.items.append(item) 27 | self.item_to_position[item] = len(self.items)-1 28 | 29 | def remove(self, item): 30 | position = self.item_to_position.pop(item) 31 | last_item = self.items.pop() 32 | if position != len(self.items): 33 | self.items[position] = last_item 34 | self.item_to_position[last_item] = position 35 | 36 | def choose_random(self): 37 | return random.choice(self.items) 38 | 39 | def __contains__(self, item): 40 | return item in self.item_to_position 41 | 42 | def __iter__(self): 43 | return iter(self.items) 44 | 45 | def __len__(self): 46 | return len(self.items) 47 | 48 | 49 | class BatchSampler(Sampler): 50 | # Sampler returning list of indices to form a mini-batch 51 | # Samples elements in groups consisting of k=2 similar elements (positives) 52 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 53 | def __init__(self, dataset: OxfordDataset, batch_size: int, batch_size_limit: int = None, 54 | batch_expansion_rate: float = None, max_batches: int = None): 55 | if batch_expansion_rate is not None: 56 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 57 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size' 58 | 59 | self.batch_size = batch_size 60 | self.batch_size_limit = batch_size_limit 61 | self.batch_expansion_rate = batch_expansion_rate 62 | self.max_batches = max_batches 63 | self.dataset = dataset 64 | self.k = 2 # Number of positive examples per group must be 2 65 | if self.batch_size < 2 * self.k: 66 | self.batch_size = 2 * self.k 67 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size)) 68 | 69 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch) 70 | self.elems_ndx = list(self.dataset.queries) # List of point cloud indexes 71 | 72 | def __iter__(self): 73 | # Re-generate batches every epoch 74 | self.generate_batches() 75 | for batch in self.batch_idx: 76 | yield batch 77 | 78 | def __len(self): 79 | return len(self.batch_idx) 80 | 81 | def expand_batch(self): 82 | if self.batch_expansion_rate is None: 83 | print('WARNING: batch_expansion_rate is None') 84 | return 85 | 86 | if self.batch_size >= self.batch_size_limit: 87 | return 88 | 89 | old_batch_size = self.batch_size 90 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 91 | self.batch_size = min(self.batch_size, self.batch_size_limit) 92 | print('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size)) 93 | 94 | def generate_batches(self): 95 | # Generate training/evaluation batches. 96 | # batch_idx holds indexes of elements in each batch as a list of lists 97 | self.batch_idx = [] 98 | 99 | unused_elements_ndx = ListDict(self.elems_ndx) 100 | current_batch = [] 101 | 102 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class' 103 | 104 | while True: 105 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0: 106 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 107 | # elements to process 108 | if len(current_batch) >= 2*self.k: 109 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 110 | # to find negative examples in the batch 111 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch)) 112 | self.batch_idx.append(current_batch) 113 | current_batch = [] 114 | if (self.max_batches is not None) and (len(self.batch_idx) >= self.max_batches): 115 | break 116 | if len(unused_elements_ndx) == 0: 117 | break 118 | 119 | # Add k=2 similar elements to the batch 120 | selected_element = unused_elements_ndx.choose_random() 121 | unused_elements_ndx.remove(selected_element) 122 | positives = self.dataset.get_positives(selected_element) 123 | if len(positives) == 0: 124 | # Broken dataset element without any positives 125 | continue 126 | 127 | unused_positives = [e for e in positives if e in unused_elements_ndx] 128 | # If there're unused elements similar to selected_element, sample from them 129 | # otherwise sample from all similar elements 130 | if len(unused_positives) > 0: 131 | second_positive = random.choice(unused_positives) 132 | unused_elements_ndx.remove(second_positive) 133 | else: 134 | second_positive = random.choice(list(positives)) 135 | 136 | current_batch += [selected_element, second_positive] 137 | 138 | for batch in self.batch_idx: 139 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch)) 140 | 141 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import torch.nn as nn 26 | 27 | import MinkowskiEngine as ME 28 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 29 | 30 | from tools.utils import set_seed 31 | set_seed(7) 32 | 33 | class ResNetBase(nn.Module): 34 | block = None 35 | layers = () 36 | init_dim = 64 37 | planes = (64, 128, 256, 512) 38 | 39 | def __init__(self, in_channels, out_channels, D=3): 40 | nn.Module.__init__(self) 41 | self.D = D 42 | assert self.block is not None 43 | 44 | self.network_initialization(in_channels, out_channels, D) 45 | self.weight_initialization() 46 | 47 | def network_initialization(self, in_channels, out_channels, D): 48 | self.inplanes = self.init_dim 49 | self.conv1 = ME.MinkowskiConvolution( 50 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 51 | 52 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 53 | self.relu = ME.MinkowskiReLU(inplace=True) 54 | 55 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 56 | 57 | self.layer1 = self._make_layer( 58 | self.block, self.planes[0], self.layers[0], stride=2) 59 | self.layer2 = self._make_layer( 60 | self.block, self.planes[1], self.layers[1], stride=2) 61 | self.layer3 = self._make_layer( 62 | self.block, self.planes[2], self.layers[2], stride=2) 63 | self.layer4 = self._make_layer( 64 | self.block, self.planes[3], self.layers[3], stride=2) 65 | 66 | self.conv5 = ME.MinkowskiConvolution( 67 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 68 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 69 | 70 | self.glob_avg = ME.MinkowskiGlobalMaxPooling() 71 | 72 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 73 | 74 | def weight_initialization(self): 75 | for m in self.modules(): 76 | if isinstance(m, ME.MinkowskiConvolution): 77 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 78 | 79 | if isinstance(m, ME.MinkowskiBatchNorm): 80 | nn.init.constant_(m.bn.weight, 1) 81 | nn.init.constant_(m.bn.bias, 0) 82 | 83 | def _make_layer(self, 84 | block, 85 | planes, 86 | blocks, 87 | stride=1, 88 | dilation=1, 89 | bn_momentum=0.1): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | ME.MinkowskiConvolution( 94 | self.inplanes, 95 | planes * block.expansion, 96 | kernel_size=1, 97 | stride=stride, 98 | dimension=self.D), 99 | ME.MinkowskiBatchNorm(planes * block.expansion)) 100 | layers = [] 101 | layers.append( 102 | block( 103 | self.inplanes, 104 | planes, 105 | stride=stride, 106 | dilation=dilation, 107 | downsample=downsample, 108 | dimension=self.D)) 109 | self.inplanes = planes * block.expansion 110 | for i in range(1, blocks): 111 | layers.append( 112 | block( 113 | self.inplanes, 114 | planes, 115 | stride=1, 116 | dilation=dilation, 117 | dimension=self.D)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.pool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.conv5(x) 133 | x = self.bn5(x) 134 | x = self.relu(x) 135 | 136 | x = self.glob_avg(x) 137 | return self.final(x) 138 | 139 | 140 | class ResNet14(ResNetBase): 141 | BLOCK = BasicBlock 142 | LAYERS = (1, 1, 1, 1) 143 | 144 | 145 | class ResNet18(ResNetBase): 146 | BLOCK = BasicBlock 147 | LAYERS = (2, 2, 2, 2) 148 | 149 | 150 | class ResNet34(ResNetBase): 151 | BLOCK = BasicBlock 152 | LAYERS = (3, 4, 6, 3) 153 | 154 | 155 | class ResNet50(ResNetBase): 156 | BLOCK = Bottleneck 157 | LAYERS = (3, 4, 6, 3) 158 | 159 | 160 | class ResNet101(ResNetBase): 161 | BLOCK = Bottleneck 162 | LAYERS = (3, 4, 23, 3) 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | def make_layer(inplanes, 171 | D, 172 | block, 173 | planes, 174 | blocks, 175 | stride=1, 176 | dilation=1, 177 | bn_momentum=0.1): 178 | 179 | downsample = None 180 | if stride != 1 or inplanes != planes * block.expansion: 181 | downsample = nn.Sequential( 182 | ME.MinkowskiConvolution( 183 | inplanes, 184 | planes * block.expansion, 185 | kernel_size=1, 186 | stride=stride, 187 | dimension=D), 188 | ME.MinkowskiBatchNorm(planes * block.expansion)) 189 | layers = [] 190 | layers.append( 191 | block( 192 | inplanes, 193 | planes, 194 | stride=stride, 195 | dilation=dilation, 196 | downsample=downsample, 197 | dimension=D)) 198 | inplanes = planes * block.expansion 199 | for i in range(1, blocks): 200 | layers.append( 201 | block( 202 | inplanes, 203 | planes, 204 | stride=1, 205 | dilation=dilation, 206 | dimension=D)) 207 | 208 | return nn.Sequential(*layers) 209 | -------------------------------------------------------------------------------- /hyptorch/nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | 7 | import hyptorch.pmath as pmath 8 | 9 | 10 | class HyperbolicMLR(nn.Module): 11 | r""" 12 | Module which performs softmax classification 13 | in Hyperbolic space. 14 | """ 15 | 16 | def __init__(self, ball_dim, n_classes, c): 17 | super(HyperbolicMLR, self).__init__() 18 | self.a_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 19 | self.p_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 20 | self.c = c 21 | self.n_classes = n_classes 22 | self.ball_dim = ball_dim 23 | self.reset_parameters() 24 | 25 | def forward(self, x, c=None): 26 | if c is None: 27 | c = torch.as_tensor(self.c).type_as(x) 28 | else: 29 | c = torch.as_tensor(c).type_as(x) 30 | p_vals_poincare = pmath.expmap0(self.p_vals, c=c) 31 | conformal_factor = 1 - c * p_vals_poincare.pow(2).sum(dim=1, keepdim=True) 32 | a_vals_poincare = self.a_vals * conformal_factor 33 | logits = pmath._hyperbolic_softmax(x, a_vals_poincare, p_vals_poincare, c) 34 | return logits 35 | 36 | def extra_repr(self): 37 | return "Poincare ball dim={}, n_classes={}, c={}".format( 38 | self.ball_dim, self.n_classes, self.c 39 | ) 40 | 41 | def reset_parameters(self): 42 | init.kaiming_uniform_(self.a_vals, a=math.sqrt(5)) 43 | init.kaiming_uniform_(self.p_vals, a=math.sqrt(5)) 44 | 45 | 46 | class HypLinear(nn.Module): 47 | def __init__(self, in_features, out_features, c, bias=True): 48 | super(HypLinear, self).__init__() 49 | self.in_features = in_features 50 | self.out_features = out_features 51 | self.c = c 52 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 53 | if bias: 54 | self.bias = nn.Parameter(torch.Tensor(out_features)) 55 | else: 56 | self.register_parameter("bias", None) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 61 | if self.bias is not None: 62 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 63 | bound = 1 / math.sqrt(fan_in) 64 | init.uniform_(self.bias, -bound, bound) 65 | 66 | def forward(self, x, c=None): 67 | if c is None: 68 | c = self.c 69 | mv = pmath.mobius_matvec(self.weight, x, c=c) 70 | if self.bias is None: 71 | return pmath.project(mv, c=c) 72 | else: 73 | bias = pmath.expmap0(self.bias, c=c) 74 | return pmath.project(pmath.mobius_add(mv, bias), c=c) 75 | 76 | def extra_repr(self): 77 | return "in_features={}, out_features={}, bias={}, c={}".format( 78 | self.in_features, self.out_features, self.bias is not None, self.c 79 | ) 80 | 81 | 82 | class ConcatPoincareLayer(nn.Module): 83 | def __init__(self, d1, d2, d_out, c): 84 | super(ConcatPoincareLayer, self).__init__() 85 | self.d1 = d1 86 | self.d2 = d2 87 | self.d_out = d_out 88 | 89 | self.l1 = HypLinear(d1, d_out, bias=False, c=c) 90 | self.l2 = HypLinear(d2, d_out, bias=False, c=c) 91 | self.c = c 92 | 93 | def forward(self, x1, x2, c=None): 94 | if c is None: 95 | c = self.c 96 | return pmath.mobius_add(self.l1(x1), self.l2(x2), c=c) 97 | 98 | def extra_repr(self): 99 | return "dims {} and {} ---> dim {}".format(self.d1, self.d2, self.d_out) 100 | 101 | 102 | class HyperbolicDistanceLayer(nn.Module): 103 | def __init__(self, c): 104 | super(HyperbolicDistanceLayer, self).__init__() 105 | self.c = c 106 | 107 | def forward(self, x1, x2, c=None): 108 | if c is None: 109 | c = self.c 110 | return pmath.dist(x1, x2, c=c, keepdim=True) 111 | 112 | def extra_repr(self): 113 | return "c={}".format(self.c) 114 | 115 | 116 | class ToPoincare(nn.Module): 117 | r""" 118 | Module which maps points in n-dim Euclidean space 119 | to n-dim Poincare ball 120 | Also implements clipping from https://arxiv.org/pdf/2107.11472.pdf 121 | """ 122 | 123 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None, riemannian=True, clip_r=None): 124 | super(ToPoincare, self).__init__() 125 | if train_x: 126 | if ball_dim is None: 127 | raise ValueError( 128 | "if train_x=True, ball_dim has to be integer, got {}".format( 129 | ball_dim 130 | ) 131 | ) 132 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 133 | else: 134 | self.register_parameter("xp", None) 135 | 136 | if train_c: 137 | self.c = nn.Parameter(torch.Tensor([c,])) 138 | else: 139 | self.c = c 140 | 141 | self.train_x = train_x 142 | 143 | self.riemannian = pmath.RiemannianGradient 144 | self.riemannian.c = c 145 | 146 | self.clip_r = clip_r 147 | 148 | if riemannian: 149 | self.grad_fix = lambda x: self.riemannian.apply(x) 150 | else: 151 | self.grad_fix = lambda x: x 152 | 153 | def forward(self, x): 154 | if self.clip_r is not None: 155 | x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5 156 | fac = torch.minimum( 157 | torch.ones_like(x_norm), 158 | self.clip_r / x_norm 159 | ) 160 | x = x * fac 161 | 162 | if self.train_x: 163 | xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) 164 | return self.grad_fix(pmath.project(pmath.expmap(xp, x, c=self.c), c=self.c)) 165 | return self.grad_fix(pmath.project(pmath.expmap0(x, c=self.c), c=self.c)) 166 | 167 | def extra_repr(self): 168 | return "c={}, train_x={}".format(self.c, self.train_x) 169 | 170 | 171 | class FromPoincare(nn.Module): 172 | r""" 173 | Module which maps points in n-dim Poincare ball 174 | to n-dim Euclidean space 175 | """ 176 | 177 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None): 178 | 179 | super(FromPoincare, self).__init__() 180 | 181 | if train_x: 182 | if ball_dim is None: 183 | raise ValueError( 184 | "if train_x=True, ball_dim has to be integer, got {}".format( 185 | ball_dim 186 | ) 187 | ) 188 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 189 | else: 190 | self.register_parameter("xp", None) 191 | 192 | if train_c: 193 | self.c = nn.Parameter(torch.Tensor([c,])) 194 | else: 195 | self.c = c 196 | 197 | self.train_c = train_c 198 | self.train_x = train_x 199 | 200 | def forward(self, x): 201 | if self.train_x: 202 | xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) 203 | return pmath.logmap(xp, x, c=self.c) 204 | return pmath.logmap0(x, c=self.c) 205 | 206 | def extra_repr(self): 207 | return "train_c={}, train_x={}".format(self.train_c, self.train_x) 208 | 209 | 210 | -------------------------------------------------------------------------------- /network/general_imagefes.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as TVmodels 5 | import MinkowskiEngine as ME 6 | 7 | from network.image_pool_fns import ImageGeM 8 | from network.image_pool_fns import ImageCosPlace 9 | from network.image_pool_fns import ImageNetVLAD 10 | from network.image_pool_fns import ImageConvAP 11 | 12 | import torch.nn.functional as F 13 | 14 | from models.minkloc import MinkLoc 15 | from tools.utils import set_seed 16 | set_seed(7) 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | class GeneralImageFE(torch.nn.Module): 28 | def __init__(self, 29 | image_fe, 30 | num_other_stage_blocks, 31 | num_stage3_blocks, 32 | image_pool_method, # GeM 33 | image_useallstages, # True 34 | output_dim, 35 | ): 36 | super().__init__() 37 | ''' 38 | resnet [64,64,128,256,512] 39 | convnext [96,96,192,384,768] 40 | swin [96,96,192,384,768] 41 | swin_v2 [96,96,192,384,768] 42 | ''' 43 | 44 | 45 | self.image_fe = image_fe 46 | self.num_other_stage_blocks = num_other_stage_blocks 47 | self.num_stage3_blocks = num_stage3_blocks 48 | self.image_pool_method = image_pool_method 49 | self.image_useallstages = image_useallstages 50 | self.output_dim = output_dim 51 | 52 | 53 | 54 | # -- resnet 55 | if self.image_fe == 'resnet18': 56 | self.model = TVmodels.resnet18(weights='IMAGENET1K_V1') 57 | if self.image_useallstages: 58 | self.last_dim = 512 59 | else: 60 | self.last_dim = 256 61 | elif self.image_fe == 'resnet34': 62 | self.model = TVmodels.resnet34(weights='IMAGENET1K_V1') 63 | if self.image_useallstages: 64 | self.last_dim = 512 65 | else: 66 | self.last_dim = 256 67 | elif self.image_fe == 'resnet50': 68 | self.model = TVmodels.resnet50(weights='IMAGENET1K_V2') 69 | if self.image_useallstages: 70 | self.last_dim = 2048 71 | else: 72 | self.last_dim = 1024 73 | 74 | 75 | # -- convnext 76 | elif self.image_fe == 'convnext_tiny': 77 | self.model = TVmodels.convnext_tiny(weights='IMAGENET1K_V1') 78 | if self.image_useallstages: 79 | self.last_dim = 768 80 | else: 81 | self.last_dim = 384 82 | elif self.image_fe == 'convnext_small': 83 | self.model = TVmodels.convnext_small(weights='IMAGENET1K_V1') 84 | self.last_dim = 384 85 | 86 | 87 | # -- swin 88 | elif self.image_fe == 'swin_t': 89 | self.model = TVmodels.swin_t(weights='IMAGENET1K_V1') 90 | self.last_dim = 384 91 | elif self.image_fe == 'swin_s': 92 | self.model = TVmodels.swin_s(weights='IMAGENET1K_V1') 93 | self.last_dim = 384 94 | elif self.image_fe == 'swin_v2_t': 95 | self.model = TVmodels.swin_v2_t(weights='IMAGENET1K_V1') 96 | self.last_dim = 384 97 | elif self.image_fe == 'swin_v2_s': 98 | self.model = TVmodels.swin_v2_s(weights='IMAGENET1K_V1') 99 | self.last_dim = 384 100 | 101 | 102 | 103 | 104 | 105 | self.conv1x1 = nn.Conv2d(self.last_dim, output_dim, kernel_size=1) 106 | 107 | 108 | self.image_gem = ImageGeM() # *1 109 | self.imagecosplace = ImageCosPlace(output_dim, output_dim) # *1 110 | self.imageconvap = ImageConvAP(output_dim, output_dim) # *4 111 | self.imagenetvlad = ImageNetVLAD(clusters_num=64, 112 | dim=output_dim) # *4 113 | 114 | 115 | 116 | 117 | def forward_resnet(self, x): 118 | fe_output_dict = {} 119 | x = self.model.conv1(x) 120 | x = self.model.bn1(x) 121 | x = self.model.relu(x) 122 | x = self.model.maxpool(x) 123 | 124 | x = self.model.layer1(x) 125 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 126 | fe_output_dict['image_layer1'] = x 127 | fe_output_dict['image_layer1_avgpool'] = x_avgpool 128 | 129 | x = self.model.layer2(x) 130 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 131 | fe_output_dict['image_layer2'] = x 132 | fe_output_dict['image_layer2_avgpool'] = x_avgpool 133 | 134 | x = self.model.layer3(x) 135 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 136 | fe_output_dict['image_layer3'] = x 137 | fe_output_dict['image_layer3_avgpool'] = x_avgpool 138 | 139 | if self.image_useallstages: 140 | x = self.model.layer4(x) 141 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 142 | fe_output_dict['image_layer4'] = x 143 | fe_output_dict['image_layer4_avgpool'] = x_avgpool 144 | 145 | return x, fe_output_dict 146 | 147 | 148 | 149 | def forward_convnext(self, x): 150 | fe_output_dict = {} 151 | layers_list = list(self.model.features.children()) 152 | assert len(layers_list)==8 153 | if not self.image_useallstages: 154 | layers_list[1] = layers_list[1] 155 | layers_list[3] = layers_list[3] 156 | layers_list[5] = layers_list[5] 157 | layers_list = layers_list[:-2] 158 | else: 159 | layers_list = layers_list 160 | 161 | for i in range(len(layers_list)): 162 | layer = layers_list[i] 163 | x = layer(x) 164 | return x, fe_output_dict 165 | 166 | 167 | def forward_swin(self, x): 168 | fe_output_dict = {} 169 | layers_list = list(self.model.features.children()) 170 | if not self.image_useallstages: 171 | layers_list = layers_list[:-2] 172 | else: 173 | layers_list = layers_list 174 | for i in range(len(layers_list)): 175 | layer = layers_list[i] 176 | x = layer(x) 177 | x = x.permute(0,3,1,2) 178 | return x, fe_output_dict 179 | 180 | 181 | 182 | 183 | 184 | 185 | def forward(self, data_dict): 186 | 187 | 188 | x = data_dict['images'] 189 | 190 | 191 | if self.image_fe in ['resnet18','resnet34','resnet50']: 192 | x, fe_output_dict = self.forward_resnet(x) 193 | elif self.image_fe in ['convnext_tiny','convnext_small']: 194 | x, fe_output_dict = self.forward_convnext(x) 195 | elif self.image_fe in ['swin_t','swin_s']: 196 | x, fe_output_dict = self.forward_swin(x) 197 | elif self.image_fe in ['swin_v2_t','swin_v2_s']: 198 | x, fe_output_dict = self.forward_swin(x) 199 | else: 200 | raise NotImplementedError 201 | 202 | 203 | x_feat_256 = x 204 | x_feat_256 = self.conv1x1(x_feat_256) 205 | 206 | 207 | if self.image_pool_method == 'GeM': 208 | embedding = self.image_gem(x_feat_256) 209 | 210 | elif self.image_pool_method == 'ConvAP': 211 | embedding = self.imageconvap(x_feat_256) 212 | 213 | elif self.image_pool_method == 'CosPlace': 214 | embedding = self.imagecosplace(x_feat_256) 215 | 216 | elif self.image_pool_method == 'NetVLAD': 217 | embedding = self.imagenetvlad(x_feat_256) 218 | 219 | else: 220 | raise NotImplementedError 221 | 222 | 223 | 224 | 225 | output_dict = { 226 | 'output_image_feat': x_feat_256, 227 | 'output_image_gem': embedding, 228 | 229 | 'embedding': embedding, 230 | } 231 | 232 | for _k, _v in fe_output_dict.items(): 233 | output_dict[_k] = _v 234 | 235 | 236 | return output_dict 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | -------------------------------------------------------------------------------- /models/minkfpn.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import torch.nn as nn 5 | import MinkowskiEngine as ME 6 | from MinkowskiEngine.modules.resnet_block import BasicBlock 7 | from layers.eca_block import ECABasicBlock 8 | 9 | from models.resnet import ResNetBase 10 | from models.resnet import make_layer 11 | import torch 12 | 13 | from tools.utils import set_seed 14 | set_seed(7) 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | class MinkFPN(ResNetBase): 23 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 24 | # in_channels=1, out_channels=128, 1, ECABasicBlock, [1,1,1], [32,64,64] 25 | def __init__(self, in_channels, out_channels, num_top_down=1, conv0_kernel_size=5, block=BasicBlock, 26 | layers=(1, 1, 1), planes=(32, 64, 64)): 27 | assert len(layers) == len(planes) 28 | assert 1 <= len(layers) 29 | assert 0 <= num_top_down <= len(layers) 30 | self.num_bottom_up = len(layers) 31 | self.num_top_down = num_top_down 32 | self.conv0_kernel_size = conv0_kernel_size 33 | self.block = block 34 | self.layers = layers 35 | self.planes = planes 36 | self.lateral_dim = out_channels 37 | self.init_dim = planes[0] 38 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 39 | 40 | def network_initialization(self, in_channels, out_channels, D): 41 | assert len(self.layers) == len(self.planes) 42 | assert len(self.planes) == self.num_bottom_up 43 | 44 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 45 | self.bns = nn.ModuleList() # Bottom-up BatchNorms 46 | self.blocks = nn.ModuleList() # Bottom-up blocks 47 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions 48 | self.conv1x1s = nn.ModuleList() # 1x1 convolutions in lateral connections 49 | 50 | # The first convolution is special case, with kernel size = 5 51 | self.inplanes = self.planes[0] 52 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, 53 | dimension=D) 54 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 55 | 56 | for plane, layer in zip(self.planes, self.layers): 57 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 58 | self.bns.append(ME.MinkowskiBatchNorm(self.inplanes)) 59 | self.blocks.append(self._make_layer(self.block, plane, layer)) 60 | 61 | # Lateral connections 62 | for i in range(self.num_top_down): 63 | self.conv1x1s.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1, 64 | stride=1, dimension=D)) 65 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2, 66 | stride=2, dimension=D)) 67 | # There's one more lateral connection than top-down TConv blocks 68 | if self.num_top_down < self.num_bottom_up: 69 | # Lateral connection from Conv block 1 or above 70 | self.conv1x1s.append(ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1, 71 | stride=1, dimension=D)) 72 | else: 73 | # Lateral connection from Con0 block 74 | self.conv1x1s.append(ME.MinkowskiConvolution(self.planes[0], self.lateral_dim, kernel_size=1, 75 | stride=1, dimension=D)) 76 | 77 | self.relu = ME.MinkowskiReLU(inplace=True) 78 | 79 | 80 | 81 | 82 | def forward(self, x): 83 | identity = x 84 | # *** BOTTOM-UP PASS *** 85 | # First bottom-up convolution is special (with bigger stride) 86 | feature_maps = [] 87 | x = self.conv0(x) # 32 88 | x = self.bn0(x) 89 | x = self.relu(x) 90 | if self.num_top_down == self.num_bottom_up: 91 | feature_maps.append(x) 92 | 93 | # BOTTOM-UP PASS 94 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bns, self.blocks)): 95 | x = conv(x) # 32 # Decreases spatial resolution (conv stride=2) 96 | x = bn(x) 97 | x = self.relu(x) 98 | x = block(x) 99 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1: 100 | feature_maps.append(x) 101 | 102 | assert len(feature_maps) == self.num_top_down 103 | 104 | x = self.conv1x1s[0](x) 105 | 106 | # TOP-DOWN PASS 107 | for ndx, tconv in enumerate(self.tconvs): 108 | x = tconv(x) # Upsample using transposed convolution 109 | x = x + self.conv1x1s[ndx+1](feature_maps[-ndx - 1]) 110 | 111 | return x 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | class GeneralMinkFPN(ResNetBase): 124 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 125 | # in_channels=1, out_channels=128, 1, ECABasicBlock, [1,1,1], [32,64,64] 126 | def __init__(self, in_channels, out_channels, num_top_down=1, conv0_kernel_size=5, block=ECABasicBlock, 127 | layers=(1, 1, 1, 1), planes=(32, 64, 64, 64)): 128 | assert len(layers) == len(planes) 129 | assert 1 <= len(layers) 130 | assert 0 <= num_top_down <= len(layers) 131 | self.out_channels = out_channels 132 | self.num_bottom_up = len(layers) 133 | self.num_top_down = num_top_down 134 | self.conv0_kernel_size = conv0_kernel_size 135 | self.block = block 136 | self.layers = layers 137 | self.planes = planes 138 | self.lateral_dim = out_channels 139 | self.init_dim = planes[0] 140 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 141 | 142 | def network_initialization(self, in_channels, out_channels, D): 143 | assert len(self.layers) == len(self.planes) 144 | assert len(self.planes) == self.num_bottom_up 145 | 146 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 147 | self.bns = nn.ModuleList() # Bottom-up BatchNorms 148 | self.blocks = nn.ModuleList() # Bottom-up blocks 149 | 150 | 151 | # The first convolution is special case, with kernel size = 5 152 | self.inplanes = self.planes[0] 153 | self.conv1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, dimension=D) 154 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 155 | self.relu = ME.MinkowskiReLU(inplace=True) 156 | 157 | for plane, layer in zip(self.planes, self.layers): 158 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 159 | self.bns.append(ME.MinkowskiBatchNorm(self.inplanes)) 160 | self.blocks.append(self._make_layer(self.block, plane, layer)) 161 | 162 | 163 | self.conv1x1 = ME.MinkowskiConvolution(self.planes[-1], self.out_channels, kernel_size=1, stride=1, dimension=D) 164 | 165 | 166 | 167 | def forward_backbone(self, x): 168 | feature_maps = [] 169 | 170 | x = self.conv1(x) 171 | x = self.bn1(x) 172 | x = self.relu(x) 173 | 174 | for layer_id, (conv, bn, block) in enumerate(zip(self.convs, self.bns, self.blocks)): 175 | x = conv(x) # Decreases spatial resolution (conv stride=2) 176 | x = bn(x) 177 | x = self.relu(x) 178 | x = block(x) 179 | feature_maps.append(x) 180 | 181 | 182 | return feature_maps 183 | 184 | 185 | 186 | def forward(self, x): 187 | 188 | 189 | feature_maps = self.forward_backbone(x) 190 | 191 | x = self.conv1x1(feature_maps[-1]) 192 | 193 | # decomposed_features = x.decomposed_features 194 | # a = torch.nn.utils.rnn.pad_sequence(decomposed_features, batch_first=True) 195 | 196 | 197 | return x 198 | 199 | 200 | -------------------------------------------------------------------------------- /network/image_pool_fns.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | 8 | 9 | 10 | class ImageGeM(nn.Module): 11 | def __init__(self, p=3, eps=1e-6): 12 | super(ImageGeM, self).__init__() 13 | self.p = nn.Parameter(torch.ones(1) * p) 14 | self.eps = eps 15 | 16 | def forward(self, x): 17 | assert len(x.shape) == 4 18 | output = nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 19 | 20 | 21 | b,c,h,w = output.shape 22 | assert [h,w]==[1,1] 23 | 24 | 25 | output = output.view(b,c) 26 | 27 | 28 | return output 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | # ---------------------------- CosPlace ---------------------------- 39 | class GeM(nn.Module): 40 | """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch 41 | """ 42 | def __init__(self, p=3, eps=1e-6): 43 | super().__init__() 44 | self.p = nn.Parameter(torch.ones(1)*p) 45 | self.eps = eps 46 | 47 | def forward(self, x): 48 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 49 | 50 | class ImageCosPlace(nn.Module): 51 | """ 52 | CosPlace aggregation layer as implemented in https://github.com/gmberton/CosPlace/blob/main/model/network.py 53 | 54 | Args: 55 | in_dim: number of channels of the input 56 | out_dim: dimension of the output descriptor 57 | """ 58 | def __init__(self, in_dim, out_dim): 59 | super().__init__() 60 | self.gem = GeM() 61 | self.fc = nn.Linear(in_dim, out_dim) 62 | 63 | def forward(self, x): 64 | x = F.normalize(x, p=2, dim=1) 65 | x = self.gem(x) 66 | x = x.flatten(1) 67 | x = self.fc(x) 68 | x = F.normalize(x, p=2, dim=1) 69 | return x 70 | 71 | 72 | 73 | 74 | # ---------------------------- MixVPR ---------------------------- 75 | class FeatureMixerLayer(nn.Module): 76 | def __init__(self, in_dim, mlp_ratio=1): 77 | super().__init__() 78 | self.mix = nn.Sequential( 79 | nn.LayerNorm(in_dim), 80 | nn.Linear(in_dim, int(in_dim * mlp_ratio)), 81 | nn.ReLU(), 82 | nn.Linear(int(in_dim * mlp_ratio), in_dim), 83 | ) 84 | 85 | for m in self.modules(): 86 | if isinstance(m, (nn.Linear)): 87 | nn.init.trunc_normal_(m.weight, std=0.02) 88 | if m.bias is not None: 89 | nn.init.zeros_(m.bias) 90 | 91 | def forward(self, x): 92 | return x + self.mix(x) 93 | 94 | 95 | class ImageMixVPR(nn.Module): 96 | def __init__(self, 97 | in_channels=1024, 98 | in_h=20, 99 | in_w=20, 100 | out_channels=512, 101 | mix_depth=1, 102 | mlp_ratio=1, 103 | out_rows=4, 104 | ) -> None: 105 | super().__init__() 106 | 107 | self.in_h = in_h # height of input feature maps 108 | self.in_w = in_w # width of input feature maps 109 | self.in_channels = in_channels # depth of input feature maps 110 | 111 | self.out_channels = out_channels # depth wise projection dimension 112 | self.out_rows = out_rows # row wise projection dimesion 113 | 114 | self.mix_depth = mix_depth # L the number of stacked FeatureMixers 115 | self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block 116 | 117 | hw = in_h*in_w 118 | self.mix = nn.Sequential(*[ 119 | FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio) 120 | for _ in range(self.mix_depth) 121 | ]) 122 | self.channel_proj = nn.Linear(in_channels, out_channels) 123 | self.row_proj = nn.Linear(hw, out_rows) 124 | 125 | def forward(self, x): 126 | x = x.flatten(2) 127 | x = self.mix(x) 128 | x = x.permute(0, 2, 1) 129 | x = self.channel_proj(x) 130 | x = x.permute(0, 2, 1) 131 | x = self.row_proj(x) 132 | x = F.normalize(x.flatten(1), p=2, dim=-1) 133 | return x 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | # ---------------------------- ConvAP ---------------------------- 144 | 145 | class ImageConvAP(nn.Module): 146 | """Implementation of ConvAP as of https://arxiv.org/pdf/2210.10239.pdf 147 | 148 | Args: 149 | in_channels (int): number of channels in the input of ConvAP 150 | out_channels (int, optional): number of channels that ConvAP outputs. Defaults to 512. 151 | s1 (int, optional): spatial height of the adaptive average pooling. Defaults to 2. 152 | s2 (int, optional): spatial width of the adaptive average pooling. Defaults to 2. 153 | """ 154 | def __init__(self, in_channels, out_channels=512, s1=2, s2=2): 155 | super(ImageConvAP, self).__init__() 156 | self.channel_pool = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True) 157 | self.AAP = nn.AdaptiveAvgPool2d((s1, s2)) 158 | 159 | def forward(self, x): 160 | x = self.channel_pool(x) 161 | x = self.AAP(x) 162 | x = F.normalize(x.flatten(1), p=2, dim=1) 163 | return x 164 | 165 | 166 | 167 | def print_nb_params(m): 168 | model_parameters = filter(lambda p: p.requires_grad, m.parameters()) 169 | params = sum([np.prod(p.size()) for p in model_parameters]) 170 | print(f'Trainable parameters: {params/1e6:.3}M') 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | # ---------------------------- NetVLAD ---------------------------- 182 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py 183 | class ImageNetVLAD(nn.Module): 184 | """NetVLAD layer implementation""" 185 | 186 | def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False): 187 | """ 188 | Args: 189 | clusters_num : int 190 | The number of clusters 191 | dim : int 192 | Dimension of descriptors 193 | alpha : float 194 | Parameter of initialization. Larger value is harder assignment. 195 | normalize_input : bool 196 | If true, descriptor-wise L2 normalization is applied to input. 197 | """ 198 | super().__init__() 199 | self.clusters_num = clusters_num 200 | self.dim = dim 201 | self.alpha = 0 202 | self.normalize_input = normalize_input 203 | self.work_with_tokens = work_with_tokens 204 | if work_with_tokens: 205 | self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False) 206 | else: 207 | self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False) 208 | self.centroids = nn.Parameter(torch.rand(clusters_num, dim)) 209 | 210 | def init_params(self, centroids, descriptors): 211 | centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True) 212 | dots = np.dot(centroids_assign, descriptors.T) 213 | dots.sort(0) 214 | dots = dots[::-1, :] # sort, descending 215 | 216 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() 217 | self.centroids = nn.Parameter(torch.from_numpy(centroids)) 218 | if self.work_with_tokens: 219 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2)) 220 | else: 221 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3)) 222 | self.conv.bias = None 223 | 224 | def forward(self, x): 225 | if self.work_with_tokens: 226 | x = x.permute(0, 2, 1) 227 | N, D, _ = x.shape[:] 228 | else: 229 | N, D, H, W = x.shape[:] 230 | if self.normalize_input: 231 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim 232 | x_flatten = x.view(N, D, -1) 233 | soft_assign = self.conv(x).view(N, self.clusters_num, -1) 234 | soft_assign = F.softmax(soft_assign, dim=1) 235 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) 236 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage 237 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 238 | self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 239 | residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2) 240 | vlad[:,D:D+1,:] = residual.sum(dim=-1) 241 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 242 | vlad = vlad.view(N, -1) # Flatten 243 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 244 | return vlad 245 | 246 | 247 | 248 | -------------------------------------------------------------------------------- /tools/utils_minkloc3dv2.py: -------------------------------------------------------------------------------- 1 | # Warsaw University of Technology 2 | 3 | import os 4 | import configparser 5 | import time 6 | import numpy as np 7 | 8 | # from datasets.quantization import PolarQuantizer, CartesianQuantizer 9 | import torch 10 | import random 11 | 12 | 13 | 14 | 15 | 16 | def get_datetime(): 17 | return time.strftime("%Y%m%d_%H%M") 18 | 19 | 20 | 21 | def set_seed(seed=7): 22 | # seed = 7 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | os.environ['PYTHONHASHSEED'] = str(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.backends.cudnn.benchmark = False 30 | torch.backends.cudnn.deterministic = True 31 | 32 | set_seed(7) 33 | 34 | 35 | 36 | 37 | 38 | class ModelParams: 39 | def __init__(self, model_params_path): 40 | config = configparser.ConfigParser() 41 | config.read(model_params_path) 42 | params = config['MODEL'] 43 | 44 | self.model_params_path = model_params_path 45 | self.model = params.get('model') 46 | self.output_dim = params.getint('output_dim', 256) # Size of the final descriptor 47 | 48 | ####################################################################### 49 | # Model dependent 50 | ####################################################################### 51 | 52 | self.coordinates = params.get('coordinates', 'polar') 53 | assert self.coordinates in ['polar', 'cartesian'], f'Unsupported coordinates: {self.coordinates}' 54 | 55 | 56 | 57 | # if 'polar' in self.coordinates: 58 | # # 3 quantization steps for polar coordinates: for sectors (in degrees), rings (in meters) and z coordinate (in meters) 59 | # self.quantization_step = tuple([float(e) for e in params['quantization_step'].split(',')]) 60 | # assert len(self.quantization_step) == 3, f'Expected 3 quantization steps: for sectors (degrees), rings (meters) and z coordinate (meters)' 61 | # self.quantizer = PolarQuantizer(quant_step=self.quantization_step) 62 | # elif 'cartesian' in self.coordinates: 63 | # # Single quantization step for cartesian coordinates 64 | # self.quantization_step = params.getfloat('quantization_step') 65 | # self.quantizer = CartesianQuantizer(quant_step=self.quantization_step) 66 | # else: 67 | # raise NotImplementedError(f"Unsupported coordinates: {self.coordinates}") 68 | 69 | 70 | 71 | 72 | # Use cosine similarity instead of Euclidean distance 73 | # When Euclidean distance is used, embedding normalization is optional 74 | self.normalize_embeddings = params.getboolean('normalize_embeddings', False) 75 | 76 | # Size of the local features from backbone network (only for MinkNet based models) 77 | self.feature_size = params.getint('feature_size', 256) 78 | if 'planes' in params: 79 | self.planes = tuple([int(e) for e in params['planes'].split(',')]) 80 | else: 81 | self.planes = tuple([32, 64, 64]) 82 | 83 | if 'layers' in params: 84 | self.layers = tuple([int(e) for e in params['layers'].split(',')]) 85 | else: 86 | self.layers = tuple([1, 1, 1]) 87 | 88 | self.num_top_down = params.getint('num_top_down', 1) 89 | self.conv0_kernel_size = params.getint('conv0_kernel_size', 5) 90 | self.block = params.get('block', 'BasicBlock') 91 | self.pooling = params.get('pooling', 'GeM') 92 | 93 | def print(self): 94 | print('Model parameters:') 95 | param_dict = vars(self) 96 | for e in param_dict: 97 | if e == 'quantization_step': 98 | s = param_dict[e] 99 | if self.coordinates == 'polar': 100 | print(f'quantization_step - sector: {s[0]} [deg] / ring: {s[1]} [m] / z: {s[2]} [m]') 101 | else: 102 | print(f'quantization_step: {s} [m]') 103 | else: 104 | print('{}: {}'.format(e, param_dict[e])) 105 | 106 | print('') 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | class TrainingParams: 115 | """ 116 | Parameters for model training 117 | """ 118 | def __init__(self, params_path: str, model_params_path: str, debug: bool = False): 119 | """ 120 | Configuration files 121 | :param path: Training configuration file 122 | :param model_params: Model-specific configuration file 123 | """ 124 | 125 | assert os.path.exists(params_path), 'Cannot find configuration file: {}'.format(params_path) 126 | assert os.path.exists(model_params_path), 'Cannot find model-specific configuration file: {}'.format(model_params_path) 127 | self.params_path = params_path 128 | self.model_params_path = model_params_path 129 | self.debug = debug 130 | 131 | config = configparser.ConfigParser() 132 | 133 | config.read(self.params_path) 134 | params = config['DEFAULT'] 135 | self.dataset_folder = params.get('dataset_folder') 136 | 137 | params = config['TRAIN'] 138 | self.save_freq = params.getint('save_freq', 0) # Model saving frequency (in epochs) 139 | self.num_workers = params.getint('num_workers', 0) 140 | 141 | # Initial batch size for global descriptors (for both main and secondary dataset) 142 | self.batch_size = params.getint('batch_size', 64) 143 | # When batch_split_size is non-zero, multistage backpropagation is enabled 144 | self.batch_split_size = params.getint('batch_split_size', None) 145 | 146 | # Set batch_expansion_th to turn on dynamic batch sizing 147 | # When number of non-zero triplets falls below batch_expansion_th, expand batch size 148 | self.batch_expansion_th = params.getfloat('batch_expansion_th', None) 149 | if self.batch_expansion_th is not None: 150 | assert 0. < self.batch_expansion_th < 1., 'batch_expansion_th must be between 0 and 1' 151 | self.batch_size_limit = params.getint('batch_size_limit', 256) 152 | # Batch size expansion rate 153 | self.batch_expansion_rate = params.getfloat('batch_expansion_rate', 1.5) 154 | assert self.batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 155 | else: 156 | self.batch_size_limit = self.batch_size 157 | self.batch_expansion_rate = None 158 | 159 | self.val_batch_size = params.getint('val_batch_size', self.batch_size_limit) 160 | 161 | self.lr = params.getfloat('lr', 1e-3) 162 | self.epochs = params.getint('epochs', 20) 163 | self.optimizer = params.get('optimizer', 'Adam') 164 | self.scheduler = params.get('scheduler', 'MultiStepLR') 165 | if self.scheduler is not None: 166 | if self.scheduler == 'CosineAnnealingLR': 167 | self.min_lr = params.getfloat('min_lr') 168 | elif self.scheduler == 'MultiStepLR': 169 | if 'scheduler_milestones' in params: 170 | scheduler_milestones = params.get('scheduler_milestones') 171 | self.scheduler_milestones = [int(e) for e in scheduler_milestones.split(',')] 172 | else: 173 | self.scheduler_milestones = [self.epochs+1] 174 | else: 175 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(self.scheduler)) 176 | 177 | self.weight_decay = params.getfloat('weight_decay', None) 178 | self.loss = params.get('loss').lower() 179 | if 'contrastive' in self.loss: 180 | self.pos_margin = params.getfloat('pos_margin', 0.2) 181 | self.neg_margin = params.getfloat('neg_margin', 0.65) 182 | elif 'triplet' in self.loss: 183 | self.margin = params.getfloat('margin', 0.4) # Margin used in loss function 184 | elif self.loss in ['truncatedsmoothap']: 185 | # Number of best positives (closest to the query) to consider 186 | self.positives_per_query = params.getint("positives_per_query", 4) 187 | # Temperatures (annealing parameter) and numbers of nearest neighbours to consider 188 | self.tau1 = params.getfloat('tau1', 0.01) 189 | self.margin = params.getfloat('margin', None) # Margin used in loss function 190 | 191 | # Similarity measure: based on cosine similarity or Euclidean distance 192 | self.similarity = params.get('similarity', 'euclidean') 193 | assert self.similarity in ['cosine', 'euclidean'] 194 | 195 | self.aug_mode = params.getint('aug_mode', 1) # Augmentation mode (1 is default) 196 | self.set_aug_mode = params.getint('set_aug_mode', 1) # Augmentation mode (1 is default) 197 | self.train_file = params.get('train_file') 198 | self.val_file = params.get('val_file', None) 199 | self.test_file = params.get('test_file', None) 200 | 201 | # Read model parameters 202 | self.model_params = ModelParams(self.model_params_path) 203 | # self._check_params() 204 | 205 | # def _check_params(self): 206 | # assert os.path.exists(self.dataset_folder), 'Cannot access dataset: {}'.format(self.dataset_folder) 207 | 208 | def print(self): 209 | print('Parameters:') 210 | param_dict = vars(self) 211 | for e in param_dict: 212 | if e != 'model_params': 213 | print('{}: {}'.format(e, param_dict[e])) 214 | 215 | self.model_params.print() 216 | print('') 217 | 218 | -------------------------------------------------------------------------------- /datasets/dataloader_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch.utils.data as data 5 | import MinkowskiEngine as ME 6 | import torch 7 | import os 8 | 9 | from datasets.augmentation import ValRGBTransform 10 | import numpy as np 11 | import random 12 | 13 | from tools.utils_adafusion import pc_array_to_voxel 14 | 15 | from PIL import Image 16 | from datasets.oxford import ts_from_filename 17 | from tools.utils import * 18 | from tools.options import * 19 | args = Options().parse() 20 | set_seed(7) 21 | 22 | 23 | def image4lidar(filename, image_path, image_ext, lidar2image_ndx, k=None): 24 | # Return an image corresponding to the given lidar point cloud (given as a path to .bin file) 25 | # k: Number of closest images to randomly select from 26 | lidar_ts, traversal = ts_from_filename(filename) 27 | assert lidar_ts in lidar2image_ndx, 'Unknown lidar timestamp: {}'.format(lidar_ts) 28 | 29 | # Randomly select one of images linked with the point cloud 30 | if k is None or k > len(lidar2image_ndx[lidar_ts]): 31 | k = len(lidar2image_ndx[lidar_ts]) 32 | 33 | image_ts = random.choice(lidar2image_ndx[lidar_ts][:k]) 34 | 35 | 36 | if args.dataset in ['oxford','oxfordadafusion']: 37 | image_file_path = os.path.join(image_path, traversal, str(image_ts) + image_ext) 38 | 39 | elif args.dataset == 'boreas': 40 | image_file_path = os.path.join(image_path, traversal, 'camera_lidar_interval10', str(image_ts) + image_ext) 41 | 42 | 43 | #image_file_path = '/media/sf_Datasets/images4lidar/2014-05-19-13-20-57/1400505893134088.png' 44 | img = Image.open(image_file_path) 45 | return img 46 | 47 | 48 | 49 | 50 | def load_data_item(file_name, lidar2image_ndx): 51 | # returns Nx3 matrix 52 | file_path = os.path.join(args.dataset_folder, file_name) 53 | 54 | result = {} 55 | 56 | 57 | if args.dataset in ['oxford','oxfordadafusion']: 58 | pc = np.fromfile(file_path, dtype=np.float64) 59 | # coords are within -1..1 range in each dimension 60 | assert pc.shape[0] == args.num_points * 3, "Error in point cloud shape: {}".format(file_path) 61 | pc = np.reshape(pc, (pc.shape[0] // 3, 3)) 62 | pc = torch.tensor(pc, dtype=torch.float) 63 | result['coords'] = pc 64 | result['clouds'] = pc.detach().clone() 65 | 66 | elif args.dataset == 'boreas': 67 | pc = np.load(file_path, allow_pickle=True) 68 | # coords are within -1..1 range in each dimension 69 | assert pc.shape[0] == args.num_points, "Error in point cloud shape: {}".format(file_path) 70 | pc = torch.tensor(pc, dtype=torch.float) 71 | result['coords'] = pc 72 | result['clouds'] = pc.detach().clone() 73 | 74 | 75 | P0_camera_path = file_path.replace('lidar_1_4096_interval10','P0_camera_interval10').replace('.npy','.txt') 76 | P0_camera = np.loadtxt(P0_camera_path) 77 | T_camera_lidar_basedon_pose_path = file_path.replace('lidar_1_4096_interval10','T_camera_lidar_basedon_pose_interval10').replace('.npy','.txt') 78 | T_camera_lidar_basedon_pose = np.loadtxt(T_camera_lidar_basedon_pose_path) 79 | 80 | 81 | P0_camera = torch.tensor(P0_camera).float() 82 | T_camera_lidar_basedon_pose = torch.tensor(T_camera_lidar_basedon_pose).float() 83 | result['P0_camera'] = P0_camera 84 | result['T_camera_lidar_basedon_pose'] = T_camera_lidar_basedon_pose 85 | 86 | 87 | 88 | 89 | 90 | # Get the first closest image for each LiDAR scan 91 | assert os.path.exists(args.lidar2image_ndx_path), f"Cannot find lidar2image_ndx pickle: {args.lidar2image_ndx_path}" 92 | # lidar2image_ndx = pickle.load(open(params.lidar2image_ndx_path, 'rb')) 93 | img = image4lidar(file_name, args.image_path, '.png', lidar2image_ndx, k=1) 94 | transform = ValRGBTransform() 95 | # Convert to tensor and normalize 96 | result['image'] = transform(img) 97 | 98 | 99 | 100 | 101 | 102 | 103 | return result 104 | 105 | 106 | 107 | 108 | 109 | def collate_fn(batch_list): 110 | 111 | batch_dict = {} 112 | coords_list = [] 113 | images_list = [] 114 | clouds_list = [] 115 | voxels_list = [] 116 | # sph_clouds_list = [] 117 | T_camera_lidar_basedon_pose_list = [] 118 | P0_camera_list = [] 119 | 120 | 121 | for each_batch in batch_list: 122 | coords = each_batch['coords'] # [4096,3] 123 | images = each_batch['images'] 124 | clouds = each_batch['clouds'] 125 | voxels = each_batch['voxels'] 126 | 127 | 128 | # if args.dataset == 'boreas': 129 | # if args.sph_cloud_fe is not None: 130 | # sph_clouds = each_batch['sph_cloud'] 131 | 132 | 133 | if args.dataset in ['oxford','oxfordadafusion']: 134 | coords = ME.utils.sparse_quantize(coordinates=coords, quantization_size=args.oxford_quantization_size) 135 | elif args.dataset == 'boreas': 136 | coords = ME.utils.sparse_quantize(coordinates=coords, quantization_size=args.boreas_quantization_size) 137 | 138 | 139 | 140 | coords_list.append(coords) 141 | images_list.append(images) 142 | clouds_list.append(clouds) 143 | voxels_list.append(voxels) 144 | 145 | # if args.dataset == 'boreas': 146 | # if args.sph_cloud_fe is not None: 147 | # sph_clouds_list.append(sph_clouds) 148 | 149 | # T_camera_lidar_basedon_pose = each_batch['T_camera_lidar_basedon_pose'] 150 | # P0_camera = each_batch['P0_camera'] 151 | # T_camera_lidar_basedon_pose_list.append(T_camera_lidar_basedon_pose) 152 | # P0_camera_list.append(P0_camera) 153 | 154 | 155 | 156 | coords_list = ME.utils.batched_coordinates(coords_list) 157 | features_list = torch.ones([len(coords_list), 1]) 158 | images_list = torch.stack(images_list) 159 | clouds_list = torch.stack(clouds_list) 160 | voxels_list = torch.stack(voxels_list) # [B, 100, 100, 100] 161 | # voxels_list = voxels_list.unsqueeze(1) # [B, 1, 100, 100, 100] 162 | 163 | 164 | 165 | 166 | # if args.dataset == 'boreas': 167 | # if args.sph_cloud_fe is not None: 168 | # sph_clouds_list = torch.stack(sph_clouds_list) 169 | 170 | # T_camera_lidar_basedon_pose_list = torch.stack(T_camera_lidar_basedon_pose_list) 171 | # P0_camera_list = torch.stack(P0_camera_list) 172 | 173 | 174 | if args.dataset == 'boreas': 175 | # if args.sph_cloud_fe is not None: 176 | # batch_dict = { 177 | # 'coords': coords_list, 178 | # 'features': features_list, 179 | # 'images': images_list, 180 | # 'clouds': clouds_list, 181 | # # 'sph_cloud':sph_clouds_list 182 | # } 183 | # return batch_dict 184 | 185 | 186 | batch_dict = { 187 | 'coords': coords_list, 188 | 'features': features_list, 189 | 'images': images_list, 190 | 'clouds': clouds_list, 191 | # 'T_camera_lidar_basedon_pose':T_camera_lidar_basedon_pose_list, 192 | # 'P0_camera':P0_camera_list 193 | } 194 | return batch_dict 195 | 196 | 197 | 198 | 199 | 200 | batch_dict = { 201 | 'coords': coords_list, 202 | 'features': features_list, 203 | 'images': images_list, 204 | 'clouds': clouds_list, 205 | 'voxels': voxels_list 206 | } 207 | 208 | return batch_dict 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | class DataloaderDataset(data.Dataset): 218 | def __init__(self, set_dict, device, lidar2image_ndx): 219 | 220 | self.set_dict = set_dict 221 | 222 | # self.params = params 223 | self.device = device 224 | self.lidar2image_ndx = lidar2image_ndx 225 | # self.lidar2image_ndx = pickle.load(open(params.lidar2image_ndx_path, 'rb')) 226 | 227 | 228 | def __len__(self): 229 | length = len(self.set_dict) 230 | return length 231 | 232 | 233 | def __getitem__(self, index): 234 | 235 | 236 | data_dict = {} 237 | 238 | 239 | x = load_data_item(self.set_dict[index]["query"], self.lidar2image_ndx) 240 | 241 | 242 | # quantize in collate_fn 243 | data_dict['coords'] = x['coords'] 244 | 245 | # if args.dataset == 'boreas': 246 | # if args.sph_cloud_fe is not None: 247 | # data_dict['sph_cloud'] = x['sph_cloud'] 248 | 249 | 250 | data_dict['images'] = x['image'] 251 | 252 | data_dict['clouds'] = x['clouds'] 253 | 254 | 255 | assert len(x['coords']) == len(x['clouds']) # [-1,1] 256 | # ---- voxel 257 | # # viz_lidar_open3d(coords.numpy()) 258 | # voxel_ids = x['clouds'] + 1 # [0,2] 259 | # voxel_ids = voxel_ids * 48 # [0,100] 260 | # voxel_ids = voxel_ids.int() 261 | # # viz_lidar_open3d(voxel_ids.numpy()) 262 | # voxels = torch.zeros(100,100,100).float() 263 | # voxels[voxel_ids[:,0],voxel_ids[:,1],voxel_ids[:,2]] = 1 264 | # # _voxel_ids = torch.where(voxels>0) 265 | # # _voxel_ids = torch.stack(_voxel_ids,dim=1) 266 | # # viz_lidar_open3d(_voxel_ids.numpy()) 267 | # data_dict['voxels'] = voxels 268 | 269 | # ---- voxel 270 | voxels = pc_array_to_voxel(x['clouds'].numpy()) # [1,shape0,shape1,shape2] Tensor 271 | data_dict['voxels'] = voxels 272 | a=1 273 | 274 | 275 | 276 | 277 | 278 | # if args.dataset in ['boreas']: 279 | # data_dict['P0_camera'] = x['P0_camera'] 280 | # data_dict['T_camera_lidar_basedon_pose'] = x['T_camera_lidar_basedon_pose'] 281 | 282 | 283 | 284 | return data_dict 285 | 286 | 287 | -------------------------------------------------------------------------------- /network/distil_imagefes.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Model processing LiDAR point clouds and RGB images 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.models as TVmodels 9 | # from TV_offline_models.swin_transformer import swin_v2_t,swin_v2_s 10 | import MinkowskiEngine as ME 11 | 12 | from models.minkloc import MinkLoc 13 | from network.resnetfpn_simple import ImageGeM 14 | from network.resnetfpn_simple import ImageConvAP 15 | from network.resnetfpn_simple import ImageMixVPR 16 | from network.resnetfpn_simple import ImageCosPlace 17 | from network.resnetfpn_simple import ImageNetVLAD 18 | 19 | from tools.utils import set_seed 20 | set_seed(7) 21 | # from tools.options import Options 22 | # args = Options().parse() 23 | 24 | 25 | 26 | 27 | 28 | class GeM(nn.Module): 29 | def __init__(self, p=3, eps=1e-6): 30 | super(GeM, self).__init__() 31 | self.p = nn.Parameter(torch.ones(1) * p) 32 | self.eps = eps 33 | 34 | def forward(self, x): 35 | return nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | class DistilImageFE(torch.nn.Module): 45 | def __init__(self, 46 | image_fe, 47 | num_other_stage_blocks, 48 | num_stage3_blocks, 49 | input_type, 50 | pool_method, # GeM 51 | useallstages, # True 52 | dataset, # oxford 53 | out_channels: int=None, 54 | lateral_dim: int=None, 55 | layers=[64, 64, 128, 256, 512], 56 | fh_num_bottom_up: int = 5, 57 | fh_num_top_down: int = 2, 58 | add_fc_block: bool = False, 59 | 60 | ): 61 | super().__init__() 62 | ''' 63 | resnet [64,64,128,256,512] 64 | convnext [96,96,192,384,768] 65 | swin [96,96,192,384,768] 66 | swin_v2 [96,96,192,384,768] 67 | ''' 68 | 69 | assert input_type in ['image','sph_cloud'] 70 | assert dataset in ['oxford','oxfordadafusion'] 71 | 72 | self.image_fe = image_fe 73 | self.num_other_stage_blocks = num_other_stage_blocks 74 | self.num_stage3_blocks = num_stage3_blocks 75 | self.input_type = input_type 76 | 77 | 78 | self.out_channels = out_channels 79 | self.pool_method = pool_method 80 | 81 | 82 | self.useallstages = useallstages 83 | 84 | # -- resnet 85 | if self.image_fe == 'resnet18': 86 | self.model = TVmodels.resnet18(weights='IMAGENET1K_V1') 87 | if self.useallstages: 88 | self.last_dim = 512 89 | else: 90 | self.last_dim = 256 91 | elif self.image_fe == 'resnet34': 92 | self.model = TVmodels.resnet34(weights='IMAGENET1K_V1') 93 | self.last_dim = 256 94 | elif self.image_fe == 'resnet50': 95 | self.model = TVmodels.resnet50(weights='IMAGENET1K_V2') 96 | self.last_dim = 1024 97 | 98 | # -- convnext 99 | elif self.image_fe == 'convnext_tiny': 100 | self.model = TVmodels.convnext_tiny(weights='IMAGENET1K_V1') 101 | if self.useallstages: 102 | self.last_dim = 768 103 | else: 104 | self.last_dim = 384 105 | elif self.image_fe == 'convnext_small': 106 | self.model = TVmodels.convnext_small(weights='IMAGENET1K_V1') 107 | self.last_dim = 384 108 | 109 | # -- swin 110 | elif self.image_fe == 'swin_t': 111 | self.model = TVmodels.swin_t(weights='IMAGENET1K_V1') 112 | self.last_dim = 384 113 | elif self.image_fe == 'swin_s': 114 | self.model = TVmodels.swin_s(weights='IMAGENET1K_V1') 115 | self.last_dim = 384 116 | elif self.image_fe == 'swin_v2_t': 117 | self.model = TVmodels.swin_v2_t(weights='IMAGENET1K_V1') 118 | self.last_dim = 384 119 | elif self.image_fe == 'swin_v2_s': 120 | self.model = TVmodels.swin_v2_s(weights='IMAGENET1K_V1') 121 | self.last_dim = 384 122 | 123 | 124 | 125 | # self.conv1x1 = nn.Conv2d(self.last_dim, 128, kernel_size=1) 126 | 127 | 128 | 129 | 130 | pool_dim = self.last_dim 131 | 132 | self.image_gem = ImageGeM() # *1 133 | self.imagecosplace = ImageCosPlace(pool_dim, pool_dim) # *1 134 | self.imageconvap = ImageConvAP(pool_dim, pool_dim) # *4 135 | self.dataset = dataset 136 | self.image_fe = image_fe 137 | 138 | 139 | 140 | if self.dataset in ['oxford','oxfordadafusion']: 141 | if self.image_fe in ['resnet18']: 142 | if self.useallstages: 143 | mixvpr_h, mixvpr_w = (8,10) 144 | else: 145 | mixvpr_h, mixvpr_w = (15,20) 146 | elif self.image_fe in ['convnext_tiny']: 147 | if self.useallstages: 148 | mixvpr_h, mixvpr_w = (7,10) 149 | else: 150 | mixvpr_h, mixvpr_w = (15,20) 151 | 152 | elif self.dataset in ['boreas']: 153 | if self.image_fe in ['resnet18']: 154 | if self.useallstages: 155 | mixvpr_h, mixvpr_w = (8,10) 156 | else: 157 | mixvpr_h, mixvpr_w = (16,20) 158 | elif self.image_fe in ['convnext_tiny']: 159 | if self.useallstages: 160 | mixvpr_h, mixvpr_w = (8,9) 161 | else: 162 | mixvpr_h, mixvpr_w = (16,19) 163 | 164 | 165 | 166 | self.imagemixvpr = ImageMixVPR( # *4 167 | in_channels=pool_dim, 168 | in_h=mixvpr_h, 169 | in_w=mixvpr_w, 170 | out_channels=pool_dim, 171 | mix_depth=4, 172 | mlp_ratio=1, 173 | out_rows=4) # [h=16,w=20] for boreas 174 | self.imagenetvlad = ImageNetVLAD(clusters_num=64, 175 | dim=pool_dim) # *4 176 | 177 | 178 | 179 | 180 | 181 | 182 | def forward_resnet(self, x): 183 | x = self.model.conv1(x) 184 | x = self.model.bn1(x) 185 | x = self.model.relu(x) 186 | x = self.model.maxpool(x) 187 | 188 | x = self.model.layer1(x) 189 | x = self.model.layer2(x) 190 | x = self.model.layer3(x) 191 | if self.useallstages: 192 | x = self.model.layer4(x) 193 | 194 | return x 195 | 196 | 197 | 198 | def forward_convnext(self, x): 199 | layers_list = list(self.model.features.children()) 200 | assert len(layers_list)==8 201 | if not self.useallstages: 202 | layers_list[1] = layers_list[1][:self.num_other_stage_blocks] 203 | layers_list[3] = layers_list[3][:self.num_other_stage_blocks] 204 | layers_list[5] = layers_list[5][:self.num_stage3_blocks] 205 | layers_list = layers_list[:-2] 206 | else: 207 | layers_list = layers_list 208 | 209 | for i in range(len(layers_list)): 210 | layer = layers_list[i] 211 | x = layer(x) 212 | return x 213 | 214 | 215 | def forward_swin(self, x): 216 | layers_list = list(self.model.features.children()) 217 | if not self.useallstages: 218 | layers_list = layers_list[:-2] 219 | else: 220 | layers_list = layers_list 221 | for i in range(len(layers_list)): 222 | layer = layers_list[i] 223 | x = layer(x) 224 | x = x.permute(0,3,1,2) 225 | return x 226 | 227 | 228 | 229 | 230 | 231 | 232 | def forward(self, data_dict): 233 | if self.input_type == 'image': 234 | x = data_dict['images'] 235 | elif self.input_type == 'sph_cloud': 236 | x = data_dict['sph_cloud'] 237 | 238 | 239 | 240 | if self.image_fe in ['resnet18','resnet34','resnet50']: 241 | x = self.forward_resnet(x) 242 | elif self.image_fe in ['convnext_tiny','convnext_small']: 243 | x = self.forward_convnext(x) 244 | elif self.image_fe in ['swin_t','swin_s']: 245 | x = self.forward_swin(x) 246 | elif self.image_fe in ['swin_v2_t','swin_v2_s']: 247 | x = self.forward_swin(x) 248 | elif self.image_fe in ['efficientnet_b0','efficientnet_b1','efficientnet_b2','efficientnet_v2_s']: 249 | x = self.forward_efficientnet(x) 250 | elif self.image_fe in ['regnet_x_3_2gf','regnet_y_1_6gf','regnet_y_3_2gf']: 251 | x = self.forward_regnet(x) 252 | else: 253 | raise NotImplementedError 254 | 255 | 256 | x_feat_256 = x 257 | 258 | 259 | 260 | 261 | if self.pool_method == 'GeM': 262 | x_gem_256 = self.image_gem(x_feat_256) 263 | 264 | elif self.pool_method == 'ConvAP': 265 | x_gem_256 = self.imageconvap(x_feat_256) 266 | 267 | elif self.pool_method == 'MixVPR': 268 | x_gem_256 = self.imagemixvpr(x_feat_256) 269 | 270 | elif self.pool_method == 'CosPlace': 271 | x_gem_256 = self.imagecosplace(x_feat_256) 272 | 273 | elif self.pool_method == 'NetVLAD': 274 | x_gem_256 = self.imagenetvlad(x_feat_256) 275 | 276 | else: 277 | raise NotImplementedError 278 | 279 | 280 | 281 | output_dict = { 282 | 'embedding': x_gem_256, 283 | 'x_feat': x_feat_256, 284 | } 285 | 286 | 287 | return output_dict 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import numpy as np 5 | import math 6 | from scipy.linalg import expm, norm 7 | import random 8 | import torch 9 | 10 | import torchvision.transforms as transforms 11 | 12 | from tools.options import Options 13 | args = Options().parse() 14 | from tools.utils import set_seed 15 | set_seed(7) 16 | 17 | 18 | 19 | class TrainTransform: 20 | def __init__(self, aug_mode): 21 | # 1 is default mode, no transform 22 | self.aug_mode = aug_mode 23 | if self.aug_mode == 0: 24 | self.transform = None 25 | elif self.aug_mode == 1: 26 | if args.dataset in ['oxford','oxfordadafusion']: 27 | t = [ 28 | JitterPoints(sigma=0.001, clip=0.002), 29 | RemoveRandomPoints(r=(0.0, 0.1)), 30 | RandomTranslation(max_delta=0.01), 31 | RemoveRandomBlock(p=0.4) 32 | ] 33 | else: 34 | t = [ 35 | JitterPoints(sigma=0.001, clip=0.002), 36 | RemoveRandomPoints(r=(0.0, 0.1)), 37 | RemoveRandomBlock(p=0.4) 38 | ] 39 | 40 | else: 41 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 42 | self.transform = transforms.Compose(t) 43 | 44 | def __call__(self, e): 45 | if self.transform is not None: 46 | e = self.transform(e) 47 | return e 48 | 49 | 50 | 51 | 52 | 53 | class TrainSetTransform: 54 | def __init__(self, aug_mode): 55 | self.aug_mode = aug_mode 56 | self.transform = None 57 | if aug_mode == 0: 58 | t = None 59 | elif aug_mode == 1: 60 | if args.dataset in ['oxford','oxfordadafusion']: 61 | t = [ 62 | RandomRotation(max_theta=5, max_theta2=0, axis=np.array([0, 0, 1])), 63 | RandomFlip([0.25, 0.25, 0.]) 64 | ] 65 | else: 66 | t = [] 67 | else: 68 | raise NotImplementedError('Unknown aug_mode: {}'.format(aug_mode)) 69 | if t is None: 70 | self.transform = None 71 | else: 72 | self.transform = transforms.Compose(t) 73 | 74 | def __call__(self, e): 75 | if self.transform is not None: 76 | e = self.transform(e) 77 | return e 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | class TrainRGBTransform: 87 | def __init__(self, aug_mode): 88 | # 1 is default mode, no transform 89 | self.aug_mode = aug_mode 90 | if self.aug_mode == 0: 91 | t = [ 92 | transforms.ToTensor(), 93 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 94 | ] 95 | elif self.aug_mode > 0: 96 | t = [ 97 | transforms.ColorJitter(brightness=args.bcs_aug_rate, contrast=args.bcs_aug_rate, saturation=args.bcs_aug_rate, hue=args.hue_aug_rate), 98 | transforms.ToTensor(), 99 | transforms.RandomErasing(scale=(0.1, 0.4)), 100 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 101 | ] 102 | else: 103 | raise NotImplementedError('Unknown aug_mode: {}'.format(self.aug_mode)) 104 | self.transform = transforms.Compose(t) 105 | 106 | def __call__(self, e): 107 | if self.transform is not None: 108 | e = self.transform(e) 109 | return e 110 | 111 | 112 | class ValRGBTransform: 113 | def __init__(self): 114 | # 1 is default mode, no transform 115 | t = [ 116 | transforms.ToTensor(), 117 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 118 | ] 119 | self.transform = transforms.Compose(t) 120 | 121 | def __call__(self, e): 122 | e = self.transform(e) 123 | return e 124 | 125 | 126 | class RandomFlip: 127 | def __init__(self, p): 128 | # p = [p_x, p_y, p_z] probability of flipping each axis 129 | assert len(p) == 3 130 | assert 0 < sum(p) <= 1, 'sum(p) must be in (0, 1] range, is: {}'.format(sum(p)) 131 | self.p = p 132 | self.p_cum_sum = np.cumsum(p) 133 | 134 | def __call__(self, coords): 135 | r = random.random() 136 | if r <= self.p_cum_sum[0]: 137 | # Flip the first axis 138 | coords[..., 0] = -coords[..., 0] 139 | elif r <= self.p_cum_sum[1]: 140 | # Flip the second axis 141 | coords[..., 1] = -coords[..., 1] 142 | elif r <= self.p_cum_sum[2]: 143 | # Flip the third axis 144 | coords[..., 2] = -coords[..., 2] 145 | 146 | return coords 147 | 148 | 149 | class RandomRotation: 150 | def __init__(self, axis=None, max_theta=180, max_theta2=15): 151 | self.axis = axis 152 | self.max_theta = max_theta # Rotation around axis 153 | self.max_theta2 = max_theta2 # Smaller rotation in random direction 154 | 155 | def _M(self, axis, theta): 156 | return expm(np.cross(np.eye(3), axis / norm(axis) * theta)).astype(np.float32) 157 | 158 | def __call__(self, coords): 159 | if self.axis is not None: 160 | axis = self.axis 161 | else: 162 | axis = np.random.rand(3) - 0.5 163 | R = self._M(axis, (np.pi * self.max_theta / 180) * 2 * (np.random.rand(1) - 0.5)) 164 | if self.max_theta2 is None: 165 | coords = coords @ R 166 | else: 167 | R_n = self._M(np.random.rand(3) - 0.5, (np.pi * self.max_theta2 / 180) * 2 * (np.random.rand(1) - 0.5)) 168 | coords = coords @ R @ R_n 169 | 170 | return coords 171 | 172 | 173 | class RandomTranslation: 174 | def __init__(self, max_delta=0.05): 175 | self.max_delta = max_delta 176 | 177 | def __call__(self, coords): 178 | trans = self.max_delta * np.random.randn(1, 3) 179 | return coords + trans.astype(np.float32) 180 | 181 | 182 | class RandomScale: 183 | def __init__(self, min, max): 184 | self.scale = max - min 185 | self.bias = min 186 | 187 | def __call__(self, coords): 188 | s = self.scale * np.random.rand(1) + self.bias 189 | return coords * s.astype(np.float32) 190 | 191 | 192 | class RandomShear: 193 | def __init__(self, delta=0.1): 194 | self.delta = delta 195 | 196 | def __call__(self, coords): 197 | T = np.eye(3) + self.delta * np.random.randn(3, 3) 198 | return coords @ T.astype(np.float32) 199 | 200 | 201 | class JitterPoints: 202 | def __init__(self, sigma=0.01, clip=None, p=1.): 203 | assert 0 < p <= 1. 204 | assert sigma > 0. 205 | 206 | self.sigma = sigma 207 | self.clip = clip 208 | self.p = p 209 | 210 | def __call__(self, e): 211 | """ Randomly jitter points. jittering is per point. 212 | Input: 213 | BxNx3 array, original batch of point clouds 214 | Return: 215 | BxNx3 array, jittered batch of point clouds 216 | """ 217 | 218 | sample_shape = (e.shape[0],) 219 | if self.p < 1.: 220 | # Create a mask for points to jitter 221 | m = torch.distributions.categorical.Categorical(probs=torch.tensor([1 - self.p, self.p])) 222 | mask = m.sample(sample_shape=sample_shape) 223 | else: 224 | mask = torch.ones(sample_shape, dtype=torch.int64 ) 225 | 226 | mask = mask == 1 227 | jitter = self.sigma * torch.randn_like(e[mask]) 228 | 229 | if self.clip is not None: 230 | jitter = torch.clamp(jitter, min=-self.clip, max=self.clip) 231 | 232 | e[mask] = e[mask] + jitter 233 | return e 234 | 235 | 236 | class RemoveRandomPoints: 237 | def __init__(self, r): 238 | if type(r) is list or type(r) is tuple: 239 | assert len(r) == 2 240 | assert 0 <= r[0] <= 1 241 | assert 0 <= r[1] <= 1 242 | self.r_min = float(r[0]) 243 | self.r_max = float(r[1]) 244 | else: 245 | assert 0 <= r <= 1 246 | self.r_min = None 247 | self.r_max = float(r) 248 | 249 | def __call__(self, e): 250 | n = len(e) 251 | if self.r_min is None: 252 | r = self.r_max 253 | else: 254 | # Randomly select removal ratio 255 | r = random.uniform(self.r_min, self.r_max) 256 | 257 | mask = np.random.choice(range(n), size=int(n*r), replace=False) # select elements to remove 258 | e[mask] = torch.zeros_like(e[mask]) 259 | return e 260 | 261 | 262 | class RemoveRandomBlock: 263 | """ 264 | Randomly remove part of the point cloud. Similar to PyTorch RandomErasing but operating on 3D point clouds. 265 | Erases fronto-parallel cuboid. 266 | Instead of erasing we set coords of removed points to (0, 0, 0) to retain the same number of points 267 | """ 268 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 269 | self.p = p 270 | self.scale = scale 271 | self.ratio = ratio 272 | 273 | def get_params(self, coords): 274 | # Find point cloud 3D bounding box 275 | flattened_coords = coords.view(-1, 3) 276 | min_coords, _ = torch.min(flattened_coords, dim=0) 277 | max_coords, _ = torch.max(flattened_coords, dim=0) 278 | span = max_coords - min_coords 279 | area = span[0] * span[1] 280 | erase_area = random.uniform(self.scale[0], self.scale[1]) * area 281 | aspect_ratio = random.uniform(self.ratio[0], self.ratio[1]) 282 | 283 | h = math.sqrt(erase_area * aspect_ratio) 284 | w = math.sqrt(erase_area / aspect_ratio) 285 | 286 | x = min_coords[0] + random.uniform(0, 1) * (span[0] - w) 287 | y = min_coords[1] + random.uniform(0, 1) * (span[1] - h) 288 | 289 | return x, y, w, h 290 | 291 | def __call__(self, coords): 292 | if random.random() < self.p: 293 | x, y, w, h = self.get_params(coords) # Fronto-parallel cuboid to remove 294 | mask = (x < coords[..., 0]) & (coords[..., 0] < x+w) & (y < coords[..., 1]) & (coords[..., 1] < y+h) 295 | coords[mask] = torch.zeros_like(coords[mask]) 296 | return coords 297 | 298 | 299 | def tensor2img(x): 300 | t = transforms.Compose([transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 301 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225]), 302 | transforms.ToPILImage()]) 303 | return t(x) 304 | -------------------------------------------------------------------------------- /multi_stage_train.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | import torch 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | # def multistaged_training_step(global_iter, model, phase, device, optimizer, loss_fn): 16 | def multistaged_training_step_distil(batch, positives_mask, negatives_mask, model, phase, device, optimizer, loss_fn_stu, 17 | compute_all_loss, output_dict_tea): 18 | """ 19 | multi-stage training step for distillation 20 | """ 21 | assert phase in ['train', 'val'] 22 | # batch: {{'coords':, 'features':}*16} 23 | # batch, positives_mask, negatives_mask = next(global_iter) 24 | 25 | if phase == 'train': 26 | model.train() 27 | else: 28 | model.eval() 29 | 30 | # Stage 1 - calculate descriptors of each batch element (with gradient turned off) 31 | # In training phase network is in the train mode to update BatchNorm stats 32 | embeddings_l = [] 33 | 34 | with torch.set_grad_enabled(False): 35 | for minibatch in batch: 36 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 37 | y = model(minibatch) 38 | embeddings_l.append(y['embedding']) 39 | 40 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 41 | 42 | # Stage 2 - compute gradient of the loss w.r.t embeddings 43 | embeddings = torch.cat(embeddings_l, dim=0) 44 | 45 | 46 | with torch.set_grad_enabled(phase == 'train'): 47 | if phase == 'train': 48 | embeddings.requires_grad_(True) 49 | 50 | _embeddings_dict = { 51 | 'embedding': embeddings, 52 | } 53 | 54 | 55 | 56 | # -- distil loss_fn 57 | loss = compute_all_loss( 58 | output_dict_stu=_embeddings_dict, 59 | output_dict_tea=output_dict_tea, 60 | positives_mask=positives_mask, 61 | negatives_mask=negatives_mask, 62 | adaptor=None, 63 | task_loss_fn_stu=loss_fn_stu, 64 | ) 65 | 66 | 67 | 68 | # stats = tensors_to_numbers(stats) 69 | if phase == 'train': 70 | loss.backward() 71 | embeddings_grad = embeddings.grad 72 | 73 | 74 | 75 | # Stage 3 - recompute descriptors with gradient enabled and compute the gradient of the loss w.r.t. 76 | # network parameters using cached gradient of the loss w.r.t embeddings 77 | if phase == 'train': 78 | optimizer.zero_grad() 79 | i = 0 80 | with torch.set_grad_enabled(True): 81 | for minibatch in batch: 82 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 83 | y = model(minibatch) 84 | embeddings = y['embedding'] 85 | 86 | minibatch_size = len(embeddings) 87 | # Compute gradients of network params w.r.t. the loss using the chain rule (using the 88 | # gradient of the loss w.r.t. embeddings stored in embeddings_grad) 89 | # By default gradients are accumulated 90 | embeddings.backward(gradient=embeddings_grad[i: i+minibatch_size]) 91 | 92 | i += minibatch_size 93 | 94 | optimizer.step() 95 | 96 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | # def multistaged_training_step(global_iter, model, phase, device, optimizer, loss_fn): 111 | def multistaged_training_step(batch, positives_mask, negatives_mask, model, phase, device, optimizer, loss_fn): 112 | # Training step using multistaged backpropagation algorithm as per: 113 | # "Learning with Average Precision: Training Image Retrieval with a Listwise Loss" 114 | # This method will break when the model contains Dropout, as the same mini-batch will produce different embeddings. 115 | # Make sure mini-batches in step 1 and step 3 are the same (so that BatchNorm produces the same results) 116 | # See some exemplary implementation here: https://gist.github.com/ByungSun12/ad964a08eba6a7d103dab8588c9a3774 117 | 118 | assert phase in ['train', 'val'] 119 | # batch: {{'coords':, 'features':}*16} 120 | # batch, positives_mask, negatives_mask = next(global_iter) 121 | 122 | if phase == 'train': 123 | model.train() 124 | else: 125 | model.eval() 126 | 127 | # Stage 1 - calculate descriptors of each batch element (with gradient turned off) 128 | # In training phase network is in the train mode to update BatchNorm stats 129 | embeddings_l = [] 130 | 131 | with torch.set_grad_enabled(False): 132 | for minibatch in batch: 133 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 134 | y = model(minibatch) 135 | embeddings_l.append(y['embedding']) 136 | 137 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 138 | 139 | # Stage 2 - compute gradient of the loss w.r.t embeddings 140 | embeddings = torch.cat(embeddings_l, dim=0) 141 | 142 | 143 | with torch.set_grad_enabled(phase == 'train'): 144 | if phase == 'train': 145 | embeddings.requires_grad_(True) 146 | 147 | _embeddings_dict = { 148 | 'embedding': embeddings, 149 | } 150 | 151 | 152 | # -- vanilla loss_fn 153 | loss, stats, _ = loss_fn(_embeddings_dict, positives_mask, negatives_mask) 154 | 155 | 156 | # stats = tensors_to_numbers(stats) 157 | if phase == 'train': 158 | loss.backward() 159 | embeddings_grad = embeddings.grad 160 | 161 | 162 | # # Delete intermediary values 163 | # embeddings_l, embeddings, y, loss = None, None, None, None 164 | 165 | # Stage 3 - recompute descriptors with gradient enabled and compute the gradient of the loss w.r.t. 166 | # network parameters using cached gradient of the loss w.r.t embeddings 167 | if phase == 'train': 168 | optimizer.zero_grad() 169 | i = 0 170 | with torch.set_grad_enabled(True): 171 | for minibatch in batch: 172 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 173 | y = model(minibatch) 174 | embeddings = y['embedding'] 175 | 176 | minibatch_size = len(embeddings) 177 | # Compute gradients of network params w.r.t. the loss using the chain rule (using the 178 | # gradient of the loss w.r.t. embeddings stored in embeddings_grad) 179 | # By default gradients are accumulated 180 | embeddings.backward(gradient=embeddings_grad[i: i+minibatch_size]) 181 | 182 | i += minibatch_size 183 | 184 | optimizer.step() 185 | 186 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | # def multistaged_training_step(global_iter, model, phase, device, optimizer, loss_fn): 201 | def multistaged_training_step_multimodal(batch, positives_mask, negatives_mask, model, phase, device, optimizer, loss_fn, 202 | train_step_type): 203 | # Training step using multistaged backpropagation algorithm as per: 204 | # "Learning with Average Precision: Training Image Retrieval with a Listwise Loss" 205 | # This method will break when the model contains Dropout, as the same mini-batch will produce different embeddings. 206 | # Make sure mini-batches in step 1 and step 3 are the same (so that BatchNorm produces the same results) 207 | # See some exemplary implementation here: https://gist.github.com/ByungSun12/ad964a08eba6a7d103dab8588c9a3774 208 | 209 | assert phase in ['train', 'val'] 210 | # batch: {{'coords':, 'features':}*16} 211 | # batch, positives_mask, negatives_mask = next(global_iter) 212 | 213 | if phase == 'train': 214 | model.train() 215 | else: 216 | model.eval() 217 | 218 | # Stage 1 - calculate descriptors of each batch element (with gradient turned off) 219 | # In training phase network is in the train mode to update BatchNorm stats 220 | embeddings_l = [] 221 | embeddings_cloud_l = [] 222 | embeddings_image_l = [] 223 | 224 | with torch.set_grad_enabled(False): 225 | for minibatch in batch: 226 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 227 | y = model(minibatch) 228 | embeddings_l.append(y['embedding']) 229 | embeddings_cloud_l.append(y['cloud_embedding']) 230 | embeddings_image_l.append(y['image_embedding']) 231 | 232 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 233 | 234 | # Stage 2 - compute gradient of the loss w.r.t embeddings 235 | embeddings = torch.cat(embeddings_l, dim=0) 236 | embeddings_cloud = torch.cat(embeddings_cloud_l, dim=0) 237 | embeddings_image = torch.cat(embeddings_image_l, dim=0) 238 | 239 | with torch.set_grad_enabled(phase == 'train'): 240 | if phase == 'train': 241 | embeddings.requires_grad_(True) 242 | embeddings_cloud.requires_grad_(True) 243 | embeddings_image.requires_grad_(True) 244 | 245 | _embeddings_dict = { 246 | 'embedding': embeddings, 247 | 'cloud_embedding': embeddings_cloud, 248 | 'image_embedding': embeddings_image 249 | } 250 | 251 | loss, stats, _ = loss_fn(_embeddings_dict, positives_mask, negatives_mask) 252 | # stats = tensors_to_numbers(stats) 253 | if phase == 'train': 254 | loss.backward() 255 | embeddings_grad = embeddings.grad 256 | embeddings_cloud_grad = embeddings_cloud.grad 257 | embeddings_image_grad = embeddings_image.grad 258 | 259 | # # Delete intermediary values 260 | # embeddings_l, embeddings, y, loss = None, None, None, None 261 | 262 | # Stage 3 - recompute descriptors with gradient enabled and compute the gradient of the loss w.r.t. 263 | # network parameters using cached gradient of the loss w.r.t embeddings 264 | if phase == 'train': 265 | optimizer.zero_grad() 266 | i = 0 267 | with torch.set_grad_enabled(True): 268 | for minibatch in batch: 269 | minibatch = {e: minibatch[e].to(device) for e in minibatch} 270 | y = model(minibatch) 271 | embeddings = y['embedding'] 272 | embeddings_cloud = y['cloud_embedding'] 273 | embeddings_image = y['image_embedding'] 274 | 275 | minibatch_size = len(embeddings) 276 | # Compute gradients of network params w.r.t. the loss using the chain rule (using the 277 | # gradient of the loss w.r.t. embeddings stored in embeddings_grad) 278 | # By default gradients are accumulated 279 | if train_step_type == 'multi_sep': 280 | # ---- separate backward ---- 281 | embeddings.backward(gradient=embeddings_grad[i: i+minibatch_size]) 282 | # embeddings_cloud.backward(gradient=embeddings_cloud_grad[i: i+minibatch_size]) 283 | # embeddings_image.backward(gradient=embeddings_image_grad[i: i+minibatch_size]) 284 | elif train_step_type == 'multi_joint': 285 | # ---- joint backward ---- 286 | embeddings_tobackward = torch.cat([ 287 | embeddings, 288 | embeddings_cloud, 289 | # embeddings_image, 290 | ], dim=-1) 291 | embeddings_tobackward_grad = torch.cat([ 292 | embeddings_grad[i: i+minibatch_size], 293 | embeddings_cloud_grad[i: i+minibatch_size], 294 | # embeddings_image_grad[i: i+minibatch_size] 295 | ], dim=-1) 296 | embeddings_tobackward.backward(gradient=embeddings_tobackward_grad) 297 | else: 298 | raise NotImplementedError 299 | 300 | 301 | 302 | i += minibatch_size 303 | 304 | optimizer.step() 305 | 306 | torch.cuda.empty_cache() # Prevent excessive GPU memory consumption by SparseTensors 307 | 308 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import numpy as np 5 | import torch 6 | from pytorch_metric_learning import losses, miners, reducers 7 | from pytorch_metric_learning.distances import LpDistance 8 | 9 | 10 | from models.loss_utils import sigmoid, compute_aff 11 | 12 | from tools.options import Options 13 | args = Options().parse() 14 | from tools.utils import set_seed 15 | set_seed(7) 16 | 17 | 18 | 19 | 20 | 21 | def make_loss(): 22 | if args.loss == 'BatchHardTripletMarginLoss': 23 | # BatchHard mining with triplet margin loss 24 | # Expects input: embeddings, positives_mask, negatives_mask 25 | loss_fn = BatchHardTripletLossWithMasks(args.margin, args.normalize_embeddings) 26 | elif args.loss == 'MultiBatchHardTripletMarginLoss': 27 | # BatchHard mining with triplet margin loss 28 | # Expects input: embeddings, positives_mask, negatives_mask 29 | loss_fn = MultiBatchHardTripletLossWithMasks(args.margin, args.normalize_embeddings, args.weights) 30 | print('MultiBatchHardTripletLossWithMasks') 31 | print('Weights (final/cloud/image): {}'.format(args.weights)) 32 | elif args.loss == 'TruncatedSmoothAP': 33 | loss_fn = TruncatedSmoothAP(tau1=args.ap_tau1, 34 | similarity=args.ap_similarity, 35 | positives_per_query=args.ap_positives_per_query) 36 | else: 37 | raise NotImplementedError 38 | 39 | return loss_fn 40 | 41 | 42 | 43 | 44 | class HardTripletMinerWithMasks: 45 | # Hard triplet miner 46 | def __init__(self, distance): 47 | self.distance = distance 48 | # Stats 49 | self.max_pos_pair_dist = None 50 | self.max_neg_pair_dist = None 51 | self.mean_pos_pair_dist = None 52 | self.mean_neg_pair_dist = None 53 | self.min_pos_pair_dist = None 54 | self.min_neg_pair_dist = None 55 | 56 | def __call__(self, embeddings, positives_mask, negatives_mask): 57 | assert embeddings.dim() == 2 58 | d_embeddings = embeddings.detach() 59 | with torch.no_grad(): 60 | hard_triplets = self.mine(d_embeddings, positives_mask, negatives_mask) 61 | return hard_triplets 62 | 63 | def mine(self, embeddings, positives_mask, negatives_mask): 64 | # Based on pytorch-metric-learning implementation 65 | dist_mat = self.distance(embeddings) 66 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask) 67 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 68 | a_keep_idx = torch.where(a1p_keep & a2n_keep) 69 | a = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx] 70 | p = hardest_positive_indices[a_keep_idx] 71 | n = hardest_negative_indices[a_keep_idx] 72 | self.max_pos_pair_dist = torch.max(hardest_positive_dist).item() 73 | self.max_neg_pair_dist = torch.max(hardest_negative_dist).item() 74 | self.mean_pos_pair_dist = torch.mean(hardest_positive_dist).item() 75 | self.mean_neg_pair_dist = torch.mean(hardest_negative_dist).item() 76 | self.min_pos_pair_dist = torch.min(hardest_positive_dist).item() 77 | self.min_neg_pair_dist = torch.min(hardest_negative_dist).item() 78 | return a, p, n 79 | 80 | 81 | def get_max_per_row(mat, mask): 82 | non_zero_rows = torch.any(mask, dim=1) 83 | mat_masked = mat.clone() 84 | mat_masked[~mask] = 0 85 | return torch.max(mat_masked, dim=1), non_zero_rows 86 | 87 | 88 | def get_min_per_row(mat, mask): 89 | non_inf_rows = torch.any(mask, dim=1) 90 | mat_masked = mat.clone() 91 | mat_masked[~mask] = float('inf') 92 | return torch.min(mat_masked, dim=1), non_inf_rows 93 | 94 | 95 | class MultiBatchHardTripletLossWithMasks: 96 | def __init__(self, margin, normalize_embeddings, weights): 97 | assert len(weights) == 3 98 | self.weights = weights 99 | self.final_loss = BatchHardTripletLossWithMasksHelper(margin, normalize_embeddings) 100 | self.cloud_loss = BatchHardTripletLossWithMasksHelper(margin, normalize_embeddings) 101 | self.image_loss = BatchHardTripletLossWithMasksHelper(margin, normalize_embeddings) 102 | 103 | 104 | 105 | def __call__(self, x, positives_mask, negatives_mask): 106 | # Loss on the final global descriptor 107 | final_loss, final_stats, final_hard_triplets = self.final_loss(x['embedding'], positives_mask, negatives_mask) 108 | final_stats = {'final_{}'.format(e): final_stats[e] for e in final_stats} 109 | 110 | loss = 0. 111 | 112 | stats = final_stats 113 | if self.weights[0] > 0.: 114 | loss = self.weights[0] * final_loss + loss 115 | 116 | # Loss on the cloud-based descriptor 117 | if 'cloud_embedding' in x: 118 | cloud_loss, cloud_stats, _ = self.cloud_loss(x['cloud_embedding'], positives_mask, negatives_mask) 119 | cloud_stats = {'cloud_{}'.format(e): cloud_stats[e] for e in cloud_stats} 120 | stats.update(cloud_stats) 121 | if self.weights[1] > 0.: 122 | loss = self.weights[1] * cloud_loss + loss 123 | 124 | # Loss on the image-based descriptor 125 | if 'image_embedding' in x: 126 | image_loss, image_stats, _ = self.image_loss(x['image_embedding'], positives_mask, negatives_mask) 127 | image_stats = {'image_{}'.format(e): image_stats[e] for e in image_stats} 128 | stats.update(image_stats) 129 | if self.weights[2] > 0.: 130 | loss = self.weights[2] * image_loss + loss 131 | 132 | stats['loss'] = loss.item() 133 | return loss, stats, None 134 | 135 | 136 | 137 | 138 | 139 | 140 | class BatchHardTripletLossWithMasks: 141 | def __init__(self, margin, normalize_embeddings): 142 | self.loss_fn = BatchHardTripletLossWithMasksHelper(margin, normalize_embeddings) 143 | 144 | def __call__(self, x, positives_mask, negatives_mask): 145 | embeddings = x['embedding'] 146 | return self.loss_fn(embeddings, positives_mask, negatives_mask) 147 | 148 | 149 | 150 | 151 | 152 | 153 | class BatchHardTripletLossWithMasksHelper: 154 | def __init__(self, margin, normalize_embeddings): 155 | self.margin = margin 156 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings, collect_stats=True) 157 | # We use triplet loss with Euclidean distance 158 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 159 | reducer_fn = reducers.AvgNonZeroReducer(collect_stats=True) 160 | self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=True, distance=self.distance, 161 | reducer=reducer_fn, collect_stats=True) 162 | 163 | def __call__(self, embeddings, positives_mask, negatives_mask): 164 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 165 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 166 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 167 | stats = {'loss': loss.item(), 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 168 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter, 169 | 'num_triplets': len(hard_triplets[0]), 170 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 171 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 172 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 173 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 174 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 175 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist, 176 | 'normalized_loss': loss.item() * self.loss_fn.reducer.triplets_past_filter, 177 | # total loss per batch 178 | 'total_loss': self.loss_fn.reducer.loss * self.loss_fn.reducer.triplets_past_filter 179 | } 180 | 181 | return loss, stats, hard_triplets 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | class TruncatedSmoothAP: 190 | def __init__(self, tau1: float = 0.01, similarity: str = 'cosine', positives_per_query: int = 4): 191 | # We reversed the notation compared to the paper (tau1 is sigmoid on similarity differences) 192 | # tau1: sigmoid temperature applied on similarity differences 193 | # positives_per_query: number of positives per query to consider 194 | # negatives_only: if True in denominator we consider positives and negatives; if False we consider all elements 195 | # (with except to the anchor itself) 196 | 197 | self.tau1 = tau1 198 | self.similarity = similarity 199 | self.positives_per_query = positives_per_query 200 | 201 | def __call__(self, embeddings, positives_mask, negatives_mask): 202 | embeddings = embeddings['embedding'] 203 | 204 | device = embeddings.device 205 | 206 | positives_mask = positives_mask.to(device) 207 | negatives_mask = negatives_mask.to(device) 208 | 209 | # Ranking of the retrieval set 210 | # For each element we ignore elements that are neither positives nor negatives 211 | 212 | # Compute cosine similarity scores 213 | # 1st dimension corresponds to q, 2nd dimension to z 214 | s_qz = compute_aff(embeddings, similarity=self.similarity) 215 | 216 | # Find the positives_per_query closest positives for each query 217 | s_positives = s_qz.detach().clone() 218 | s_positives.masked_fill_(torch.logical_not(positives_mask), np.NINF) 219 | #closest_positives_ndx = torch.argmax(s_positives, dim=1).view(-1, 1) # Indices of closests positives for each query 220 | closest_positives_ndx = torch.topk(s_positives, k=self.positives_per_query, dim=1, largest=True, sorted=True)[1] 221 | # closest_positives_ndx is (batch_size, positives_per_query) with positives_per_query closest positives 222 | # per each batch element 223 | 224 | n_positives = positives_mask.sum(dim=1) # Number of positives for each anchor 225 | 226 | # Compute the rank of each example x with respect to query element q as per Eq. (2) 227 | s_diff = s_qz.unsqueeze(1) - s_qz.gather(1, closest_positives_ndx).unsqueeze(2) 228 | s_sigmoid = sigmoid(s_diff, temp=self.tau1) 229 | 230 | # Compute the nominator in Eq. 2 and 5 - for q compute the ranking of each of its positives with respect to other positives of q 231 | # Filter out z not in Positives 232 | pos_mask = positives_mask.unsqueeze(1) 233 | pos_s_sigmoid = s_sigmoid * pos_mask 234 | 235 | # Filter out z on the same position as the positive (they have value = 0.5, as the similarity difference is zero) 236 | mask = torch.ones_like(pos_s_sigmoid).scatter(2, closest_positives_ndx.unsqueeze(2), 0.) 237 | pos_s_sigmoid = pos_s_sigmoid * mask 238 | 239 | # Compute the rank for each query and its positives_per_query closest positive examples with respect to other positives 240 | r_p = torch.sum(pos_s_sigmoid, dim=2) + 1. 241 | # r_p is (batch_size, positives_per_query) matrix 242 | 243 | # Consider only positives and negatives in the denominator 244 | # Compute the denominator in Eq. 5 - add sum of Indicator function for negatives (or non-positives) 245 | neg_mask = negatives_mask.unsqueeze(1) 246 | neg_s_sigmoid = s_sigmoid * neg_mask 247 | r_omega = r_p + torch.sum(neg_s_sigmoid, dim=2) 248 | 249 | # Compute R(i, S_p) / R(i, S_omega) ration in Eq. 2 250 | r = r_p / r_omega 251 | 252 | # Compute metrics mean ranking of the positive example, recall@1 253 | stats = {} 254 | # Mean number of positives per query 255 | stats['positives_per_query'] = n_positives.float().mean(dim=0).item() 256 | # Mean ranking of selected positive examples (closests positives) 257 | temp = s_diff.detach() > 0 258 | temp = torch.logical_and(temp[:, 0], negatives_mask) # Take the best positive 259 | hard_ranking = temp.sum(dim=1) 260 | stats['best_positive_ranking'] = hard_ranking.float().mean(dim=0).item() 261 | # Recall at 1 262 | stats['recall'] = {1: (hard_ranking <= 1).float().mean(dim=0).item()} 263 | 264 | # r is (N, positives_per_query) tensor 265 | # Zero entries not corresponding to real positives - this happens when the number of true positives is lower than positives_per_query 266 | valid_positives_mask = torch.gather(positives_mask, 1, closest_positives_ndx) # () tensor 267 | masked_r = r * valid_positives_mask 268 | n_valid_positives = valid_positives_mask.sum(dim=1) 269 | 270 | # Filter out rows (queries) without any positive to avoid division by zero 271 | valid_q_mask = n_valid_positives > 0 272 | masked_r = masked_r[valid_q_mask] 273 | 274 | ap = (masked_r.sum(dim=1) / n_valid_positives[valid_q_mask]).mean() 275 | loss = 1. - ap 276 | 277 | stats['loss'] = loss.item() 278 | stats['ap'] = ap.item() 279 | stats['avg_embedding_norm'] = embeddings.norm(dim=1).mean().item() 280 | 281 | return loss, stats, None 282 | -------------------------------------------------------------------------------- /network/resnetfpn_simple.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.models as models 5 | import numpy as np 6 | import torch.nn.functional as F 7 | 8 | 9 | 10 | from tools.options import Options 11 | args = Options().parse() 12 | from tools.utils import set_seed 13 | set_seed(7) 14 | 15 | 16 | 17 | class ResnetFPNSimple(torch.nn.Module): 18 | def __init__(self, out_channels: int, lateral_dim: int, layers=[64, 64, 128, 256, 512], fh_num_bottom_up: int = 5, 19 | fh_num_top_down: int = 2, add_fc_block: bool = False, pool_method='gem'): 20 | # Pooling types: GeM, sum-pooled convolution (SPoC), maximum activations of convolutions (MAC) 21 | super().__init__() 22 | assert 0 < fh_num_bottom_up <= 5 23 | assert 0 <= fh_num_top_down < fh_num_bottom_up 24 | 25 | self.out_channels = out_channels 26 | self.lateral_dim = lateral_dim 27 | self.fh_num_bottom_up = fh_num_bottom_up 28 | self.fh_num_top_down = fh_num_top_down 29 | self.add_fc_block = add_fc_block 30 | self.layers = layers # Number of channels in output from each ResNet block 31 | self.pool_method = pool_method.lower() 32 | # model = models.resnet18(pretrained=True) 33 | model = models.resnet34(pretrained=True) 34 | # Last 2 blocks are AdaptiveAvgPool2d and Linear (get rid of them) 35 | self.resnet_fe = nn.ModuleList(list(model.children())[:3+self.fh_num_bottom_up]) 36 | # self.resnet_fe = list(model.children()) 37 | 38 | # print(self.resnet_fe) 39 | 40 | 41 | # Lateral connections and top-down pass for the feature extraction head 42 | self.fh_tconvs = nn.ModuleDict() # Top-down transposed convolutions in feature head 43 | self.fh_conv1x1 = nn.ModuleDict() # 1x1 convolutions in lateral connections to the feature head 44 | for i in range(self.fh_num_bottom_up - self.fh_num_top_down, self.fh_num_bottom_up): 45 | self.fh_conv1x1[str(i + 1)] = nn.Conv2d(in_channels=layers[i], out_channels=self.lateral_dim, kernel_size=1) 46 | self.fh_tconvs[str(i + 1)] = torch.nn.ConvTranspose2d(in_channels=self.lateral_dim, 47 | out_channels=self.lateral_dim, 48 | kernel_size=2, stride=2) 49 | 50 | # One more lateral connection 51 | temp = self.fh_num_bottom_up - self.fh_num_top_down 52 | self.fh_conv1x1[str(temp)] = nn.Conv2d(in_channels=layers[temp-1], out_channels=self.lateral_dim, kernel_size=1) 53 | 54 | # Pooling types: GeM, sum-pooled convolution (SPoC), maximum activations of convolutions (MAC) 55 | if self.pool_method == 'gem': 56 | self.pool = ImageGeM() 57 | elif self.pool_method == 'spoc': 58 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 59 | elif self.pool_method == 'max': 60 | self.pool = nn.AdaptiveMaxPool2d((1, 1)) 61 | else: 62 | raise NotImplementedError("Unknown pooling method: {}".format(self.pool_method)) 63 | 64 | if self.add_fc_block: 65 | self.fc = torch.nn.Linear(in_features=self.lateral_dim, out_features=self.out_channels) 66 | 67 | def forward(self, batch): 68 | output_dict = {} 69 | x = batch['images'] 70 | feature_maps = {} 71 | 72 | # 0, 1, 2, 3 = first layers: Conv2d, BatchNorm, ReLu, MaxPool2d 73 | x = self.resnet_fe[0](x) 74 | x = self.resnet_fe[1](x) 75 | x = self.resnet_fe[2](x) 76 | x = self.resnet_fe[3](x) 77 | feature_maps["1"] = x 78 | 79 | # sequential blocks, build from BasicBlock or Bottleneck blocks 80 | for i in range(4, self.fh_num_bottom_up+3): 81 | x = self.resnet_fe[i](x) 82 | feature_maps[str(i-2)] = x 83 | 84 | assert len(feature_maps) == self.fh_num_bottom_up 85 | # x is (batch_size, 512, H=20, W=15) for 640x480 input image 86 | 87 | # FEATURE HEAD TOP-DOWN PASS 88 | xf = self.fh_conv1x1[str(self.fh_num_bottom_up)](feature_maps[str(self.fh_num_bottom_up)]) 89 | for i in range(self.fh_num_bottom_up, self.fh_num_bottom_up - self.fh_num_top_down, -1): 90 | xf = self.fh_tconvs[str(i)](xf) # Upsample using transposed convolution 91 | xf = xf + self.fh_conv1x1[str(i-1)](feature_maps[str(i - 1)]) 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | return xf 101 | 102 | 103 | 104 | 105 | 106 | class ImageGeM(nn.Module): 107 | def __init__(self, p=3, eps=1e-6): 108 | super(ImageGeM, self).__init__() 109 | self.p = nn.Parameter(torch.ones(1) * p) 110 | self.eps = eps 111 | 112 | def forward(self, x): 113 | assert len(x.shape) == 4 114 | output = nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 115 | 116 | 117 | b,c,h,w = output.shape 118 | assert [h,w]==[1,1] 119 | 120 | 121 | output = output.view(b,c) 122 | 123 | 124 | return output 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | # ---------------------------- CosPlace ---------------------------- 135 | class GeM(nn.Module): 136 | """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch 137 | """ 138 | def __init__(self, p=3, eps=1e-6): 139 | super().__init__() 140 | self.p = nn.Parameter(torch.ones(1)*p) 141 | self.eps = eps 142 | 143 | def forward(self, x): 144 | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 145 | 146 | class ImageCosPlace(nn.Module): 147 | """ 148 | CosPlace aggregation layer as implemented in https://github.com/gmberton/CosPlace/blob/main/model/network.py 149 | 150 | Args: 151 | in_dim: number of channels of the input 152 | out_dim: dimension of the output descriptor 153 | """ 154 | def __init__(self, in_dim, out_dim): 155 | super().__init__() 156 | self.gem = GeM() 157 | self.fc = nn.Linear(in_dim, out_dim) 158 | 159 | def forward(self, x): 160 | x = F.normalize(x, p=2, dim=1) 161 | x = self.gem(x) 162 | x = x.flatten(1) 163 | x = self.fc(x) 164 | x = F.normalize(x, p=2, dim=1) 165 | return x 166 | 167 | 168 | 169 | 170 | # ---------------------------- MixVPR ---------------------------- 171 | class FeatureMixerLayer(nn.Module): 172 | def __init__(self, in_dim, mlp_ratio=1): 173 | super().__init__() 174 | self.mix = nn.Sequential( 175 | nn.LayerNorm(in_dim), 176 | nn.Linear(in_dim, int(in_dim * mlp_ratio)), 177 | nn.ReLU(), 178 | nn.Linear(int(in_dim * mlp_ratio), in_dim), 179 | ) 180 | 181 | for m in self.modules(): 182 | if isinstance(m, (nn.Linear)): 183 | nn.init.trunc_normal_(m.weight, std=0.02) 184 | if m.bias is not None: 185 | nn.init.zeros_(m.bias) 186 | 187 | def forward(self, x): 188 | return x + self.mix(x) 189 | 190 | 191 | class ImageMixVPR(nn.Module): 192 | def __init__(self, 193 | in_channels=1024, 194 | in_h=20, 195 | in_w=20, 196 | out_channels=512, 197 | mix_depth=1, 198 | mlp_ratio=1, 199 | out_rows=4, 200 | ) -> None: 201 | super().__init__() 202 | 203 | self.in_h = in_h # height of input feature maps 204 | self.in_w = in_w # width of input feature maps 205 | self.in_channels = in_channels # depth of input feature maps 206 | 207 | self.out_channels = out_channels # depth wise projection dimension 208 | self.out_rows = out_rows # row wise projection dimesion 209 | 210 | self.mix_depth = mix_depth # L the number of stacked FeatureMixers 211 | self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block 212 | 213 | hw = in_h*in_w 214 | self.mix = nn.Sequential(*[ 215 | FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio) 216 | for _ in range(self.mix_depth) 217 | ]) 218 | self.channel_proj = nn.Linear(in_channels, out_channels) 219 | self.row_proj = nn.Linear(hw, out_rows) 220 | 221 | def forward(self, x): 222 | x = x.flatten(2) 223 | x = self.mix(x) 224 | x = x.permute(0, 2, 1) 225 | x = self.channel_proj(x) 226 | x = x.permute(0, 2, 1) 227 | x = self.row_proj(x) 228 | x = F.normalize(x.flatten(1), p=2, dim=-1) 229 | return x 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | # ---------------------------- ConvAP ---------------------------- 240 | 241 | class ImageConvAP(nn.Module): 242 | """Implementation of ConvAP as of https://arxiv.org/pdf/2210.10239.pdf 243 | 244 | Args: 245 | in_channels (int): number of channels in the input of ConvAP 246 | out_channels (int, optional): number of channels that ConvAP outputs. Defaults to 512. 247 | s1 (int, optional): spatial height of the adaptive average pooling. Defaults to 2. 248 | s2 (int, optional): spatial width of the adaptive average pooling. Defaults to 2. 249 | """ 250 | def __init__(self, in_channels, out_channels=512, s1=2, s2=2): 251 | super(ImageConvAP, self).__init__() 252 | self.channel_pool = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=True) 253 | self.AAP = nn.AdaptiveAvgPool2d((s1, s2)) 254 | 255 | def forward(self, x): 256 | x = self.channel_pool(x) 257 | x = self.AAP(x) 258 | x = F.normalize(x.flatten(1), p=2, dim=1) 259 | return x 260 | 261 | 262 | 263 | def print_nb_params(m): 264 | model_parameters = filter(lambda p: p.requires_grad, m.parameters()) 265 | params = sum([np.prod(p.size()) for p in model_parameters]) 266 | print(f'Trainable parameters: {params/1e6:.3}M') 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | # ---------------------------- NetVLAD ---------------------------- 278 | # based on https://github.com/lyakaap/NetVLAD-pytorch/blob/master/netvlad.py 279 | class ImageNetVLAD(nn.Module): 280 | """NetVLAD layer implementation""" 281 | 282 | def __init__(self, clusters_num=64, dim=128, normalize_input=True, work_with_tokens=False): 283 | """ 284 | Args: 285 | clusters_num : int 286 | The number of clusters 287 | dim : int 288 | Dimension of descriptors 289 | alpha : float 290 | Parameter of initialization. Larger value is harder assignment. 291 | normalize_input : bool 292 | If true, descriptor-wise L2 normalization is applied to input. 293 | """ 294 | super().__init__() 295 | self.clusters_num = clusters_num 296 | self.dim = dim 297 | self.alpha = 0 298 | self.normalize_input = normalize_input 299 | self.work_with_tokens = work_with_tokens 300 | if work_with_tokens: 301 | self.conv = nn.Conv1d(dim, clusters_num, kernel_size=1, bias=False) 302 | else: 303 | self.conv = nn.Conv2d(dim, clusters_num, kernel_size=(1, 1), bias=False) 304 | self.centroids = nn.Parameter(torch.rand(clusters_num, dim)) 305 | 306 | def init_params(self, centroids, descriptors): 307 | centroids_assign = centroids / np.linalg.norm(centroids, axis=1, keepdims=True) 308 | dots = np.dot(centroids_assign, descriptors.T) 309 | dots.sort(0) 310 | dots = dots[::-1, :] # sort, descending 311 | 312 | self.alpha = (-np.log(0.01) / np.mean(dots[0,:] - dots[1,:])).item() 313 | self.centroids = nn.Parameter(torch.from_numpy(centroids)) 314 | if self.work_with_tokens: 315 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha * centroids_assign).unsqueeze(2)) 316 | else: 317 | self.conv.weight = nn.Parameter(torch.from_numpy(self.alpha*centroids_assign).unsqueeze(2).unsqueeze(3)) 318 | self.conv.bias = None 319 | 320 | def forward(self, x): 321 | if self.work_with_tokens: 322 | x = x.permute(0, 2, 1) 323 | N, D, _ = x.shape[:] 324 | else: 325 | N, D, H, W = x.shape[:] 326 | if self.normalize_input: 327 | x = F.normalize(x, p=2, dim=1) # Across descriptor dim 328 | x_flatten = x.view(N, D, -1) 329 | soft_assign = self.conv(x).view(N, self.clusters_num, -1) 330 | soft_assign = F.softmax(soft_assign, dim=1) 331 | vlad = torch.zeros([N, self.clusters_num, D], dtype=x_flatten.dtype, device=x_flatten.device) 332 | for D in range(self.clusters_num): # Slower than non-looped, but lower memory usage 333 | residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \ 334 | self.centroids[D:D+1, :].expand(x_flatten.size(-1), -1, -1).permute(1, 2, 0).unsqueeze(0) 335 | residual = residual * soft_assign[:,D:D+1,:].unsqueeze(2) 336 | vlad[:,D:D+1,:] = residual.sum(dim=-1) 337 | vlad = F.normalize(vlad, p=2, dim=2) # intra-normalization 338 | vlad = vlad.view(N, -1) # Flatten 339 | vlad = F.normalize(vlad, p=2, dim=1) # L2 normalize 340 | return vlad 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | if __name__ == '__main__': 352 | x = torch.randn(4, 2048, 10, 10) 353 | m = ImageCosPlace(2048, 512) 354 | r = m(x) 355 | print(r.shape) 356 | 357 | 358 | 359 | x = torch.randn(4, 2048, 10, 10) 360 | m = ImageConvAP(2048, 512) 361 | r = m(x) 362 | print(r.shape) 363 | 364 | 365 | 366 | x = torch.randn(1, 1024, 20, 20) 367 | agg = ImageMixVPR( 368 | in_channels=1024, 369 | in_h=20, 370 | in_w=20, 371 | out_channels=1024, 372 | mix_depth=4, 373 | mlp_ratio=1, 374 | out_rows=4) 375 | 376 | print_nb_params(agg) 377 | output = agg(x) 378 | print(output.shape) 379 | 380 | 381 | -------------------------------------------------------------------------------- /models/minkloc_multimodal.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | # Model processing LiDAR point clouds and RGB images 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision.models as TVmodels 9 | from network.image_pool_fns import ImageGeM 10 | from network.image_pool_fns import ImageCosPlace 11 | from network.image_pool_fns import ImageNetVLAD 12 | from network.image_pool_fns import ImageConvAP 13 | 14 | import torch.nn.functional as F 15 | 16 | from tools.utils import set_seed 17 | set_seed(7) 18 | 19 | class MinkLocMultimodal(torch.nn.Module): 20 | def __init__(self, cloud_fe, cloud_fe_size, image_fe, image_fe_size, 21 | fuse_method, dropout_p: float = None, final_block: str = None): 22 | super().__init__() 23 | 24 | 25 | 26 | self.cloud_fe = cloud_fe 27 | 28 | 29 | self.image_fe = image_fe 30 | 31 | 32 | self.fuse_method = fuse_method 33 | 34 | 35 | 36 | 37 | def forward(self, batch): 38 | y = {} 39 | if self.image_fe is not None: 40 | image_embedding, imagefe_output_dict = self.image_fe(batch) 41 | assert image_embedding.dim() == 2 42 | y['image_embedding'] = image_embedding 43 | for _k, _v in imagefe_output_dict.items(): 44 | y[_k] = _v 45 | 46 | if self.cloud_fe is not None: 47 | cloud_embedding = self.cloud_fe(batch)['embedding'] 48 | assert cloud_embedding.dim() == 2 49 | y['cloud_embedding'] = cloud_embedding 50 | 51 | 52 | assert cloud_embedding.shape[0] == image_embedding.shape[0] 53 | 54 | 55 | if self.fuse_method == 'cat': 56 | x = torch.cat([cloud_embedding, image_embedding], dim=1) 57 | elif self.fuse_method == 'add': 58 | assert cloud_embedding.shape == image_embedding.shape 59 | x = cloud_embedding + image_embedding 60 | else: 61 | raise NotImplementedError 62 | 63 | 64 | y['embedding'] = x 65 | 66 | 67 | return y 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | class ResnetFPN(torch.nn.Module): 80 | def __init__(self, out_channels: int, lateral_dim: int, layers=[64, 64, 128, 256, 512], fh_num_bottom_up: int = 5, 81 | fh_num_top_down: int = 2, add_fc_block: bool = False, pool_method='gem', 82 | add_basicblock: bool = True, num_basicblocks: int = 2): 83 | # Pooling types: GeM, sum-pooled convolution (SPoC), maximum activations of convolutions (MAC) 84 | super().__init__() 85 | assert 0 < fh_num_bottom_up <= 5 86 | assert 0 <= fh_num_top_down < fh_num_bottom_up 87 | 88 | self.out_channels = out_channels 89 | self.lateral_dim = lateral_dim 90 | self.fh_num_bottom_up = fh_num_bottom_up 91 | self.fh_num_top_down = fh_num_top_down 92 | self.add_fc_block = add_fc_block 93 | self.layers = layers # Number of channels in output from each ResNet block 94 | self.pool_method = pool_method.lower() 95 | self.add_basicblock = add_basicblock 96 | self.num_basicblocks = num_basicblocks 97 | 98 | 99 | model = TVmodels.resnet18(pretrained=True) 100 | 101 | 102 | 103 | # Last 2 blocks are AdaptiveAvgPool2d and Linear (get rid of them) 104 | self.resnet_fe = nn.ModuleList(list(model.children())[:3+self.fh_num_bottom_up]) 105 | 106 | # Lateral connections and top-down pass for the feature extraction head 107 | self.fh_conv1x1 = nn.ModuleDict() # 1x1 convolutions in lateral connections to the feature head 108 | self.fh_tconvs = nn.ModuleDict() # Top-down transposed convolutions in feature head 109 | self.fh_tbasicblocks = nn.ModuleDict() # Top-down basic blocks in feature head 110 | for i in range(self.fh_num_bottom_up - self.fh_num_top_down, self.fh_num_bottom_up): 111 | self.fh_conv1x1[str(i + 1)] = nn.Conv2d(in_channels=layers[i], out_channels=self.lateral_dim, kernel_size=1) 112 | self.fh_tconvs[str(i + 1)] = torch.nn.ConvTranspose2d(in_channels=self.lateral_dim, 113 | out_channels=self.lateral_dim, 114 | kernel_size=2, stride=2) 115 | 116 | 117 | 118 | # One more lateral connection 119 | temp = self.fh_num_bottom_up - self.fh_num_top_down 120 | self.fh_conv1x1[str(temp)] = nn.Conv2d(in_channels=layers[temp-1], out_channels=self.lateral_dim, kernel_size=1) 121 | 122 | 123 | 124 | # Pooling types: GeM, sum-pooled convolution (SPoC), maximum activations of convolutions (MAC) 125 | if self.pool_method == 'gem': 126 | self.pool = GeM() 127 | elif self.pool_method == 'spoc': 128 | self.pool = nn.AdaptiveAvgPool2d((1, 1)) 129 | elif self.pool_method == 'max': 130 | self.pool = nn.AdaptiveMaxPool2d((1, 1)) 131 | else: 132 | raise NotImplementedError("Unknown pooling method: {}".format(self.pool_method)) 133 | 134 | if self.add_fc_block: 135 | self.fc = torch.nn.Linear(in_features=self.lateral_dim, out_features=self.out_channels) 136 | 137 | def forward(self, batch): 138 | x = batch['images'] 139 | feature_maps = {} 140 | 141 | # 0, 1, 2, 3 = first layers: Conv2d, BatchNorm, ReLu, MaxPool2d 142 | x = self.resnet_fe[0](x) 143 | x = self.resnet_fe[1](x) 144 | x = self.resnet_fe[2](x) 145 | x = self.resnet_fe[3](x) 146 | feature_maps["1"] = x 147 | 148 | # sequential blocks, build from BasicBlock or Bottleneck blocks 149 | for i in range(4, self.fh_num_bottom_up+3): 150 | x = self.resnet_fe[i](x) 151 | feature_maps[str(i-2)] = x 152 | 153 | assert len(feature_maps) == self.fh_num_bottom_up 154 | # x is (batch_size, 512, H=20, W=15) for 640x480 input image 155 | 156 | # FEATURE HEAD TOP-DOWN PASS 157 | xf = self.fh_conv1x1[str(self.fh_num_bottom_up)](feature_maps[str(self.fh_num_bottom_up)]) 158 | if self.add_basicblock: 159 | xf = self.fh_tbasicblocks[str(self.fh_num_bottom_up)](xf) 160 | 161 | for i in range(self.fh_num_bottom_up, self.fh_num_bottom_up - self.fh_num_top_down, -1): 162 | xf = self.fh_tconvs[str(i)](xf) # Upsample using transposed convolution 163 | xf = xf + self.fh_conv1x1[str(i-1)](feature_maps[str(i - 1)]) 164 | 165 | x = self.pool(xf) 166 | # x is (batch_size, 512, 1, 1) tensor 167 | 168 | x = torch.flatten(x, 1) 169 | # x is (batch_size, 512) tensor 170 | 171 | if self.add_fc_block: 172 | x = self.fc(x) 173 | 174 | # (batch_size, feature_size) 175 | assert x.shape[1] == self.out_channels 176 | 177 | 178 | return x 179 | 180 | 181 | # GeM code adapted from: https://github.com/filipradenovic/cnnimageretrieval-pytorch 182 | 183 | class GeM(nn.Module): 184 | def __init__(self, p=3, eps=1e-6): 185 | super(GeM, self).__init__() 186 | self.p = nn.Parameter(torch.ones(1) * p) 187 | self.eps = eps 188 | 189 | def forward(self, x): 190 | return nn.functional.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | # ---------------------------------- ResNetFPNv2 ---------------------------------- 204 | class ResNetFPNv2(torch.nn.Module): 205 | def __init__(self, 206 | image_fe, 207 | image_pool_method, # GeM 208 | image_useallstages, # True 209 | output_dim, 210 | ): 211 | super().__init__() 212 | ''' 213 | resnet [64,64,128,256,512] 214 | convnext [96,96,192,384,768] 215 | swin [96,96,192,384,768] 216 | swin_v2 [96,96,192,384,768] 217 | ''' 218 | 219 | 220 | self.image_fe = image_fe 221 | self.image_pool_method = image_pool_method 222 | 223 | 224 | 225 | 226 | self.image_useallstages = image_useallstages 227 | 228 | # -- resnet 229 | if self.image_fe == 'resnet18': 230 | self.model = TVmodels.resnet18(weights='IMAGENET1K_V1') 231 | if self.image_useallstages: 232 | self.last_dim = 512 233 | else: 234 | self.last_dim = 256 235 | elif self.image_fe == 'resnet34': 236 | self.model = TVmodels.resnet34(weights='IMAGENET1K_V1') 237 | if self.image_useallstages: 238 | self.last_dim = 512 239 | else: 240 | self.last_dim = 256 241 | elif self.image_fe == 'resnet50': 242 | self.model = TVmodels.resnet50(weights='IMAGENET1K_V2') 243 | if self.image_useallstages: 244 | self.last_dim = 2048 245 | else: 246 | self.last_dim = 1024 247 | elif self.image_fe == 'resnet101': 248 | self.model = TVmodels.resnet101(weights='IMAGENET1K_V2') 249 | if self.image_useallstages: 250 | self.last_dim = 2048 251 | else: 252 | self.last_dim = 1024 253 | elif self.image_fe == 'resnet152': 254 | self.model = TVmodels.resnet152(weights='IMAGENET1K_V2') 255 | if self.image_useallstages: 256 | self.last_dim = 2048 257 | else: 258 | self.last_dim = 1024 259 | 260 | 261 | # -- convnext 262 | elif self.image_fe == 'convnext_tiny': 263 | self.model = TVmodels.convnext_tiny(weights='IMAGENET1K_V1') 264 | if self.image_useallstages: 265 | self.last_dim = 768 266 | else: 267 | self.last_dim = 384 268 | elif self.image_fe == 'convnext_small': 269 | self.model = TVmodels.convnext_small(weights='IMAGENET1K_V1') 270 | if self.image_useallstages: 271 | self.last_dim = 768 272 | else: 273 | self.last_dim = 384 274 | 275 | 276 | # -- swin 277 | elif self.image_fe == 'swin_t': 278 | self.model = TVmodels.swin_t(weights='IMAGENET1K_V1') 279 | if self.image_useallstages: 280 | self.last_dim = 768 281 | else: 282 | self.last_dim = 384 283 | elif self.image_fe == 'swin_s': 284 | self.model = TVmodels.swin_s(weights='IMAGENET1K_V1') 285 | self.last_dim = 384 286 | elif self.image_fe == 'swin_v2_t': 287 | self.model = TVmodels.swin_v2_t(weights='IMAGENET1K_V1') 288 | if self.image_useallstages: 289 | self.last_dim = 768 290 | else: 291 | self.last_dim = 384 292 | elif self.image_fe == 'swin_v2_s': 293 | self.model = TVmodels.swin_v2_s(weights='IMAGENET1K_V1') 294 | self.last_dim = 384 295 | 296 | 297 | 298 | 299 | 300 | self.conv1x1 = nn.Conv2d(self.last_dim, output_dim, kernel_size=1) 301 | 302 | 303 | self.image_gem = ImageGeM() # *1 304 | self.imagecosplace = ImageCosPlace(output_dim, output_dim) # *1 305 | self.imageconvap = ImageConvAP(output_dim, output_dim) # *4 306 | self.imagenetvlad = ImageNetVLAD(clusters_num=64, 307 | dim=output_dim) # *4 308 | 309 | 310 | 311 | def forward_resnet(self, x): 312 | fe_output_dict = {} 313 | x = self.model.conv1(x) 314 | x = self.model.bn1(x) 315 | x = self.model.relu(x) 316 | x = self.model.maxpool(x) 317 | 318 | x = self.model.layer1(x) 319 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 320 | fe_output_dict['image_layer1'] = x 321 | fe_output_dict['image_layer1_avgpool'] = x_avgpool 322 | 323 | x = self.model.layer2(x) 324 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 325 | fe_output_dict['image_layer2'] = x 326 | fe_output_dict['image_layer2_avgpool'] = x_avgpool 327 | 328 | x = self.model.layer3(x) 329 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 330 | fe_output_dict['image_layer3'] = x 331 | fe_output_dict['image_layer3_avgpool'] = x_avgpool 332 | 333 | if self.image_useallstages: 334 | x = self.model.layer4(x) 335 | x_avgpool = F.avg_pool2d(x, kernel_size=x.size()[2:]).squeeze(3).squeeze(2) 336 | fe_output_dict['image_layer4'] = x 337 | fe_output_dict['image_layer4_avgpool'] = x_avgpool 338 | 339 | return x, fe_output_dict 340 | 341 | 342 | 343 | def forward_convnext(self, x): 344 | layers_list = list(self.model.features.children()) 345 | assert len(layers_list)==8 346 | if not self.image_useallstages: 347 | layers_list = layers_list[:-2] 348 | else: 349 | layers_list = layers_list 350 | 351 | for i in range(len(layers_list)): 352 | layer = layers_list[i] 353 | x = layer(x) 354 | return x 355 | 356 | 357 | def forward_swin(self, x): 358 | layers_list = list(self.model.features.children()) 359 | if not self.image_useallstages: 360 | layers_list = layers_list[:-2] 361 | else: 362 | layers_list = layers_list 363 | for i in range(len(layers_list)): 364 | layer = layers_list[i] 365 | x = layer(x) 366 | x = x.permute(0,3,1,2) 367 | return x 368 | 369 | 370 | 371 | 372 | 373 | 374 | def forward(self, data_dict): 375 | 376 | 377 | x = data_dict['images'] 378 | fe_output_dict = {} 379 | 380 | 381 | if self.image_fe in ['resnet18','resnet34','resnet50','resnet101','resnet152']: 382 | x, fe_output_dict = self.forward_resnet(x) 383 | elif self.image_fe in ['convnext_tiny','convnext_small']: 384 | x = self.forward_convnext(x) 385 | elif self.image_fe in ['swin_t','swin_s']: 386 | x = self.forward_swin(x) 387 | elif self.image_fe in ['swin_v2_t','swin_v2_s']: 388 | x = self.forward_swin(x) 389 | else: 390 | raise NotImplementedError 391 | 392 | 393 | x_feat_256 = x 394 | x_feat_256 = self.conv1x1(x_feat_256) 395 | 396 | 397 | if self.image_pool_method == 'GeM': 398 | embedding = self.image_gem(x_feat_256) 399 | 400 | elif self.image_pool_method == 'ConvAP': 401 | embedding = self.imageconvap(x_feat_256) 402 | 403 | elif self.image_pool_method == 'CosPlace': 404 | embedding = self.imagecosplace(x_feat_256) 405 | 406 | elif self.image_pool_method == 'NetVLAD': 407 | embedding = self.imagenetvlad(x_feat_256) 408 | 409 | else: 410 | raise NotImplementedError 411 | 412 | 413 | 414 | return embedding, fe_output_dict 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | --------------------------------------------------------------------------------