├── .gitignore ├── LICENSE ├── README.md ├── config ├── config_baseline.txt ├── config_baseline_gtav5.txt └── models │ └── vlpdnet.txt ├── dataloader ├── dataset_utils.py ├── gta5.py ├── oxford.py └── samplers.py ├── evaluate.py ├── generating_queries ├── generate_test_sets.py ├── generate_test_sets_gtav5.py ├── generate_training_tuples_baseline.py ├── generate_training_tuples_baseline_gtav5.py └── registration.py ├── loss ├── __init__.py ├── d_loss.py ├── metric_loss.py ├── mmd_loss.py ├── ops │ └── emd │ │ ├── LICENSE │ │ ├── README.md │ │ ├── emd.cpp │ │ ├── emd_cuda.cu │ │ ├── emd_module.py │ │ └── setup.py └── reg_loss.py ├── media └── pipeline.png ├── misc ├── log.py ├── point_utils.py └── utils.py ├── models ├── __init__.py ├── discriminator │ ├── __init__.py │ ├── resnet.py │ └── resnetD.py ├── minkloc3d │ ├── __init__.py │ ├── minkfpn.py │ ├── minkpool.py │ ├── pooling.py │ ├── resnet_mink.py │ └── senet_block.py ├── model_factory.py ├── vcrnet │ ├── __init__.py │ ├── transformer.py │ └── vcrnet.py └── vlpdnet │ ├── __init__.py │ ├── lpdnet_model.py │ └── vlpdnet.py ├── train.py ├── training ├── optimizer_factory.py ├── reg_train.py └── trainer.py └── weights └── vlpdnet-registration.t7 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | generating_queries/pickles 10 | 11 | # Distribution / packaging 12 | .Python 13 | build 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # PyCharm settings 133 | .idea 134 | 135 | *.pyc 136 | *.pth 137 | *.pickle 138 | results* 139 | __pycache__/ 140 | checkpoints/ 141 | data/ 142 | /.idea/ 143 | *.zip 144 | *.backup 145 | *.npz 146 | *.npy 147 | *.h5 148 | .fuse* 149 | .~lock* 150 | *.so 151 | *.ckpt.* 152 | *.pyc 153 | *.out.* 154 | /benchmark_datasets/ 155 | benchmark_datasets 156 | *benchmark_datasets* 157 | *.pickle 158 | /tf_logs 159 | *gpu_mem_track.txt 160 | *.ckpt 161 | benchmark_datasets/ 162 | 163 | tf_logs 164 | 165 | *.log 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Intelligent Robotics and Machine Vision Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # vLPD-Net: Registration-aided 3D Point Cloud Learning for Large-Scale Place Recognition 2 | 3 | Our IROS 2021: Registration-aided 3D Point Cloud Learning for Large-Scale Place Recognition. 4 | 5 | [Paper](https://arxiv.org/abs/2012.05018) [Video](https://www.youtube.com/watch?v=Nnj9KUXIFns&t=11s) 6 | 7 | Author: Zhijian Qiao1†, Hanjiang Hu1, 2†, Weiang Shi1, Siyuan Chen1, Zhe Liu1, 3, Hesheng Wang1. 8 | 9 | Shanghai Jiao Tong University1 & Carnegie Mellon University2 & University of Cambridge3 10 | 11 | 12 | ![Overview](media/pipeline.png) 13 | 14 | ### Introduction 15 | This is the our IROS 2021 work. The synthetic virtual dataset has been utilized through GTA-V to involve multi-environment traverses without laborious human efforts, while the accurate registration ground truth can be obtained at the same time. 16 | 17 | To this end, we propose vLPD-Net, which is a novel registration-aided 3D domain adaptation network for point cloud based place recognition. A structure-aware registration network is proposed to leverage geometry property and co-contextual information between two input point clouds. Along with them, the adversarial training pipeline is implemented to reduce the gap between synthetic and real-world domain. 18 | 19 | ### Citation 20 | If you find this work useful, please cite: 21 | ``` 22 | @INPROCEEDINGS{9635878, author={Qiao, Zhijian and Hu, Hanjiang and Shi, Weiang and Chen, Siyuan and Liu, Zhe and Wang, Hesheng}, booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, title={A Registration-aided Domain Adaptation Network for 3D Point Cloud Based Place Recognition}, year={2021}, volume={}, number={}, pages={1317-1322}, doi={10.1109/IROS51168.2021.9635878}} 23 | ``` 24 | 25 | ### Environment and Dependencies 26 | Code was tested using Python 3.8 with PyTorch 1.7.1 and MinkowskiEngine 0.5.0 on Ubuntu 18.04 with CUDA 10.2. 27 | 28 | The following Python packages are required: 29 | * PyTorch (version 1.7) 30 | * MinkowskiEngine (version 0.5.0) 31 | * pytorch_metric_learning (version 0.9.94 or above) 32 | * tensorboard 33 | * pandas 34 | * psutil 35 | * bitarray 36 | 37 | ### Datasets 38 | **vLPD-Net** is trained on a subset of Oxford RobotCar datasets and synthetic virtual dataset built on GTAV5. Processed Oxford Robotcar datasets are available in [PointNetVLAD](https://github.com/mikacuy/pointnetvlad.git). To build the synthetic virtual dataset, we use a plugin [DeepGTAV](https://github.com/aitorzip/DeepGTAV) and its expansion [A-virtual-LiDAR-for-DeepGTAV](https://github.com/gdpinchina/A-virtual-LiDAR-for-DeepGTAV), which provide the vehicle with cameras and LiDAR. Besides, we use [VPilot](https://github.com/aitorzip/VPilot) for better interaction with DeepGTAV. The data collected by the plugins is simply saved in txt files, and then processed and saved with the Oxford Robotcar format. 39 | 40 | 41 | The structure of directory is like this: 42 | ``` 43 | vLPD-Net 44 | └── benchmark_datasets 45 | ├── GTA5 46 | └── oxford 47 | ``` 48 | 49 | To generate pickles, run: 50 | ```generate pickles 51 | cd generating_queries/ 52 | 53 | # Generate training tuples for the Oxford Dataset 54 | python generate_training_tuples_baseline.py 55 | 56 | # Generate evaluation Oxford tuples 57 | python generate_test_sets.py 58 | 59 | # In addition, if you also want to experiment on GTAV5, we have open sourced the code for processing this dataset. We hope it will be helpful to you. 60 | # Generate training tuples for the gtav5 Dataset 61 | python generate_training_tuples_baseline_gtav5.py 62 | 63 | # Generate evaluation gtav5 tuples 64 | python generate_test_sets_gtav5.py 65 | ``` 66 | 67 | ### Training 68 | We train vLPD-Net on one 2080Ti GPU. 69 | 70 | To train the network, run: 71 | 72 | ```train baseline 73 | # To train vLPD-Net model on the Oxford Dataset 74 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python train.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt 75 | 76 | # To train vLPD-Net model on the GTAV5 Dataset 77 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python train.py --config ./config/config_baseline_gtav5.txt --model_config ./config/models/vlpdnet.txt 78 | ``` 79 | 80 | For registration on oxford dataset, we transform point cloud by ourself to generate ground truth. The registration training piplines on two datasets are also a bit different. On oxford dataset, we just use EPCOR in inference while we train with EPCOR on GTAV5 based on a pretrained whole-to-whole registration model like [VCR-Net](https://github.com/qiaozhijian/VCR-Net.git). This is due to the fact that registration on Oxford dataset is naturally whole-to-whole. 81 | ### Pre-trained Models 82 | 83 | Pretrained models are available in `weights` directory 84 | - `vlpdnet-oxford.pth` trained on the Oxford Dataset 85 | - `vlpdnet-gtav5.pth` trained on the GTAV5 Dataset with domain adaptation 86 | - `vlpdnet-registration.t7` trained on the Oxford Dataset for registration. 87 | 88 | ### Evaluation 89 | 90 | To evaluate pretrained models, run the following commands: 91 | 92 | ```eval baseline 93 | 94 | # To evaluate the model trained on the Oxford Dataset 95 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --weights=./weights/vlpdnet-oxford.pth 96 | 97 | # To evaluate the model trained on the GTAV5 Dataset 98 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --weights=./weights/vlpdnet-gtav5.pth 99 | 100 | # To evaluate the model trained on the Oxford Dataset for registration 101 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --eval_reg=./weights/vlpdnet-registration.t7 102 | ``` 103 | 104 | ### Acknowledgment 105 | 106 | [MinkLoc3D](https://github.com/jac99/MinkLoc3D.git) 107 | 108 | [VCR-Net](https://github.com/qiaozhijian/VCR-Net.git) 109 | 110 | [LPD-Net-Pytorch](https://github.com/qiaozhijian/LPD-Net-Pytorch.git) 111 | 112 | ### License 113 | Our code is released under the MIT License (see LICENSE file for details). 114 | -------------------------------------------------------------------------------- /config/config_baseline.txt: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | num_points = 4096 3 | 4 | dataset_folder = ./benchmark_datasets 5 | queries_folder = ./generating_queries/pickles 6 | 7 | [TRAIN] 8 | num_workers = 16 9 | batch_size = 56 10 | batch_size_limit = 56 11 | batch_expansion_rate = 1.4 12 | batch_expansion_th = 0.7 13 | 14 | lr = 1e-3 15 | epochs = 50 16 | scheduler_milestones = 40 17 | 18 | aug_mode = 2 19 | weight_decay = 1e-3 20 | eval_simple = True 21 | 22 | loss = BatchHardTripletMarginLoss 23 | normalize_embeddings = False 24 | margin = 0.2 25 | swap = True 26 | lamda = 1 27 | lamda_reg = 0.001 28 | lpd_fixed = True 29 | 30 | train_file = training_queries_baseline.pickle 31 | 32 | [REGISTARTION] 33 | batch_size = 30 34 | epochs = 100 35 | num_points = 4096 36 | iter = 1 37 | overlap = 1.0 -------------------------------------------------------------------------------- /config/config_baseline_gtav5.txt: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | num_points = 4096 3 | 4 | dataset_folder = ./benchmark_datasets 5 | queries_folder = ./generating_queries/pickles 6 | 7 | [TRAIN] 8 | num_workers = 16 9 | batch_size = 32 10 | batch_size_limit = 42 11 | batch_expansion_rate = 1.4 12 | batch_expansion_th = 0.7 13 | 14 | lr = 1e-3 15 | epochs = 40 16 | scheduler_milestones = 30 17 | 18 | aug_mode = 2 19 | weight_decay = 1e-3 20 | eval_simple = True 21 | 22 | loss = BatchHardTripletMarginLoss 23 | normalize_embeddings = False 24 | margin = 0.2 25 | swap = True 26 | lamda = 1 27 | lamda_reg = 0.001 28 | lpd_fixed = True 29 | 30 | train_file = gta5_training_queries_baseline.pickle 31 | 32 | [REGISTARTION] 33 | batch_size = 30 34 | epochs = 100 35 | num_points = 4096 36 | iter = 1 37 | overlap = 1.0 38 | 39 | [DOMAIN_ADAPT] 40 | model = FCDor 41 | lr = 1e-4 42 | epochs = 30 43 | scheduler_milestones = 20 44 | weight_decay = 1e-3 45 | 46 | lamda_gd = 0.01 47 | lamda_d = 0.01 48 | loss = MSEWithMask 49 | 50 | train_file = training_queries_baseline.pickle -------------------------------------------------------------------------------- /config/models/vlpdnet.txt: -------------------------------------------------------------------------------- 1 | # MinkLoc3D model 2 | [MODEL] 3 | model = vLPDNet 4 | mink_quantization_size = 0.01 5 | planes = 32,64,64 6 | layers = 1,1,1 7 | num_top_down = 1 8 | conv0_kernel_size = 5 9 | feature_size = 256 10 | separa_mode = Simple 11 | 12 | [LPD] 13 | emb_dims = 128 14 | featnet = lpdnet 15 | lpd_channels = 16,32,64 16 | 17 | 18 | [REGRESS] 19 | model = vcr 20 | loss = PointMSE 21 | ff_dims = 128 -------------------------------------------------------------------------------- /dataloader/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | import MinkowskiEngine as ME 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from dataloader.gta5 import GTAV5 11 | from dataloader.oxford import OxfordDataset, TrainTransform, TrainSetTransform, Oxford 12 | from dataloader.samplers import BatchSampler 13 | from misc.utils import MinkLocParams 14 | 15 | 16 | def make_datasets(params: MinkLocParams, debug=False): 17 | # Create training and validation datasets 18 | datasets = {} 19 | train_transform = TrainTransform(params.aug_mode) 20 | train_set_transform = TrainSetTransform(params.aug_mode) 21 | if debug: 22 | max_elems = 1000 23 | else: 24 | max_elems = None 25 | 26 | datasets['train'] = OxfordDataset(params, params.train_file, train_transform, set_transform=train_set_transform, 27 | max_elems=max_elems) 28 | val_transform = None 29 | if params.val_file is not None: 30 | datasets['val'] = OxfordDataset(params, params.val_file, val_transform) 31 | return datasets 32 | 33 | 34 | def make_eval_dataset(params: MinkLocParams): 35 | # Create evaluation datasets 36 | dataset = OxfordDataset(params, params.test_file, transform=None) 37 | return dataset 38 | 39 | 40 | def sparse_quantize(xyz_batch, params=None): 41 | coords = [ 42 | ME.utils.sparse_quantize(e, quantization_size=params.model_params.mink_quantization_size, return_index=True) 43 | for e in xyz_batch] 44 | coords_idx = [e[1] + i * e[1].shape[0] for i, e in enumerate(coords)] 45 | coords_idx = torch.cat(coords_idx) 46 | coords = [e[0] for e in coords] 47 | coords = ME.utils.batched_coordinates(coords) 48 | 49 | # Assign a dummy feature equal to 1 to each point 50 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 51 | feats = torch.ones((coords.shape[0], 1), dtype=torch.float32) 52 | batch = {'coords': coords, 'features': feats, 'index': coords_idx, 'cloud': xyz_batch} 53 | 54 | return batch 55 | 56 | 57 | def make_collate_fn(dataset: OxfordDataset, params=None): 58 | # set_transform: the transform to be applied to all batch elements 59 | def collate_fn(data_list): 60 | # Constructs a batch object 61 | batch_size = len(data_list) 62 | clouds = [e[0] for e in data_list] 63 | labels = [e[1] for e in data_list] 64 | xyz_batch = torch.stack(clouds, dim=0) # Produces (batch_size, n_points, 3) tensor 65 | if dataset.set_transform is not None: 66 | # Apply the same transformation on all dataset elements 67 | xyz_batch, R_set, t_set = dataset.set_transform(xyz_batch) 68 | else: 69 | R_set = np.eye(3) 70 | t_set = np.zeros((3, 1)) 71 | 72 | batch = sparse_quantize(xyz_batch, params) 73 | 74 | clouds_source = [e[5] for e in data_list] 75 | xyz_batch_source = torch.stack(clouds_source, dim=0) # Produces (batch_size, n_points, 3) tensor 76 | source_batch = sparse_quantize(xyz_batch_source, params) 77 | 78 | if data_list[0][2] is not None: 79 | clouds_da = [e[2][0] for e in data_list] 80 | xyz_batch_da = torch.stack(clouds_da, dim=0) # Produces (batch_size, n_points, 3) tensor 81 | da_batch = sparse_quantize(xyz_batch_da, params) 82 | else: 83 | da_batch = None 84 | 85 | # Compute positives and negatives mask 86 | # dataset.queries[label]['positives'] is bitarray 87 | positives_mask = [[dataset.queries[label]['positives'][e] for e in labels] for label in labels] 88 | negatives_mask = [[dataset.queries[label]['negatives'][e] for e in labels] for label in labels] 89 | 90 | positives_mask = torch.tensor(positives_mask) 91 | negatives_mask = torch.tensor(negatives_mask) 92 | 93 | R2 = [e[3].reshape(1, 3, 3) for e in data_list] 94 | R2 = torch.tensor(np.concatenate(R2, axis=0), dtype=torch.float32) 95 | t2 = [e[4].reshape(1, 3, 1) for e in data_list] 96 | t2 = torch.tensor(np.concatenate(t2, axis=0), dtype=torch.float32) 97 | 98 | R_set = torch.tensor(R_set.reshape(1, 3, 3), dtype=torch.float32).repeat(batch_size, 1, 1) 99 | t_set = torch.tensor(t_set.reshape(1, 3, 1), dtype=torch.float32).repeat(batch_size, 1, 1) 100 | 101 | R = torch.bmm(R_set, R2) 102 | t = t_set + torch.bmm(R_set, t2) 103 | 104 | gt_T = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1) 105 | gt_T[:, :3, :3] = R 106 | gt_T[:, :3, 3] = t.reshape(batch_size, 3) 107 | 108 | # Returns (batch_size, n_points, 3) tensor and positives_mask and 109 | # negatives_mask which are batch_size x batch_size boolean tensors 110 | return batch, positives_mask, negatives_mask, da_batch, R.float(), t.float(), source_batch, gt_T.float() 111 | 112 | return collate_fn 113 | 114 | 115 | def make_dataloaders(params: MinkLocParams, debug=False): 116 | """ 117 | Create training and validation dataloaders that return groups of k=2 similar elements 118 | :param train_params: 119 | :param model_params: 120 | :return: 121 | """ 122 | datasets = make_datasets(params, debug=debug) 123 | 124 | dataloders = {} 125 | train_sampler = BatchSampler(datasets['train'], batch_size=params.batch_size, 126 | batch_size_limit=params.batch_size_limit, 127 | batch_expansion_rate=params.batch_expansion_rate) 128 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 129 | train_collate_fn = make_collate_fn(datasets['train'], params) 130 | dataloders['train'] = DataLoader(datasets['train'], batch_sampler=train_sampler, collate_fn=train_collate_fn, 131 | num_workers=params.num_workers, pin_memory=True) 132 | 133 | if 'val' in datasets: 134 | val_sampler = BatchSampler(datasets['val'], batch_size=params.batch_size) 135 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 136 | # Currently validation dataset has empty set_transform function, but it may change in the future 137 | val_collate_fn = make_collate_fn(datasets['val'], params) 138 | dataloders['val'] = DataLoader(datasets['val'], batch_sampler=val_sampler, collate_fn=val_collate_fn, 139 | num_workers=params.num_workers, pin_memory=True) 140 | if params.is_register: 141 | if not "gta" in params.train_file.lower(): 142 | train_loader = DataLoader( 143 | Oxford(params=params, partition='train'), 144 | batch_size=params.reg.batch_size, shuffle=True, drop_last=True, num_workers=16) 145 | test_loader = DataLoader( 146 | Oxford(params=params, partition='test'), 147 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16) 148 | else: 149 | train_loader = DataLoader( 150 | GTAV5(params=params, partition='train'), 151 | batch_size=params.reg.batch_size, shuffle=True, drop_last=True, num_workers=16) 152 | test_loader = DataLoader( 153 | GTAV5(params=params, partition='test'), 154 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16) 155 | dataloders['reg_train'] = train_loader 156 | dataloders['reg_test'] = test_loader 157 | 158 | return dataloders 159 | -------------------------------------------------------------------------------- /dataloader/gta5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | from scipy.spatial.transform import Rotation 7 | from sklearn.neighbors import NearestNeighbors 8 | from torch.utils.data import Dataset 9 | 10 | from misc.utils import MinkLocParams 11 | 12 | 13 | def load_data(params, partition): 14 | query_filepath = os.path.join(params.queries_folder, params.train_file) 15 | with open(query_filepath, 'rb') as handle: 16 | # key:{'query':file,'positives':[files],'negatives:[files], 'neighbors':[keys]} 17 | queries = pickle.load(handle) 18 | num = len(queries) 19 | all_data = queries 20 | # if partition=='train': 21 | # for idx in range(num): 22 | # if idx % 4!=0: 23 | # all_data.append(queries[idx]) 24 | # else: 25 | # for idx in range(num): 26 | # if idx % 4==0: 27 | # all_data.append(queries[idx]) 28 | return all_data 29 | 30 | 31 | class GTAV5(Dataset): 32 | def __init__(self, params: MinkLocParams, partition='train'): 33 | self.num_points = params.reg.num_points 34 | self.overlap = params.reg.overlap 35 | self.partition = partition 36 | print('Load GTAV5 Dataset') 37 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 38 | self.DATA_DIR = os.path.join(BASE_DIR, '../../dataset/benchmark_datasets/') 39 | self.data = load_data(params, partition) 40 | # self.partial = True 41 | 42 | def __getitem__(self, item): 43 | 44 | if self.partition != 'train': 45 | np.random.seed(item) 46 | 47 | # [num,num_dim] 48 | query_dict = self.data[item] 49 | 50 | filename = os.path.join(self.DATA_DIR, query_dict['query']) 51 | pointcloud = np.fromfile(filename, dtype=np.float64) 52 | pointcloud = np.reshape(pointcloud, (pointcloud.shape[0] // 3, 3)) 53 | pointcloud1 = (np.random.permutation(pointcloud)[: self.num_points]).T 54 | 55 | num_samples = len(query_dict['positives']) 56 | pos_idx = np.random.randint(num_samples) 57 | positive_dict = self.data[query_dict['positives'][pos_idx]] 58 | filename = os.path.join(self.DATA_DIR, positive_dict['query']) 59 | pointcloud = np.fromfile(filename, dtype=np.float64) 60 | pointcloud = np.reshape(pointcloud, (pointcloud.shape[0] // 3, 3)) 61 | pointcloud2 = (np.random.permutation(pointcloud)[: self.num_points]).T 62 | 63 | positive_T = query_dict['positives_T'][pos_idx] 64 | 65 | rotation_ab = positive_T[:3, :3] 66 | translation_ab = positive_T[:3, 3].reshape(3, 1) 67 | pointcloud1 = rotation_ab @ pointcloud1 + translation_ab 68 | 69 | # self.visual_pcl_simple(pcl1=pointcloud1.T, pcl2=pointcloud2.T) 70 | 71 | anglex = (np.random.uniform()) * 5.0 / 180.0 * np.pi 72 | angley = (np.random.uniform()) * 5.0 / 180.0 * np.pi 73 | anglez = (np.random.uniform()) * 35.0 / 180.0 * np.pi 74 | 75 | cosx = np.cos(anglex) 76 | cosy = np.cos(angley) 77 | cosz = np.cos(anglez) 78 | sinx = np.sin(anglex) 79 | siny = np.sin(angley) 80 | sinz = np.sin(anglez) 81 | Rx = np.array([[1, 0, 0], 82 | [0, cosx, -sinx], 83 | [0, sinx, cosx]]) 84 | Ry = np.array([[cosy, 0, siny], 85 | [0, 1, 0], 86 | [-siny, 0, cosy]]) 87 | Rz = np.array([[cosz, -sinz, 0], 88 | [sinz, cosz, 0], 89 | [0, 0, 1]]) 90 | R_ab = Rx.dot(Ry).dot(Rz) 91 | R_ba = R_ab.T 92 | 93 | if self.partition == 'train': 94 | translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5), 95 | np.random.uniform(-0.5, 0.5)]) 96 | else: 97 | translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5), 98 | np.random.uniform(-0.5, 0.5)]) 99 | 100 | translation_ba = -R_ba.dot(translation_ab) 101 | 102 | rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex]) 103 | pointcloud2 = rotation_ab.apply(pointcloud2.T).T + np.expand_dims(translation_ab, axis=1) 104 | 105 | euler_ab = np.asarray([anglez, angley, anglex]) 106 | 107 | euler_ba = -euler_ab[::-1] 108 | 109 | pointcloud1 = np.random.permutation(pointcloud1.T).T 110 | pointcloud2 = np.random.permutation(pointcloud2.T).T 111 | 112 | # [3,num_points] (3,) 113 | return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \ 114 | translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \ 115 | euler_ab.astype('float32'), euler_ba.astype('float32'), 0 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | # return 100 120 | 121 | def visual_pcl_simple(self, pcl1, pcl2, name='Open3D Origin'): 122 | # pcl: N,3 123 | pcd1 = o3d.geometry.PointCloud() 124 | pcd1.points = o3d.utility.Vector3dVector(pcl1[:, :3]) 125 | pcd2 = o3d.geometry.PointCloud() 126 | pcd2.points = o3d.utility.Vector3dVector(pcl2[:, :3]) 127 | pcd1.paint_uniform_color([1, 0.706, 0]) 128 | pcd2.paint_uniform_color([0, 0.651, 0.929]) 129 | o3d.visualization.draw_geometries([pcd1, pcd2], window_name=name, width=1920, height=1080, 130 | left=50, 131 | top=50, 132 | point_show_normal=False, mesh_show_wireframe=False, 133 | mesh_show_back_face=False) 134 | 135 | def nearest_neighbor(self, dst, reserve): 136 | dst = dst.T 137 | num = np.max([dst.shape[0], dst.shape[1]]) 138 | num = int(num * reserve) 139 | src = dst[-1, :].reshape(1, -1) 140 | neigh = NearestNeighbors(n_neighbors=num) 141 | neigh.fit(dst) 142 | indices = neigh.kneighbors(src, return_distance=False) 143 | indices = indices.ravel() 144 | return dst[indices, :].T 145 | -------------------------------------------------------------------------------- /dataloader/samplers.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import copy 5 | import random 6 | 7 | from torch.utils.data import DataLoader, Sampler 8 | 9 | from dataloader.oxford import OxfordDataset 10 | from misc.log import log_string 11 | 12 | VERBOSE = False 13 | 14 | 15 | class BatchSampler(Sampler): 16 | # Sampler returning list of indices to form a mini-batch 17 | # Samples elements in groups consisting of k=2 similar elements (positives) 18 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 19 | def __init__(self, dataset: OxfordDataset, batch_size: int, batch_size_limit: int = None, 20 | batch_expansion_rate: float = None): 21 | if batch_expansion_rate is not None: 22 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1' 23 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size' 24 | 25 | self.batch_size = batch_size 26 | self.batch_size_limit = batch_size_limit 27 | self.batch_expansion_rate = batch_expansion_rate 28 | self.dataset = dataset 29 | self.k = 2 # Number of positive examples per group must be 2 30 | if self.batch_size < 2 * self.k: 31 | self.batch_size = 2 * self.k 32 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size)) 33 | 34 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch) 35 | 36 | self.elems_ndx = {} # Dictionary of point cloud indexes 37 | for ndx in self.dataset.queries: 38 | self.elems_ndx[ndx] = True 39 | # self.generate_batches() # 初始化batch_idx 40 | 41 | def __iter__(self): 42 | # Re-generate batches every epoch 43 | self.generate_batches() 44 | for batch in self.batch_idx: 45 | yield batch 46 | 47 | def __len(self): 48 | return len(self.batch_idx) 49 | 50 | def __len__(self) -> int: 51 | return len(self.batch_idx) 52 | 53 | def expand_batch(self): 54 | if self.batch_expansion_rate is None: 55 | print('WARNING: batch_expansion_rate is None') 56 | return 57 | 58 | if self.batch_size >= self.batch_size_limit: 59 | return 60 | 61 | old_batch_size = self.batch_size 62 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 63 | self.batch_size = min(self.batch_size, self.batch_size_limit) 64 | log_string('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size)) 65 | 66 | def generate_batches(self): 67 | # Generate training/evaluation batches. 68 | # batch_idx holds indexes of elements in each batch as a list of lists 69 | self.batch_idx = [] 70 | 71 | unused_elements_ndx = copy.deepcopy(self.elems_ndx) 72 | current_batch = [] 73 | 74 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class' 75 | 76 | while True: 77 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0: 78 | # Flush out a new batch and reinitialize a list of available location 79 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 80 | # elements to process 81 | if len(current_batch) >= 2 * self.k: 82 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 83 | # to find negative examples in the batch 84 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch)) 85 | self.batch_idx.append(current_batch) 86 | current_batch = [] 87 | if len(unused_elements_ndx) == 0: 88 | break 89 | 90 | # Add k=2 similar elements to the batch 91 | selected_element = random.choice(list(unused_elements_ndx)) 92 | unused_elements_ndx.pop(selected_element) 93 | positives = self.dataset.get_positives_ndx(selected_element) 94 | if len(positives) == 0: 95 | # Broken dataset element without any positives 96 | continue 97 | 98 | unused_positives = [e for e in positives if e in unused_elements_ndx] 99 | # If there're unused elements similar to selected_element, sample from them 100 | # otherwise sample from all similar elements 101 | if len(unused_positives) > 0: 102 | second_positive = random.choice(unused_positives) 103 | unused_elements_ndx.pop(second_positive) 104 | else: 105 | second_positive = random.choice(positives) 106 | 107 | # 每个batch里都有batch_size/2对正样本对 108 | current_batch += [selected_element, second_positive] 109 | 110 | for batch in self.batch_idx: 111 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch)) 112 | 113 | 114 | if __name__ == '__main__': 115 | dataset_path = '/media/sf_Datasets/PointNetVLAD' 116 | query_filename = 'test_queries_baseline.pickle' 117 | 118 | from configparser import ConfigParser 119 | 120 | config = ConfigParser() 121 | config.dataset_path = dataset_path 122 | 123 | ds = OxfordDataset(config, query_filename) 124 | sampler = BatchSampler(ds, batch_size=16) 125 | dataloader = DataLoader(ds, batch_sampler=sampler) 126 | e = ds[0] 127 | res = next(iter(dataloader)) 128 | log_string(res) 129 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | import argparse 6 | import json 7 | import os 8 | import pickle 9 | 10 | import numpy as np 11 | import torch 12 | from sklearn.neighbors import KDTree 13 | from tqdm import tqdm 14 | 15 | from misc.log import log_string 16 | from misc.utils import MinkLocParams 17 | from models.model_factory import model_factory, load_weights 18 | from training.reg_train import testVCRNet 19 | from torch.utils.data import DataLoader 20 | from dataloader.oxford import Oxford 21 | DEBUG = False 22 | 23 | 24 | def evaluate(model, device, params, log=False): 25 | # Run evaluation on all eval datasets 26 | 27 | if DEBUG: 28 | params.eval_database_files = params.eval_database_files[0:1] 29 | params.eval_query_files = params.eval_query_files[0:1] 30 | 31 | assert len(params.eval_database_files) == len(params.eval_query_files) 32 | 33 | stats = {} 34 | for database_file, query_file in zip(params.eval_database_files, params.eval_query_files): 35 | # Extract location name from query and database files 36 | location_name = database_file.split('_')[0] 37 | temp = query_file.split('_')[0] 38 | assert location_name == temp, 'Database location: {} does not match query location: {}'.format(database_file, 39 | query_file) 40 | 41 | p = os.path.join(params.queries_folder, database_file) 42 | with open(p, 'rb') as f: 43 | database_sets = pickle.load(f) 44 | 45 | p = os.path.join(params.queries_folder, query_file) 46 | with open(p, 'rb') as f: 47 | query_sets = pickle.load(f) 48 | if log: 49 | print('Evaluation:{} on {}'.format(database_file, query_file)) 50 | temp = evaluate_dataset(model, device, params, database_sets, query_sets, log=log) 51 | stats[location_name] = temp 52 | 53 | for database_name in stats: 54 | log_string('Dataset: {} '.format(database_name), end='') 55 | t = 'Avg. top 1 recall: {:.2f} Avg. top 1% recall: {:.2f} Avg. similarity: {:.4f}' 56 | log_string(t.format(stats[database_name]['ave_recall'][0], 57 | stats[database_name]['ave_one_percent_recall'], 58 | stats[database_name]['average_similarity'])) 59 | return stats 60 | 61 | 62 | def evaluate_dataset(model, device, params, database_sets, query_sets, log=False): 63 | # Run evaluation on a single dataset 64 | recall = np.zeros(25) 65 | count = 0 66 | similarity = [] 67 | one_percent_recall = [] 68 | 69 | database_embeddings = [] 70 | query_embeddings = [] 71 | 72 | model.eval() 73 | if log: 74 | tqdm_ = lambda x, desc: tqdm(x, desc=desc) 75 | else: 76 | tqdm_ = lambda x, desc: x 77 | 78 | torch.cuda.empty_cache() 79 | for set in tqdm_(database_sets, 'Database'): 80 | database_embeddings.append(get_latent_vectors(model, set, device, params)) 81 | 82 | for set in tqdm_(query_sets, ' Query'): 83 | query_embeddings.append(get_latent_vectors(model, set, device, params)) 84 | 85 | for i in tqdm_(range(len(query_sets)), ' Test'): 86 | for j in range(len(query_sets)): 87 | if i == j: 88 | continue 89 | pair_recall, pair_similarity, pair_opr = get_recall(i, j, database_embeddings, query_embeddings, query_sets, 90 | database_sets, log=log) 91 | recall += np.array(pair_recall) 92 | count += 1 93 | one_percent_recall.append(pair_opr) 94 | for x in pair_similarity: 95 | similarity.append(x) 96 | 97 | ave_recall = recall / count 98 | average_similarity = np.mean(similarity) 99 | ave_one_percent_recall = np.mean(one_percent_recall) 100 | stats = {'ave_one_percent_recall': ave_one_percent_recall, 'ave_recall': ave_recall, 101 | 'average_similarity': average_similarity, 'Loc rebuild': 1} 102 | return stats 103 | 104 | 105 | def load_pc(file_name, params, make_tensor=True): 106 | # returns Nx3 matrix 107 | file_path = os.path.join(params.dataset_folder, file_name) 108 | pc = np.fromfile(file_path, dtype=np.float64) 109 | # coords are within -1..1 range in each dimension 110 | assert pc.shape[0] == params.num_points * 3, "Error in point cloud shape: {}".format(file_path) 111 | pc = np.reshape(pc, (pc.shape[0] // 3, 3)) 112 | if make_tensor: 113 | pc = torch.tensor(pc, dtype=torch.float) 114 | return pc 115 | 116 | 117 | def load_pc_files(elem_ndxs, set, params): 118 | pcs = [] 119 | for elem_ndx in elem_ndxs: 120 | pc = load_pc(set[elem_ndx]["query"], params, make_tensor=False) 121 | if (pc.shape[0] != 4096): 122 | assert 0, 'pc.shape[0] != 4096' 123 | pcs.append(pc) 124 | pcs = np.asarray(pcs) 125 | 126 | pcs = torch.tensor(pcs, dtype=torch.float) 127 | return pcs 128 | 129 | 130 | def genrate_batch(num, batch_size): 131 | sets = np.arange(0, num, batch_size) 132 | sets = sets.tolist() 133 | if sets[-1] != num: 134 | sets.append(num) 135 | return sets 136 | 137 | 138 | def get_latent_vectors(model, set, device, params: MinkLocParams): 139 | if DEBUG: 140 | embeddings = np.random.rand(len(set), 256) 141 | return embeddings 142 | 143 | model.eval() 144 | embeddings_l = [] 145 | 146 | batch_set = genrate_batch(len(set), int(params.batch_size * 1.5)) 147 | for batch_id in range(len(batch_set) - 1): 148 | elem_ndx = np.arange(batch_set[batch_id], batch_set[batch_id + 1]) 149 | x = load_pc_files(elem_ndx, set, params) 150 | with torch.no_grad(): 151 | batch = {'cloud': x.cuda()} 152 | embedding = model(target_batch=batch)['embeddings'] 153 | # embedding is (1, 1024) tensor 154 | if params.normalize_embeddings: 155 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) # Normalize embeddings 156 | 157 | embedding = embedding.detach().cpu().numpy() 158 | embeddings_l.append(embedding) 159 | 160 | embeddings = np.vstack(embeddings_l) 161 | return embeddings 162 | 163 | 164 | def get_recall(m, n, database_vectors, query_vectors, query_sets, database_sets, log=False): 165 | # Original PointNetVLAD code 166 | database_output = database_vectors[m] 167 | queries_output = query_vectors[n] 168 | 169 | database_nbrs = KDTree(database_output) 170 | 171 | num_neighbors = 25 172 | recall = [0] * num_neighbors 173 | 174 | top1_similarity_score = [] 175 | one_percent_retrieved = 0 176 | threshold = max(int(round(len(database_output) / 100.0)), 1) 177 | 178 | num_evaluated = 0 179 | for i in range(len(queries_output)): 180 | # i is query element ndx 181 | query_details = query_sets[n][i] # {'query': path, 'northing': , 'easting': } 182 | true_neighbors = query_details[m] 183 | if len(true_neighbors) == 0: 184 | continue 185 | num_evaluated += 1 186 | distances, indices = database_nbrs.query(np.array([queries_output[i]]), k=num_neighbors) 187 | 188 | for j in range(len(indices[0])): 189 | if indices[0][j] in true_neighbors: 190 | if j == 0: 191 | similarity = np.dot(queries_output[i], database_output[indices[0][j]]) 192 | top1_similarity_score.append(similarity) 193 | recall[j] += 1 194 | break 195 | 196 | if len(list(set(indices[0][0:threshold]).intersection(set(true_neighbors)))) > 0: 197 | one_percent_retrieved += 1 198 | 199 | one_percent_recall = (one_percent_retrieved / float(num_evaluated)) * 100 200 | recall = (np.cumsum(recall) / float(num_evaluated)) * 100 201 | # log_string(recall) 202 | # log_string(np.mean(top1_similarity_score)) 203 | # log_string(one_percent_recall) 204 | return recall, top1_similarity_score, one_percent_recall 205 | 206 | 207 | def export_eval_stats(file_name, prefix, eval_stats): 208 | s = prefix 209 | ave_1p_recall_l = [] 210 | ave_recall_l = [] 211 | # Print results on the final model 212 | with open(file_name, "a") as f: 213 | for ds in ['oxford', 'university', 'residential', 'business']: 214 | ave_1p_recall = eval_stats[ds]['ave_one_percent_recall'] 215 | ave_1p_recall_l.append(ave_1p_recall) 216 | ave_recall = eval_stats[ds]['ave_recall'][0] 217 | ave_recall_l.append(ave_recall) 218 | s += ", {:0.2f}, {:0.2f}".format(ave_1p_recall, ave_recall) 219 | 220 | mean_1p_recall = np.mean(ave_1p_recall_l) 221 | mean_recall = np.mean(ave_recall_l) 222 | s += ", {:0.2f}, {:0.2f}\n".format(mean_1p_recall, mean_recall) 223 | f.write(s) 224 | 225 | 226 | if __name__ == "__main__": 227 | parser = argparse.ArgumentParser(description='Evaluate model on PointNetVLAD (Oxford) dataset') 228 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file') 229 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file') 230 | parser.add_argument('--weights', type=str, required=False, help='Trained model weights') 231 | parser.add_argument('--debug', dest='debug', action='store_true') 232 | parser.set_defaults(debug=False) 233 | parser.add_argument('--visualize', dest='visualize', action='store_true') 234 | parser.set_defaults(visualize=False) 235 | parser.add_argument('--savejson', type=str, default='', help='') 236 | parser.add_argument('--eval_reg', type=str, default="") 237 | 238 | args = parser.parse_args() 239 | log_string('Config path: {}'.format(args.config)) 240 | log_string('Model config path: {}'.format(args.model_config)) 241 | if args.weights is None: 242 | w = 'RANDOM WEIGHTS' 243 | else: 244 | w = args.weights 245 | log_string('Weights: {}'.format(w)) 246 | log_string('Debug mode: {}'.format(args.debug)) 247 | log_string('Visualize: {}'.format(args.visualize)) 248 | 249 | params = MinkLocParams(args.config, args.model_config) 250 | params.print() 251 | 252 | model, device, d_model, vcr_model = model_factory(params) 253 | 254 | load_weights(args.weights, model) 255 | 256 | if args.eval_reg != "": 257 | test_loader = DataLoader( 258 | Oxford(params=params, partition='test'), 259 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16) 260 | checkpoint_dict = torch.load(args.eval_reg, map_location=torch.device('cpu')) 261 | vcr_model.load_state_dict(checkpoint_dict, strict=True) 262 | log_string('load vcr_model with {}'.format(args.eval_reg)) 263 | testVCRNet(1, vcr_model, test_loader) 264 | else: 265 | stats = evaluate(model, device, params, True) 266 | 267 | for database_name in stats: 268 | log_string(' Avg. recall @N:') 269 | log_string(str(stats[database_name]['ave_recall'])) 270 | 271 | if len(args.savejson) > 0: 272 | result = {} 273 | result['trainfile'] = params.train_file 274 | result['weightfile'] = args.weights 275 | result['lr'] = params.lr 276 | result['lamda_g'] = params.lamda_gd 277 | result['weight_decay'] = params.weight_decay 278 | result['domain_adapt'] = params.domain_adapt 279 | if params.domain_adapt: 280 | result['lr_d'] = params.d_lr 281 | result['lamda_d'] = params.lamda_d 282 | result['weight_decay_d'] = params.d_weight_decay 283 | else: 284 | result['lr_d'] = None 285 | result['lamda_d'] = None 286 | result['weight_decay_d'] = None 287 | 288 | for database_name in stats: 289 | result_database = {} 290 | result_database['recall_top1'] = float(stats[database_name]['ave_recall'][0]) 291 | result_database['recall_top1per'] = float(stats[database_name]['ave_one_percent_recall']) 292 | result_database['similarity'] = float(stats[database_name]['average_similarity']) 293 | result[database_name] = result_database 294 | json.dump(result, open(args.savejson, 'w')) 295 | -------------------------------------------------------------------------------- /generating_queries/generate_test_sets.py: -------------------------------------------------------------------------------- 1 | # Code taken from PointNetVLAD repo: https://github.com/mikacuy/pointnetvlad 2 | 3 | import os 4 | import pickle 5 | 6 | import numpy as np 7 | import pandas as pd 8 | from sklearn.neighbors import KDTree 9 | 10 | #####For training and test data split##### 11 | x_width = 150 12 | y_width = 150 13 | 14 | # For Oxford 15 | p1 = [5735712.768124, 620084.402381] 16 | p2 = [5735611.299219, 620540.270327] 17 | p3 = [5735237.358209, 620543.094379] 18 | p4 = [5734749.303802, 619932.693364] 19 | 20 | # For University Sector 21 | p5 = [363621.292362, 142864.19756] 22 | p6 = [364788.795462, 143125.746609] 23 | p7 = [363597.507711, 144011.414174] 24 | 25 | # For Residential Area 26 | p8 = [360895.486453, 144999.915143] 27 | p9 = [362357.024536, 144894.825301] 28 | p10 = [361368.907155, 145209.663042] 29 | 30 | p_dict = {"oxford": [p1, p2, p3, p4], "university": [p5, p6, p7], "residential": [p8, p9, p10], "business": [], 31 | "lgsvl": []} 32 | p_dict_thresold = {"oxford": 25, "university": 25, "residential": 25, "business": 25, "lgsvl": 60} 33 | 34 | 35 | def check_in_test_set(northing, easting, points, x_width, y_width): 36 | in_test_set = False 37 | if points == []: 38 | return True 39 | for point in points: 40 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[ 41 | 1] - y_width < easting and easting < point[1] + y_width): 42 | in_test_set = True 43 | break 44 | return in_test_set 45 | 46 | 47 | ########################################## 48 | 49 | def output_to_file(output, filename): 50 | if not os.path.exists('pickles'): 51 | os.mkdir('pickles') 52 | filename = os.path.join('pickles', filename) 53 | 54 | with open(filename, 'wb') as handle: 55 | pickle.dump(output, handle, protocol=pickle.DEFAULT_PROTOCOL) 56 | print("Done ", filename) 57 | 58 | 59 | def construct_query_and_database_sets(base_path, runs_folder, folders, pointcloud_fols, filename, p, output_name): 60 | database_trees = [] 61 | test_trees = [] 62 | for folder in folders: 63 | print(folder) 64 | df_database = pd.DataFrame(columns=['file', 'northing', 'easting']) 65 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting']) 66 | 67 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',') 68 | for index, row in df_locations.iterrows(): 69 | # entire business district is in the test set 70 | if (output_name == "business"): 71 | df_test = df_test.append(row, ignore_index=True) 72 | elif (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)): 73 | df_test = df_test.append(row, ignore_index=True) 74 | df_database = df_database.append(row, ignore_index=True) 75 | 76 | database_tree = KDTree(df_database[['northing', 'easting']]) 77 | test_tree = KDTree(df_test[['northing', 'easting']]) 78 | database_trees.append(database_tree) 79 | test_trees.append(test_tree) 80 | 81 | test_sets = [] 82 | database_sets = [] 83 | for folder in folders: 84 | database = {} 85 | test = {} 86 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',') 87 | df_locations['timestamp'] = runs_folder + folder + pointcloud_fols + df_locations['timestamp'].astype( 88 | str) + '.bin' 89 | 90 | df_locations = df_locations.rename(columns={'timestamp': 'file'}) 91 | for index, row in df_locations.iterrows(): 92 | # entire business district is in the test set 93 | if (output_name == "business"): 94 | test[len(test.keys())] = {'query': row['file'], 'northing': row['northing'], 'easting': row['easting']} 95 | elif (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)): 96 | test[len(test.keys())] = {'query': row['file'], 'northing': row['northing'], 'easting': row['easting']} 97 | database[len(database.keys())] = {'query': row['file'], 'northing': row['northing'], 98 | 'easting': row['easting']} 99 | 100 | database_sets.append(database) 101 | test_sets.append(test) 102 | 103 | for i in range(len(database_sets)): 104 | tree = database_trees[i] 105 | for j in range(len(test_sets)): 106 | if (i == j): 107 | continue 108 | for key in range(len(test_sets[j].keys())): 109 | coor = np.array([[test_sets[j][key]["northing"], test_sets[j][key]["easting"]]]) 110 | index = tree.query_radius(coor, r=p_dict_thresold[output_name]) 111 | # indices of the positive matches in database i of each query (key) in test set j 112 | test_sets[j][key][i] = index[0].tolist() 113 | 114 | output_to_file(database_sets, output_name + '_evaluation_database.pickle') 115 | output_to_file(test_sets, output_name + '_evaluation_query.pickle') 116 | 117 | 118 | ###Building database and query files for evaluation 119 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 120 | base_path = "../benchmark_datasets/" 121 | 122 | # For Oxford 123 | folders = [] 124 | runs_folder = "oxford/" 125 | all_folders = sorted(os.listdir(os.path.join(BASE_DIR, base_path, runs_folder))) 126 | index_list = [5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 24, 31, 32, 33, 38, 39, 43, 44] 127 | print(len(index_list)) 128 | for index in index_list: 129 | folders.append(all_folders[index]) 130 | 131 | print(folders) 132 | construct_query_and_database_sets(base_path, runs_folder, folders, "/pointcloud_20m/", "pointcloud_locations_20m.csv", 133 | p_dict["oxford"], "oxford") 134 | 135 | # For lgsvl District 136 | # runs_folder = "lgsvl_new/" 137 | # all_folders=sorted(os.listdir(os.path.join(BASE_DIR,base_path,runs_folder))) 138 | # print(all_folders) 139 | # construct_query_and_database_sets(base_path, runs_folder, all_folders, "/pointcloud_20m/", "pointcloud_locations_20m.csv", p_dict["lgsvl"], "lgsvl") 140 | -------------------------------------------------------------------------------- /generating_queries/generate_test_sets_gtav5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import sys 4 | 5 | import numpy as np 6 | import pandas as pd 7 | from sklearn.neighbors import KDTree 8 | 9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | sys.path.append(BASE_DIR) 11 | sys.path.append(os.path.abspath(os.path.join(BASE_DIR, ".."))) 12 | from tqdm import tqdm 13 | from glob import glob 14 | from registration import registration_withinit 15 | 16 | recall_dis = 1.25 17 | 18 | 19 | def main(): 20 | # Building database and query files for evaluation 21 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 22 | base_path = '../benchmark_datasets/' 23 | 24 | # For Oxford 25 | runs_folder = "GTA5/" 26 | all_folders = sorted(glob(os.path.join(BASE_DIR, base_path, runs_folder) + '/round*')) 27 | for i in range(len(all_folders)): 28 | all_folders[i] = os.path.basename(all_folders[i]) 29 | print("start process") 30 | construct_query_and_database_sets(base_path, runs_folder, all_folders, "/pcl/", 31 | "pointcloud_locations_20m.csv") 32 | 33 | 34 | def check_in_test_set(): 35 | return True 36 | 37 | 38 | ########################################## 39 | def output_to_file(output, filename): 40 | with open(filename, 'wb') as handle: 41 | pickle.dump(output, handle, protocol=pickle.HIGHEST_PROTOCOL) 42 | print("Done ", filename) 43 | 44 | 45 | def construct_query_and_database_sets(base_path, runs_folder, folders, pointcloud_fols, filename): 46 | database_trees = [] 47 | test_trees = [] 48 | for folder in tqdm(folders): 49 | # print(folder) 50 | df_database = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude']) 51 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude']) 52 | 53 | df_locations = pd.read_csv(os.path.join( 54 | base_path, runs_folder, folder, filename), sep=',') 55 | 56 | for index, row in df_locations.iterrows(): 57 | # entire business district is in the test set 58 | if (check_in_test_set()): 59 | df_test = df_test.append(row, ignore_index=True) 60 | df_database = df_database.append(row, ignore_index=True) 61 | 62 | database_tree = KDTree(df_database[['northing', 'easting', 'altitude']]) 63 | test_tree = KDTree(df_test[['northing', 'easting', 'altitude']]) 64 | database_trees.append(database_tree) 65 | test_trees.append(test_tree) 66 | 67 | test_sets = [] 68 | database_sets = [] 69 | for folder in tqdm(folders): 70 | database = {} 71 | test = {} 72 | df_locations = pd.read_csv(os.path.join( 73 | base_path, runs_folder, folder, filename), sep=',') 74 | df_locations['timestamp'] = runs_folder + folder + \ 75 | pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin' 76 | df_locations = df_locations.rename(columns={'timestamp': 'file'}) 77 | for index, row in df_locations.iterrows(): 78 | # entire business district is in the test set 79 | if (check_in_test_set()): 80 | test[len(test.keys())] = { 81 | 'query': row['file'], 'northing': row['northing'], 'easting': row['easting'], 82 | 'altitude': row['altitude']} 83 | database[len(database.keys())] = { 84 | 'query': row['file'], 'northing': row['northing'], 'easting': row['easting'], 85 | 'altitude': row['altitude']} 86 | database_sets.append(database) 87 | test_sets.append(test) 88 | 89 | for i in tqdm(range(len(database_sets))): 90 | tree = database_trees[i] 91 | for j in range(len(test_sets)): 92 | if (i == j): 93 | continue 94 | for key in range(len(test_sets[j].keys())): 95 | coor = np.array( 96 | [[test_sets[j][key]["northing"], test_sets[j][key]["easting"], test_sets[j][key]["altitude"]]]) 97 | index = tree.query_radius(coor, r=recall_dis) 98 | # indices of the positive matches in database i of each query (key) in test set j 99 | test_sets[j][key][i] = index[0].tolist() 100 | 101 | positives_T = [] 102 | for pos_i in test_sets[j][key][i]: 103 | query = test_sets[j][key]["query"] 104 | test_i = test_sets[i][pos_i] 105 | query_i = test_i["query"] 106 | coor_i = np.array([test_i["northing"], test_i["easting"], test_i["altitude"]]) 107 | trans_pre = coor.squeeze() - coor_i.squeeze() 108 | T = registration_withinit(base_path + query, base_path + query_i, trans_pre=trans_pre) 109 | positives_T.append(T) 110 | test_sets[j][key]['positives_T'] = positives_T 111 | 112 | if not os.path.exists("pickles/"): 113 | os.mkdir("pickles/") 114 | 115 | output_to_file(database_sets, "pickles/" + 'gta5_evaluation_database.pickle') 116 | output_to_file(test_sets, "pickles/" + 'gta5_evaluation_query.pickle') 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /generating_queries/generate_training_tuples_baseline.py: -------------------------------------------------------------------------------- 1 | # Code taken from PointNetVLAD repo: https://github.com/mikacuy/pointnetvlad 2 | 3 | import os 4 | import pickle 5 | import random 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.neighbors import KDTree 10 | from tqdm import tqdm 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | base_path = "../benchmark_datasets/" 14 | 15 | runs_folder = "oxford/" 16 | filename = "pointcloud_locations_20m_10overlap.csv" 17 | pointcloud_fols = "/pointcloud_20m_10overlap/" 18 | 19 | all_folders = sorted(os.listdir(os.path.join(BASE_DIR, base_path, runs_folder))) 20 | 21 | folders = [] 22 | 23 | # All runs are used for training (both full and partial) 24 | index_list = range(len(all_folders) - 1) 25 | print("Number of runs: " + str(len(index_list))) 26 | for index in index_list: 27 | folders.append(all_folders[index]) 28 | print(folders) 29 | 30 | #####For training and test data split##### 31 | x_width = 150 32 | y_width = 150 33 | p1 = [5735712.768124, 620084.402381] 34 | p2 = [5735611.299219, 620540.270327] 35 | p3 = [5735237.358209, 620543.094379] 36 | p4 = [5734749.303802, 619932.693364] 37 | p = [p1, p2, p3, p4] 38 | 39 | 40 | def check_in_test_set(northing, easting, points, x_width, y_width): 41 | in_test_set = False 42 | for point in points: 43 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[ 44 | 1] - y_width < easting and easting < point[1] + y_width): 45 | in_test_set = True 46 | break 47 | return in_test_set 48 | 49 | 50 | ########################################## 51 | 52 | 53 | def construct_query_dict(df_centroids, filename): 54 | tree = KDTree(df_centroids[['northing', 'easting']]) 55 | ind_nn = tree.query_radius(df_centroids[['northing', 'easting']], r=10) 56 | ind_r = tree.query_radius(df_centroids[['northing', 'easting']], r=50) 57 | queries = {} 58 | for i in tqdm(range(len(ind_nn))): 59 | query = df_centroids.iloc[i]["file"] 60 | positives = np.setdiff1d(ind_nn[i], [i]).tolist() 61 | negatives = np.setdiff1d(df_centroids.index.values.tolist(), ind_r[i]).tolist() 62 | random.shuffle(negatives) 63 | queries[i] = {"query": query, "positives": positives, "negatives": negatives} 64 | 65 | if not os.path.exists('pickles'): 66 | os.mkdir('pickles') 67 | filename = os.path.join('pickles', filename) 68 | with open(filename, 'wb') as handle: 69 | pickle.dump(queries, handle, protocol=pickle.DEFAULT_PROTOCOL) 70 | 71 | print("Done ", filename) 72 | 73 | 74 | ####Initialize pandas DataFrame 75 | df_train = pd.DataFrame(columns=['file', 'northing', 'easting']) 76 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting']) 77 | 78 | for folder in tqdm(folders): 79 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',') 80 | df_locations['timestamp'] = runs_folder + folder + pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin' 81 | df_locations = df_locations.rename(columns={'timestamp': 'file'}) 82 | 83 | for index, row in df_locations.iterrows(): 84 | if (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)): 85 | df_test = df_test.append(row, ignore_index=True) 86 | else: 87 | df_train = df_train.append(row, ignore_index=True) 88 | 89 | print("Number of training submaps: " + str(len(df_train['file']))) 90 | print("Number of non-disjoint test submaps: " + str(len(df_test['file']))) 91 | construct_query_dict(df_train, "training_queries_baseline.pickle") 92 | construct_query_dict(df_test, "test_queries_baseline.pickle") 93 | -------------------------------------------------------------------------------- /generating_queries/generate_training_tuples_baseline_gtav5.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | import sys 5 | from glob import glob 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from sklearn.neighbors import KDTree 10 | from tqdm import tqdm 11 | 12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(BASE_DIR) 14 | sys.path.append(os.path.abspath(os.path.join(BASE_DIR, ".."))) 15 | from registration import registration_withinit 16 | 17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 18 | base_path = '../benchmark_datasets/' 19 | 20 | runs_folder = "GTA5/" 21 | filename = "pointcloud_locations_20m_10overlap.csv" 22 | pointcloud_fols = "/pcl/" 23 | 24 | all_folders = sorted(glob(os.path.join(BASE_DIR, base_path, runs_folder) + '/round*')) 25 | for i in range(len(all_folders)): 26 | all_folders[i] = os.path.basename(all_folders[i]) 27 | 28 | # 正样本范围 点云单位上正负1 29 | pos_dis = 0.5 30 | # 负样本范围 31 | neg_dis = 4 32 | 33 | 34 | def check_in_test_set(northing, easting, points, x_width, y_width): 35 | in_test_set = False 36 | for point in points: 37 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[ 38 | 1] - y_width < easting and easting < point[1] + y_width): 39 | in_test_set = True 40 | break 41 | return in_test_set 42 | 43 | 44 | ########################################## 45 | 46 | 47 | def construct_query_dict(df_centroids, filename): 48 | tree = KDTree(df_centroids[['northing', 'easting', 'altitude']]) 49 | ind_nn = tree.query_radius(df_centroids[['northing', 'easting', 'altitude']], r=pos_dis) 50 | ind_r = tree.query_radius(df_centroids[['northing', 'easting', 'altitude']], r=neg_dis) 51 | queries = {} 52 | for i in tqdm(range(len(ind_nn))): 53 | query = df_centroids.iloc[i]["file"] 54 | positives = np.setdiff1d(ind_nn[i], [i]).tolist() 55 | list_neg = np.setdiff1d(df_centroids.index.values.tolist(), ind_r[i]).tolist() 56 | positives_T = [] 57 | for pos_i in range(len(positives)): 58 | trans_pre = np.asarray(df_centroids.iloc[i][['northing', 'easting', 'altitude']]) - np.asarray( 59 | df_centroids.iloc[positives[pos_i]][['northing', 'easting', 'altitude']]) 60 | T = registration_withinit(base_path + query, base_path + df_centroids.iloc[positives[pos_i]]["file"], 61 | trans_pre=trans_pre) 62 | positives_T.append(T) 63 | # if len(list_neg) > 4001: 64 | # random.shuffle(list_neg) 65 | # list_neg = list_neg[:4001] 66 | negatives = list_neg 67 | random.shuffle(negatives) 68 | queries[i] = {"query": query, "positives": positives, 69 | "positives_T": positives_T, "negatives": negatives} 70 | 71 | with open(filename, 'wb') as handle: 72 | pickle.dump(queries, handle, protocol=pickle.HIGHEST_PROTOCOL) 73 | 74 | print("Done ", filename) 75 | 76 | 77 | # Initialize pandas DataFrame 78 | df_train = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude']) 79 | 80 | for folder in all_folders: 81 | df_locations = pd.read_csv(os.path.join( 82 | base_path, runs_folder, folder, filename), sep=',') 83 | df_locations['timestamp'] = runs_folder + folder + \ 84 | pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin' 85 | df_locations = df_locations.rename(columns={'timestamp': 'file'}) 86 | 87 | for index, row in df_locations.iterrows(): 88 | df_train = df_train.append(row, ignore_index=True) 89 | 90 | if not os.path.exists("pickles/"): 91 | os.mkdir("pickles/") 92 | 93 | print("Number of training submaps: " + str(len(df_train['file']))) 94 | construct_query_dict(df_train, "pickles/" + "gta5_training_queries_baseline.pickle") 95 | -------------------------------------------------------------------------------- /generating_queries/registration.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import sys 3 | 4 | import numpy as np 5 | import open3d as o3d 6 | 7 | # monkey patches visualization and provides helpers to load geometries 8 | sys.path.append('..') 9 | 10 | 11 | def draw_registration_result(source, target, transformation): 12 | source_temp = copy.deepcopy(source) 13 | target_temp = copy.deepcopy(target) 14 | source_temp.paint_uniform_color([1, 0.706, 0]) 15 | target_temp.paint_uniform_color([0, 0.651, 0.929]) 16 | # a=np.asarray(source_temp.points) 17 | source_temp.transform(transformation) 18 | # b=np.asarray(source_temp.points) 19 | # o3d.visualization.draw_geometries([source_temp, target_temp]) 20 | 21 | 22 | def preprocess_point_cloud(pcd, voxel_size, is_voxel=True): 23 | if is_voxel: 24 | # print(":: Downsample with a voxel size %.3f." % voxel_size) 25 | pcd_down = pcd.voxel_down_sample(voxel_size) 26 | else: 27 | pcd_down = pcd 28 | radius_normal = voxel_size * 2 29 | # print(":: Estimate normal with search radius %.3f." % radius_normal) 30 | pcd_down.estimate_normals( 31 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30)) 32 | 33 | radius_feature = voxel_size * 5 34 | # print(":: Compute FPFH feature with search radius %.3f." % radius_feature) 35 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature( 36 | pcd_down, 37 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100)) 38 | return pcd_down, pcd_fpfh 39 | 40 | 41 | def prepare_dataset(voxel_size): 42 | print(":: Load two point clouds and disturb initial pose.") 43 | source = o3d.io.read_point_cloud("./test_data/cloud_bin_0.pcd") 44 | target = o3d.io.read_point_cloud("./test_data/cloud_bin_1.pcd") 45 | # trans_init = np.asarray([[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0], 46 | # [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]]) 47 | # source.transform(trans_init) 48 | # draw_registration_result(source, target, np.identity(4)) 49 | 50 | source_down, source_fpfh = preprocess_point_cloud(source, voxel_size) 51 | target_down, target_fpfh = preprocess_point_cloud(target, voxel_size) 52 | return source, target, source_down, target_down, source_fpfh, target_fpfh 53 | 54 | 55 | def execute_global_registration(source_down, target_down, source_fpfh, 56 | target_fpfh, voxel_size): 57 | distance_threshold = voxel_size * 1.5 58 | # print(":: RANSAC registration on downsampled point clouds.") 59 | # print(" Since the downsampling voxel size is %.3f," % voxel_size) 60 | # print(" we use a liberal distance threshold %.3f." % distance_threshold) 61 | result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching( 62 | source_down, target_down, source_fpfh, target_fpfh, distance_threshold, 63 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False), 64 | 4, [ 65 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength( 66 | 0.9), 67 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance( 68 | distance_threshold) 69 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, 500)) 70 | return result 71 | 72 | 73 | def refine_registration(source, target, init_transformation, voxel_size=-1): 74 | if voxel_size == -1: 75 | distance_threshold = 0.01 76 | else: 77 | distance_threshold = voxel_size * 0.4 78 | # print(":: Point-to-plane ICP registration is applied on original point") 79 | # print(" clouds to refine the alignment. This time we use a strict") 80 | # print(" distance threshold %.3f." % distance_threshold) 81 | result = o3d.pipelines.registration.registration_icp( 82 | source, target, distance_threshold, init_transformation, 83 | o3d.pipelines.registration.TransformationEstimationPointToPoint()) 84 | return result 85 | 86 | 87 | def execute_fast_global_registration(source_down, target_down, source_fpfh, 88 | target_fpfh, voxel_size): 89 | distance_threshold = voxel_size * 0.5 90 | print(":: Apply fast global registration with distance threshold %.3f" \ 91 | % distance_threshold) 92 | result = o3d.pipelines.registration.registration_fast_based_on_feature_matching( 93 | source_down, target_down, source_fpfh, target_fpfh, 94 | o3d.pipelines.registration.FastGlobalRegistrationOption( 95 | maximum_correspondence_distance=distance_threshold)) 96 | return result 97 | 98 | 99 | def demo_open3d(): 100 | voxel_size = 0.05 # means 5cm for this dataset 101 | source, target, source_down, target_down, source_fpfh, target_fpfh = prepare_dataset( 102 | voxel_size) 103 | 104 | result_ransac = execute_global_registration(source_down, target_down, 105 | source_fpfh, target_fpfh, 106 | voxel_size) 107 | draw_registration_result(source_down, target_down, result_ransac.transformation) 108 | 109 | result_fast = execute_fast_global_registration(source_down, target_down, 110 | source_fpfh, target_fpfh, 111 | voxel_size) 112 | draw_registration_result(source_down, target_down, result_fast.transformation) 113 | 114 | result_icp = refine_registration(source, target, result_fast.transformation, voxel_size) 115 | draw_registration_result(source, target, result_icp.transformation) 116 | 117 | 118 | def load_bin(path1, path2, voxel_size): 119 | source = o3d.geometry.PointCloud() 120 | source.points = o3d.utility.Vector3dVector(np.fromfile(path1).reshape([-1, 3])) 121 | target = o3d.geometry.PointCloud() 122 | target.points = o3d.utility.Vector3dVector(np.fromfile(path2).reshape([-1, 3])) 123 | 124 | source_down, source_fpfh = preprocess_point_cloud(source, voxel_size, is_voxel=False) 125 | target_down, target_fpfh = preprocess_point_cloud(target, voxel_size, is_voxel=False) 126 | return source, target, source_down, target_down, source_fpfh, target_fpfh 127 | 128 | 129 | def registration_numpy(pcl1, pcl2, trans_pre=np.eye(4)): 130 | source = o3d.geometry.PointCloud() 131 | source.points = o3d.utility.Vector3dVector(pcl1) 132 | target = o3d.geometry.PointCloud() 133 | target.points = o3d.utility.Vector3dVector(pcl2) 134 | 135 | draw_registration_result(source, target, np.eye(4)) 136 | draw_registration_result(source, target, trans_pre) 137 | 138 | 139 | def registration_withinit(path1, path2, trans_pre=[0, 0, 0]): 140 | source = o3d.geometry.PointCloud() 141 | source.points = o3d.utility.Vector3dVector(np.fromfile(path1).reshape([-1, 3])) 142 | target = o3d.geometry.PointCloud() 143 | target.points = o3d.utility.Vector3dVector(np.fromfile(path2).reshape([-1, 3])) 144 | 145 | trans_init = np.asarray([[1.0, 0.0, 0.0, trans_pre[0]], [0.0, 1.0, 0.0, trans_pre[1]], 146 | [0.0, 0.0, 1.0, trans_pre[2]], [0.0, 0.0, 0.0, 1.0]]) 147 | 148 | result_icp = refine_registration(source, target, trans_init) 149 | # draw_registration_result(source, target, result_icp.transformation) 150 | 151 | T = np.asarray(result_icp.transformation) 152 | 153 | return T 154 | 155 | 156 | if __name__ == '__main__': 157 | # demo_open3d() 158 | path1 = '../benchmark_datasets/GTA5/round1/pcl/1.bin' 159 | path2 = '../benchmark_datasets/GTA5/round1/pcl/2.bin' 160 | 161 | # path1 = '../benchmark_datasets/oxford/2014-05-19-13-20-57/pointcloud_20m_10overlap/1400505893170765.bin' 162 | # path2 = '../benchmark_datasets/oxford/2014-05-19-13-20-57/pointcloud_20m_10overlap/1400505894395159.bin' 163 | registration_withinit(path1, path2, voxel_size=0.00125) 164 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/loss/__init__.py -------------------------------------------------------------------------------- /loss/d_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from misc.log import log_string 4 | 5 | 6 | def make_d_loss(params): 7 | if params.d_loss == 'BCELogitWithMask': 8 | loss_fn = BCELogitWithMask() 9 | elif params.d_loss == 'MSEWithMask': 10 | loss_fn = MSEWithMask() 11 | elif params.d_loss == 'MSE': 12 | loss_fn = MSE() 13 | elif params.d_loss == 'BCELogit': 14 | loss_fn = BCELogit() 15 | else: 16 | log_string('Unknown loss: {}'.format(params.d_loss)) 17 | raise NotImplementedError 18 | return loss_fn 19 | 20 | 21 | class DisLossBase: 22 | # Hard triplet miner 23 | def __init__(self): 24 | self.distance = None 25 | self.mask = False 26 | if torch.cuda.is_available(): 27 | self.device = "cuda" 28 | else: 29 | self.device = "cpu" 30 | 31 | def __call__(self, syn, hard_triplets, real=None): 32 | assert self.distance is not None 33 | assert len(syn.shape) == 2 34 | 35 | if self.mask: 36 | a, p, n = hard_triplets 37 | syn = syn[a] 38 | real = real[a] if real is not None else None 39 | 40 | if real is not None: 41 | assert len(real.shape) == 2 42 | syn_label = torch.zeros((syn.shape[0], 1), dtype=torch.float32, device=self.device) 43 | real_label = torch.ones((real.shape[0], 1), dtype=torch.float32, device=self.device) 44 | label = torch.cat((syn_label, real_label), dim=1) 45 | out = torch.cat((syn, real), dim=1) 46 | else: 47 | label = torch.ones((syn.shape[0], 1), dtype=torch.float32, device=self.device) 48 | out = syn 49 | 50 | loss = self.distance(out, label) 51 | 52 | if real is not None: 53 | stats = {'d_loss': loss.detach().cpu().item()} 54 | else: 55 | stats = {'g_d_loss': loss.detach().cpu().item()} 56 | 57 | return loss, stats 58 | 59 | 60 | class BCELogitWithMask(DisLossBase): 61 | def __init__(self): 62 | super(BCELogitWithMask, self).__init__() 63 | self.distance = torch.nn.BCEWithLogitsLoss(reduction='mean').to(self.device) 64 | self.mask = True 65 | 66 | 67 | class MSEWithMask(DisLossBase): 68 | def __init__(self): 69 | super(MSEWithMask, self).__init__() 70 | self.distance = torch.nn.MSELoss(reduction='mean').to(self.device) 71 | self.mask = True 72 | 73 | 74 | class BCELogit(DisLossBase): 75 | def __init__(self): 76 | super(BCELogit, self).__init__() 77 | self.distance = torch.nn.BCEWithLogitsLoss(reduction='mean').to(self.device) 78 | self.mask = False 79 | 80 | 81 | class MSE(DisLossBase): 82 | def __init__(self): 83 | super(MSE, self).__init__() 84 | self.distance = torch.nn.MSELoss(reduction='mean').to(self.device) 85 | self.mask = False 86 | -------------------------------------------------------------------------------- /loss/metric_loss.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | 6 | import numpy as np 7 | import torch 8 | from pytorch_metric_learning import losses 9 | from pytorch_metric_learning.distances import LpDistance 10 | from scipy.spatial.transform import Rotation 11 | 12 | from loss.mmd_loss import MMD_loss 13 | from misc.log import log_string 14 | 15 | D2G = np.pi / 180.2 16 | 17 | 18 | def make_loss(params): 19 | if params.loss == 'BatchHardTripletMarginLoss': 20 | # BatchHard mining with triplet margin loss 21 | # Expects input: embeddings, positives_mask, negatives_mask 22 | loss_fn = BatchHardTripletLossWithMasks(params.margin, params.normalize_embeddings, params.swap) 23 | elif params.loss == 'BatchHardContrastiveLoss': 24 | loss_fn = BatchHardContrastiveLossWithMasks(params.pos_margin, params.neg_margin, params.normalize_embeddings) 25 | elif params.loss == 'LocQuatMSELoss': 26 | loss_fn = BatchLocQuatLoss(torch.nn.MSELoss()) 27 | elif params.loss == 'LocQuatL1Loss': 28 | loss_fn = BatchLocQuatLoss(torch.nn.L1Loss()) 29 | elif params.loss == 'LocQuatCloudLoss': 30 | loss_fn = BatchLocQuatCloudLoss() 31 | else: 32 | log_string('Unknown loss: {}'.format(params.ldss)) 33 | raise NotImplementedError 34 | return loss_fn 35 | 36 | 37 | class MMDLoss: 38 | def __init__(self, margin, normalize_embeddings, swap): 39 | self.margin = margin 40 | self.normalize_embeddings = normalize_embeddings 41 | self.loss_fn = MMD_loss() 42 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings) 43 | # We use triplet loss with Euclidean distance 44 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 45 | 46 | def __call__(self, embeddings, positives_mask, negatives_mask): 47 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 48 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 49 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 50 | stats = {'metric_loss': loss.detach().cpu().item(), 51 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 52 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter, 53 | 'num_triplets': len(hard_triplets[0]), 54 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 55 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 56 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 57 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 58 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 59 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist 60 | } 61 | 62 | 63 | def quat_to_mat_torch(q, req_grad=False): 64 | x = q[0] 65 | y = q[1] 66 | z = q[2] 67 | w = q[3] 68 | m11 = 1 - 2 * y * y - 2 * z * z 69 | m21 = 2 * x * y + 2 * w * z 70 | m31 = 2 * x * z - 2 * w * y 71 | m12 = 2 * x * y - 2 * w * z 72 | m22 = 1 - 2 * x * x - 1 * 2 * z * z 73 | m32 = 2 * y * z + 2 * w * x 74 | m13 = 2 * x * z + 2 * w * y 75 | m23 = 2 * y * z - 2 * w * x 76 | m33 = 1 - 2 * x * x - 2 * y * y 77 | 78 | return torch.tensor([[m11, m12, m13], [m21, m22, m23], [m31, m32, m33]], requires_grad=req_grad) 79 | 80 | 81 | def to_clouds(batch, device=torch.device('cpu')): 82 | clouds = [[] for i in range(batch[-1][0] + 1)] 83 | for p in batch: 84 | clouds[p[0]].append(p[1:]) 85 | clouds = [torch.Tensor(c).to(device) for c in clouds] 86 | return clouds 87 | 88 | 89 | class BatchLocQuatCloudLoss: 90 | def __init__(self): 91 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 92 | self.loss_fn = torch.nn.L1Loss() 93 | pass 94 | 95 | def __call__(self, batch, quat, truth_quat): 96 | batch = batch['coords'].detach().numpy() 97 | rots = [] 98 | 99 | loss = 0.0 100 | raw_clouds = to_clouds(batch, self.device) 101 | for q, qt, cloud in zip(quat, truth_quat, raw_clouds): 102 | m1 = quat_to_mat_torch(q, True).to(self.device) 103 | m2 = quat_to_mat_torch(qt).to(self.device) 104 | 105 | c1 = torch.matmul(cloud, m1) 106 | c2 = torch.matmul(cloud, m2) 107 | 108 | loss += self.loss_fn(c1, c2) 109 | 110 | loss = loss / len(raw_clouds) 111 | return loss, {'loss': loss} 112 | 113 | 114 | class BatchLocQuatLoss: 115 | def __init__(self, tloss): 116 | self.loss_fn = tloss 117 | 118 | def __call__(self, quats, quats_truth): 119 | loss = self.loss_fn(quats, quats_truth) 120 | rloss = np.array([0., 0., 0.]) 121 | maxrloss = 0 122 | max_loss_r1 = [] 123 | max_loss_r2 = [] 124 | for q, t in zip(quats.detach().cpu().numpy(), quats_truth.detach().cpu().numpy()): 125 | r1 = Rotation.from_quat(q).as_rotvec() / D2G 126 | r2 = Rotation.from_quat(t).as_rotvec() / D2G 127 | r = abs(r1 - r2) 128 | rloss += r 129 | if maxrloss < np.mean(r): 130 | max_loss_r1 = r1 131 | max_loss_r2 = r2 132 | rloss = np.mean(rloss) / len(quats) 133 | rloss2 = np.linalg.norm(rloss) / len(quats) 134 | return loss, {'loss': loss.detach().cpu(), 135 | 'avg.rloss': rloss, 136 | 'avg.rloss2': rloss2, 137 | 'max_loss_r1_x': max_loss_r1[0], 138 | 'max_loss_r1_y': max_loss_r1[1], 139 | 'max_loss_r1_z': max_loss_r1[2], 140 | 141 | 'max_loss_r2_x': max_loss_r2[0], 142 | 'max_loss_r2_y': max_loss_r2[1], 143 | 'max_loss_r2_z': max_loss_r2[2], } 144 | 145 | 146 | class HardTripletMinerWithMasks: 147 | # Hard triplet miner 148 | def __init__(self, distance): 149 | self.distance = distance 150 | # Stats 151 | self.max_pos_pair_dist = None 152 | self.max_neg_pair_dist = None 153 | self.mean_pos_pair_dist = None 154 | self.mean_neg_pair_dist = None 155 | self.min_pos_pair_dist = None 156 | self.min_neg_pair_dist = None 157 | 158 | def __call__(self, embeddings, positives_mask, negatives_mask): 159 | assert embeddings.dim() == 2 160 | d_embeddings = embeddings.detach() 161 | with torch.no_grad(): 162 | hard_triplets = self.mine(d_embeddings, positives_mask, negatives_mask) 163 | return hard_triplets 164 | 165 | def mine(self, embeddings, positives_mask, negatives_mask): 166 | # Based on pytorch-metric-learning implementation n x n 167 | dist_mat = self.distance(embeddings) 168 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask) 169 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 170 | a_keep_idx = torch.where(a1p_keep & a2n_keep) 171 | a = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx] 172 | p = hardest_positive_indices[a_keep_idx] 173 | n = hardest_negative_indices[a_keep_idx] 174 | self.max_pos_pair_dist = torch.max(hardest_positive_dist).item() 175 | self.max_neg_pair_dist = torch.max(hardest_negative_dist).item() 176 | self.mean_pos_pair_dist = torch.mean(hardest_positive_dist).item() 177 | self.mean_neg_pair_dist = torch.mean(hardest_negative_dist).item() 178 | self.min_pos_pair_dist = torch.min(hardest_positive_dist).item() 179 | self.min_neg_pair_dist = torch.min(hardest_negative_dist).item() 180 | return a, p, n 181 | 182 | 183 | def get_max_per_row(mat, mask): 184 | non_zero_rows = torch.any(mask, dim=1) 185 | mat_masked = mat.clone().detach() 186 | mat_masked[~mask] = 0 187 | # mat_masked2=torch.where(mask,mat,torch.zeros(mat.shape)) 188 | # pdb.set_trace() 189 | return torch.max(mat_masked, dim=1), non_zero_rows 190 | 191 | 192 | def get_min_per_row(mat, mask): 193 | non_inf_rows = torch.any(mask, dim=1) 194 | mat_masked = mat.clone().detach() 195 | mat_masked[~mask] = float('inf') 196 | return torch.min(mat_masked, dim=1), non_inf_rows 197 | 198 | 199 | class BatchHardTripletLossWithMasks: 200 | def __init__(self, margin, normalize_embeddings, swap): 201 | self.margin = margin 202 | self.normalize_embeddings = normalize_embeddings 203 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings) 204 | # We use triplet loss with Euclidean distance 205 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 206 | self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=swap, distance=self.distance) 207 | 208 | # 3510 209 | def __call__(self, embeddings, positives_mask, negatives_mask): 210 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 211 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 212 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 213 | stats = {'metric_loss': loss.detach().cpu().item(), 214 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 215 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter, 216 | 'num_triplets': len(hard_triplets[0]), 217 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 218 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 219 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 220 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 221 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 222 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist 223 | } 224 | 225 | return loss, stats, hard_triplets 226 | 227 | 228 | class BatchHardContrastiveLossWithMasks: 229 | def __init__(self, pos_margin, neg_margin, normalize_embeddings): 230 | self.pos_margin = pos_margin 231 | self.neg_margin = neg_margin 232 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings) 233 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance) 234 | # We use contrastive loss with squared Euclidean distance 235 | self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin, neg_margin=self.neg_margin, 236 | distance=self.distance) 237 | 238 | def __call__(self, embeddings, positives_mask, negatives_mask): 239 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask) 240 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device) 241 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets) 242 | stats = {'metric_loss': loss.detach().cpu().item(), 243 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm, 244 | 'pos_pairs_above_low': self.loss_fn.reducer.reducers['pos_loss'].pos_pairs_above_low, 245 | 'neg_pairs_above_low': self.loss_fn.reducer.reducers['neg_loss'].neg_pairs_above_low, 246 | 'pos_loss': self.loss_fn.reducer.reducers['pos_loss'].pos_loss, 247 | 'neg_loss': self.loss_fn.reducer.reducers['neg_loss'].neg_loss, 248 | 'num_pairs': 2 * len(hard_triplets[0]), 249 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist, 250 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist, 251 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist, 252 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist, 253 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist, 254 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist 255 | } 256 | 257 | return loss, stats, hard_triplets 258 | -------------------------------------------------------------------------------- /loss/mmd_loss.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | 4 | min_var_est = 1e-8 5 | 6 | 7 | # Consider linear time MMD with a linear kernel: 8 | # K(f(x), f(y)) = f(x)^Tf(y) 9 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 10 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 11 | # 12 | # f_of_X: batch_size * k 13 | # f_of_Y: batch_size * k 14 | def linear_mmd2(f_of_X, f_of_Y): 15 | loss = 0.0 16 | delta = f_of_X - f_of_Y 17 | # pdb.set_trace() 18 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1)) 19 | return loss 20 | 21 | 22 | # Consider linear time MMD with a polynomial kernel: 23 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d 24 | # f_of_X: batch_size * k 25 | # f_of_Y: batch_size * k 26 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0): 27 | K_XX = (alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c) 28 | K_XX_mean = torch.mean(K_XX.pow(d)) 29 | 30 | K_YY = (alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c) 31 | K_YY_mean = torch.mean(K_YY.pow(d)) 32 | 33 | K_XY = (alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c) 34 | K_XY_mean = torch.mean(K_XY.pow(d)) 35 | 36 | K_YX = (alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c) 37 | K_YX_mean = torch.mean(K_YX.pow(d)) 38 | 39 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean 40 | 41 | 42 | def _mix_rbf_kernel(X, Y, sigma_list): 43 | assert (X.size(0) == Y.size(0)) 44 | m = X.size(0) 45 | 46 | Z = torch.cat((X, Y), 0) 47 | ZZT = torch.mm(Z, Z.t()) 48 | diag_ZZT = torch.diag(ZZT).unsqueeze(1) 49 | Z_norm_sqr = diag_ZZT.expand_as(ZZT) 50 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t() 51 | 52 | K = 0.0 53 | for sigma in sigma_list: 54 | gamma = 1.0 / (2 * sigma ** 2) 55 | K += torch.exp(-gamma * exponent) 56 | 57 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list) 58 | 59 | 60 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True): 61 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 62 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 63 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 64 | 65 | 66 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True): 67 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list) 68 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased) 69 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased) 70 | 71 | 72 | ################################################################################ 73 | # Helper functions to compute variances based on kernel matrices 74 | ################################################################################ 75 | 76 | 77 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 78 | m = K_XX.size(0) # assume X, Y are same shape 79 | 80 | # Get the various sums of kernels that we'll use 81 | # Kts drop the diagonal, but we don't need to compute them explicitly 82 | if const_diagonal is not False: 83 | diag_X = diag_Y = const_diagonal 84 | sum_diag_X = sum_diag_Y = m * const_diagonal 85 | else: 86 | diag_X = torch.diag(K_XX) # (m,) 87 | diag_Y = torch.diag(K_YY) # (m,) 88 | sum_diag_X = torch.sum(diag_X) 89 | sum_diag_Y = torch.sum(diag_Y) 90 | 91 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 92 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 93 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 94 | 95 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 96 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 97 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 98 | 99 | if biased: 100 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 101 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 102 | - 2.0 * K_XY_sum / (m * m)) 103 | else: 104 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 105 | + Kt_YY_sum / (m * (m - 1)) 106 | - 2.0 * K_XY_sum / (m * m)) 107 | 108 | return mmd2 109 | 110 | 111 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 112 | mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased) 113 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est)) 114 | return loss, mmd2, var_est 115 | 116 | 117 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False): 118 | m = K_XX.size(0) # assume X, Y are same shape 119 | 120 | # Get the various sums of kernels that we'll use 121 | # Kts drop the diagonal, but we don't need to compute them explicitly 122 | if const_diagonal is not False: 123 | diag_X = diag_Y = const_diagonal 124 | sum_diag_X = sum_diag_Y = m * const_diagonal 125 | sum_diag2_X = sum_diag2_Y = m * const_diagonal ** 2 126 | else: 127 | diag_X = torch.diag(K_XX) # (m,) 128 | diag_Y = torch.diag(K_YY) # (m,) 129 | sum_diag_X = torch.sum(diag_X) 130 | sum_diag_Y = torch.sum(diag_Y) 131 | sum_diag2_X = diag_X.dot(diag_X) 132 | sum_diag2_Y = diag_Y.dot(diag_Y) 133 | 134 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X 135 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y 136 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e 137 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e 138 | 139 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e 140 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e 141 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e 142 | 143 | Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2 144 | Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2 145 | K_XY_2_sum = (K_XY ** 2).sum() # \| K_{XY} \|_F^2 146 | 147 | if biased: 148 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m) 149 | + (Kt_YY_sum + sum_diag_Y) / (m * m) 150 | - 2.0 * K_XY_sum / (m * m)) 151 | else: 152 | mmd2 = (Kt_XX_sum / (m * (m - 1)) 153 | + Kt_YY_sum / (m * (m - 1)) 154 | - 2.0 * K_XY_sum / (m * m)) 155 | 156 | var_est = ( 157 | 2.0 / (m ** 2 * (m - 1.0) ** 2) * ( 158 | 2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum) 159 | - (4.0 * m - 6.0) / (m ** 3 * (m - 1.0) ** 3) * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2) 160 | + 4.0 * (m - 2.0) / (m ** 3 * (m - 1.0) ** 2) * ( 161 | K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0)) 162 | - 4.0 * (m - 3.0) / (m ** 3 * (m - 1.0) ** 2) * (K_XY_2_sum) - (8 * m - 12) / ( 163 | m ** 5 * (m - 1)) * K_XY_sum ** 2 164 | + 8.0 / (m ** 3 * (m - 1.0)) * ( 165 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum 166 | - Kt_XX_sums.dot(K_XY_sums_1) 167 | - Kt_YY_sums.dot(K_XY_sums_0)) 168 | ) 169 | return mmd2, var_est 170 | 171 | 172 | # https://github.com/ZongxianLee/MMD_Loss.Pytorch 173 | import torch 174 | import torch.nn as nn 175 | 176 | 177 | class MMD_loss(nn.Module): 178 | def __init__(self, kernel_mul=2.0, kernel_num=5): 179 | super(MMD_loss, self).__init__() 180 | self.kernel_num = kernel_num 181 | self.kernel_mul = kernel_mul 182 | self.fix_sigma = None 183 | return 184 | 185 | @staticmethod 186 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 187 | n_samples = int(source.size()[0]) + int(target.size()[0]) 188 | total = torch.cat([source, target], dim=0) 189 | 190 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 191 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 192 | L2_distance = ((total0 - total1) ** 2).sum(2) 193 | if fix_sigma: 194 | bandwidth = fix_sigma 195 | else: 196 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 197 | bandwidth /= kernel_mul ** (kernel_num // 2) 198 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 199 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 200 | return sum(kernel_val) 201 | 202 | def forward(self, source, target): 203 | batch_size = int(source.size()[0]) 204 | kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, 205 | fix_sigma=self.fix_sigma) 206 | XX = kernels[:batch_size, :batch_size] 207 | YY = kernels[batch_size:, batch_size:] 208 | XY = kernels[:batch_size, batch_size:] 209 | YX = kernels[batch_size:, :batch_size] 210 | loss = torch.mean(XX + YY - XY - YX) 211 | return loss 212 | -------------------------------------------------------------------------------- /loss/ops/emd/LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /loss/ops/emd/README.md: -------------------------------------------------------------------------------- 1 | This code was taken from [https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd](https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd). It is a part of the implementation of the following paper: 2 | 3 | ``` 4 | @article{liu2019morphing, 5 | title={Morphing and Sampling Network for Dense Point Cloud Completion}, 6 | author={Liu, Minghua and Sheng, Lu and Yang, Sheng and Shao, Jing and Hu, Shi-Min}, 7 | journal={arXiv preprint arXiv:1912.00280}, 8 | year={2019} 9 | } 10 | ``` 11 | 12 | ## Earth Mover's Distance of point clouds 13 | 14 | Compared to the Chamfer Distance (CD), the Earth Mover's Distance (EMD) is more reliable to distinguish the visual quality of the point clouds. See our [paper](http://cseweb.ucsd.edu/~mil070/projects/AAAI2020/paper.pdf) for more details. 15 | 16 | We provide an EMD implementation for point cloud comparison, which only needs $O(n)$ memory and thus enables dense point clouds (with 10,000 points or over) and large batch size. It is based on an approximated algorithm (auction algorithm) and cannot guarantee a (but near) bijection assignment. It employs a parameter $\epsilon$ to balance the error rate and the speed of convergence. Smaller $\epsilon$ achieves more accurate results, but needs a longer time for convergence. The time complexity is $O(n^2k)$, where $k$ is the number of iterations. We set a $\epsilon = 0.005, k = 50$ during training and a $\epsilon = 0.002, k = 10000$ during testing. 17 | 18 | ### Compile 19 | Run `python3 setup.py install` to compile. 20 | 21 | ### Example 22 | See `emd_module.py/test_emd()` for examples. 23 | 24 | ### Input 25 | 26 | - **xyz1, xyz2**: float tensors with shape `[#batch, #points, 3]`. xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud. Two point clouds should have same size and be normalized to [0, 1]. The number of points should be a multiple of 1024. The batch size should be no greater than 512. Since we only calculate gradients for xyz1, please do not swap xyz1 and xyz2. 27 | - **eps**: a float tensor, the parameter balances the error rate and the speed of convergence. 28 | - **iters**: a int tensor, the number of iterations. 29 | 30 | ### Output 31 | 32 | - **dist**: a float tensor with shape `[#batch, #points]`. sqrt(dist) are the L2 distances between the pairs of points. 33 | - **assignment**: a int tensor with shape `[#batch, #points]`. The index of the matched point in the ground truth point cloud. 34 | -------------------------------------------------------------------------------- /loss/ops/emd/emd.cpp: -------------------------------------------------------------------------------- 1 | // EMD approximation module (based on auction algorithm) 2 | // author: Minghua Liu 3 | #include 4 | #include 5 | 6 | int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, 7 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, 8 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters); 9 | 10 | int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx); 11 | 12 | 13 | 14 | int emd_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, 15 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, 16 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) { 17 | return emd_cuda_forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters); 18 | } 19 | 20 | int emd_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) { 21 | 22 | return emd_cuda_backward(xyz1, xyz2, gradxyz, graddist, idx); 23 | } 24 | 25 | 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &emd_forward, "emd forward (CUDA)"); 30 | m.def("backward", &emd_backward, "emd backward (CUDA)"); 31 | } -------------------------------------------------------------------------------- /loss/ops/emd/emd_cuda.cu: -------------------------------------------------------------------------------- 1 | // EMD approximation module (based on auction algorithm) 2 | // author: Minghua Liu 3 | #include 4 | #include 5 | 6 | #include 7 | #include 8 | #include 9 | 10 | __device__ __forceinline__ float atomicMax(float *address, float val) 11 | { 12 | int ret = __float_as_int(*address); 13 | while(val > __int_as_float(ret)) 14 | { 15 | int old = ret; 16 | if((ret = atomicCAS((int *)address, old, __float_as_int(val))) == old) 17 | break; 18 | } 19 | return __int_as_float(ret); 20 | } 21 | 22 | 23 | __global__ void clear(int b, int * cnt_tmp, int * unass_cnt) { 24 | for (int i = threadIdx.x; i < b; i += blockDim.x) { 25 | cnt_tmp[i] = 0; 26 | unass_cnt[i] = 0; 27 | } 28 | } 29 | 30 | __global__ void calc_unass_cnt(int b, int n, int * assignment, int * unass_cnt) { 31 | // count the number of unassigned points in each batch 32 | const int BLOCK_SIZE = 1024; 33 | __shared__ int scan_array[BLOCK_SIZE]; 34 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 35 | scan_array[threadIdx.x] = assignment[i * n + blockIdx.y * BLOCK_SIZE + threadIdx.x] == -1 ? 1 : 0; 36 | __syncthreads(); 37 | 38 | int stride = 1; 39 | while(stride <= BLOCK_SIZE / 2) { 40 | int index = (threadIdx.x + 1) * stride * 2 - 1; 41 | if(index < BLOCK_SIZE) 42 | scan_array[index] += scan_array[index - stride]; 43 | stride = stride * 2; 44 | __syncthreads(); 45 | } 46 | __syncthreads(); 47 | 48 | if (threadIdx.x == BLOCK_SIZE - 1) { 49 | atomicAdd(&unass_cnt[i], scan_array[threadIdx.x]); 50 | } 51 | __syncthreads(); 52 | } 53 | } 54 | 55 | __global__ void calc_unass_cnt_sum(int b, int * unass_cnt, int * unass_cnt_sum) { 56 | // count the cumulative sum over over unass_cnt 57 | const int BLOCK_SIZE = 512; // batch_size <= 512 58 | __shared__ int scan_array[BLOCK_SIZE]; 59 | scan_array[threadIdx.x] = unass_cnt[threadIdx.x]; 60 | __syncthreads(); 61 | 62 | int stride = 1; 63 | while(stride <= BLOCK_SIZE / 2) { 64 | int index = (threadIdx.x + 1) * stride * 2 - 1; 65 | if(index < BLOCK_SIZE) 66 | scan_array[index] += scan_array[index - stride]; 67 | stride = stride * 2; 68 | __syncthreads(); 69 | } 70 | __syncthreads(); 71 | stride = BLOCK_SIZE / 4; 72 | while(stride > 0) { 73 | int index = (threadIdx.x + 1) * stride * 2 - 1; 74 | if((index + stride) < BLOCK_SIZE) 75 | scan_array[index + stride] += scan_array[index]; 76 | stride = stride / 2; 77 | __syncthreads(); 78 | } 79 | __syncthreads(); 80 | 81 | //printf("%d\n", unass_cnt_sum[b - 1]); 82 | unass_cnt_sum[threadIdx.x] = scan_array[threadIdx.x]; 83 | } 84 | 85 | __global__ void calc_unass_idx(int b, int n, int * assignment, int * unass_idx, int * unass_cnt, int * unass_cnt_sum, int * cnt_tmp) { 86 | // list all the unassigned points 87 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 88 | if (assignment[i * n + blockIdx.y * 1024 + threadIdx.x] == -1) { 89 | int idx = atomicAdd(&cnt_tmp[i], 1); 90 | unass_idx[unass_cnt_sum[i] - unass_cnt[i] + idx] = blockIdx.y * 1024 + threadIdx.x; 91 | } 92 | } 93 | } 94 | 95 | __global__ void Bid(int b, int n, const float * xyz1, const float * xyz2, float eps, int * assignment, int * assignment_inv, float * price, 96 | int * bid, float * bid_increments, float * max_increments, int * unass_cnt, int * unass_cnt_sum, int * unass_idx) { 97 | const int batch = 2048, block_size = 1024, block_cnt = n / 1024; 98 | __shared__ float xyz2_buf[batch * 3]; 99 | __shared__ float price_buf[batch]; 100 | __shared__ float best_buf[block_size]; 101 | __shared__ float better_buf[block_size]; 102 | __shared__ int best_i_buf[block_size]; 103 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 104 | int _unass_cnt = unass_cnt[i]; 105 | if (_unass_cnt == 0) 106 | continue; 107 | int _unass_cnt_sum = unass_cnt_sum[i]; 108 | int unass_per_block = (_unass_cnt + block_cnt - 1) / block_cnt; 109 | int thread_per_unass = block_size / unass_per_block; 110 | int unass_this_block = max(min(_unass_cnt - (int) blockIdx.y * unass_per_block, unass_per_block), 0); 111 | 112 | float x1, y1, z1, best = -1e9, better = -1e9; 113 | int best_i = -1, _unass_id = -1, thread_in_unass; 114 | 115 | if (threadIdx.x < thread_per_unass * unass_this_block) { 116 | _unass_id = unass_per_block * blockIdx.y + threadIdx.x / thread_per_unass + _unass_cnt_sum - _unass_cnt; 117 | _unass_id = unass_idx[_unass_id]; 118 | thread_in_unass = threadIdx.x % thread_per_unass; 119 | 120 | x1 = xyz1[(i * n + _unass_id) * 3 + 0]; 121 | y1 = xyz1[(i * n + _unass_id) * 3 + 1]; 122 | z1 = xyz1[(i * n + _unass_id) * 3 + 2]; 123 | } 124 | 125 | for (int k2 = 0; k2 < n; k2 += batch) { 126 | int end_k = min(n, k2 + batch) - k2; 127 | for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) { 128 | xyz2_buf[j] = xyz2[(i * n + k2) * 3 + j]; 129 | } 130 | for (int j = threadIdx.x; j < end_k; j += blockDim.x) { 131 | price_buf[j] = price[i * n + k2 + j]; 132 | } 133 | __syncthreads(); 134 | 135 | if (_unass_id != -1) { 136 | int delta = (end_k + thread_per_unass - 1) / thread_per_unass; 137 | int l = thread_in_unass * delta; 138 | int r = min((thread_in_unass + 1) * delta, end_k); 139 | for (int k = l; k < r; k++) 140 | //if (!last || assignment_inv[i * n + k + k2] == -1) 141 | { 142 | float x2 = xyz2_buf[k * 3 + 0] - x1; 143 | float y2 = xyz2_buf[k * 3 + 1] - y1; 144 | float z2 = xyz2_buf[k * 3 + 2] - z1; 145 | // the coordinates of points should be normalized to [0, 1] 146 | float d = 3.0 - sqrtf(x2 * x2 + y2 * y2 + z2 * z2) - price_buf[k]; 147 | if (d > best) { 148 | better = best; 149 | best = d; 150 | best_i = k + k2; 151 | } 152 | else if (d > better) { 153 | better = d; 154 | } 155 | } 156 | } 157 | __syncthreads(); 158 | } 159 | 160 | best_buf[threadIdx.x] = best; 161 | better_buf[threadIdx.x] = better; 162 | best_i_buf[threadIdx.x] = best_i; 163 | __syncthreads(); 164 | 165 | if (_unass_id != -1 && thread_in_unass == 0) { 166 | for (int j = threadIdx.x + 1; j < threadIdx.x + thread_per_unass; j++) { 167 | if (best_buf[j] > best) { 168 | better = max(best, better_buf[j]); 169 | best = best_buf[j]; 170 | best_i = best_i_buf[j]; 171 | } 172 | else better = max(better, best_buf[j]); 173 | } 174 | bid[i * n + _unass_id] = best_i; 175 | bid_increments[i * n + _unass_id] = best - better + eps; 176 | atomicMax(&max_increments[i * n + best_i], best - better + eps); 177 | } 178 | } 179 | } 180 | 181 | __global__ void GetMax(int b, int n, int * assignment, int * bid, float * bid_increments, float * max_increments, int * max_idx) { 182 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 183 | int j = threadIdx.x + blockIdx.y * blockDim.x; 184 | if (assignment[i * n + j] == -1) { 185 | int bid_id = bid[i * n + j]; 186 | float bid_inc = bid_increments[i * n + j]; 187 | float max_inc = max_increments[i * n + bid_id]; 188 | if (bid_inc - 1e-6 <= max_inc && max_inc <= bid_inc + 1e-6) 189 | { 190 | max_idx[i * n + bid_id] = j; 191 | } 192 | } 193 | } 194 | } 195 | 196 | __global__ void Assign(int b, int n, int * assignment, int * assignment_inv, float * price, int * bid, float * bid_increments, float * max_increments, int * max_idx, bool last) { 197 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 198 | int j = threadIdx.x + blockIdx.y * blockDim.x; 199 | if (assignment[i * n + j] == -1) { 200 | int bid_id = bid[i * n + j]; 201 | if (last || max_idx[i * n + bid_id] == j) 202 | { 203 | float bid_inc = bid_increments[i * n + j]; 204 | int ass_inv = assignment_inv[i * n + bid_id]; 205 | if (!last && ass_inv != -1) { 206 | assignment[i * n + ass_inv] = -1; 207 | } 208 | assignment_inv[i * n + bid_id] = j; 209 | assignment[i * n + j] = bid_id; 210 | price[i * n + bid_id] += bid_inc; 211 | max_increments[i * n + bid_id] = -1e9; 212 | } 213 | } 214 | } 215 | } 216 | 217 | __global__ void CalcDist(int b, int n, float * xyz1, float * xyz2, float * dist, int * assignment) { 218 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 219 | int j = threadIdx.x + blockIdx.y * blockDim.x; 220 | int k = assignment[i * n + j]; 221 | float deltax = xyz1[(i * n + j) * 3 + 0] - xyz2[(i * n + k) * 3 + 0]; 222 | float deltay = xyz1[(i * n + j) * 3 + 1] - xyz2[(i * n + k) * 3 + 1]; 223 | float deltaz = xyz1[(i * n + j) * 3 + 2] - xyz2[(i * n + k) * 3 + 2]; 224 | dist[i * n + j] = deltax * deltax + deltay * deltay + deltaz * deltaz; 225 | } 226 | } 227 | 228 | int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price, 229 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments, 230 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) { 231 | 232 | const auto batch_size = xyz1.size(0); 233 | const auto n = xyz1.size(1); //num_points point cloud A 234 | const auto m = xyz2.size(1); //num_points point cloud B 235 | 236 | if (n != m) { 237 | printf("Input Error! The two point clouds should have the same size.\n"); 238 | return -1; 239 | } 240 | 241 | if (batch_size > 512) { 242 | printf("Input Error! The batch size should be less than 512.\n"); 243 | return -1; 244 | } 245 | 246 | if (n % 1024 != 0) { 247 | printf("Input Error! The size of the point clouds should be a multiple of 1024.\n"); 248 | return -1; 249 | } 250 | 251 | //cudaEvent_t start,stop; 252 | //cudaEventCreate(&start); 253 | //cudaEventCreate(&stop); 254 | //cudaEventRecord(start); 255 | //int iters = 50; 256 | for (int i = 0; i < iters; i++) { 257 | clear<<<1, batch_size>>>(batch_size, cnt_tmp.data(), unass_cnt.data()); 258 | calc_unass_cnt<<>>(batch_size, n, assignment.data(), unass_cnt.data()); 259 | calc_unass_cnt_sum<<<1, batch_size>>>(batch_size, unass_cnt.data(), unass_cnt_sum.data()); 260 | calc_unass_idx<<>>(batch_size, n, assignment.data(), unass_idx.data(), unass_cnt.data(), 261 | unass_cnt_sum.data(), cnt_tmp.data()); 262 | Bid<<>>(batch_size, n, xyz1.data(), xyz2.data(), eps, assignment.data(), assignment_inv.data(), 263 | price.data(), bid.data(), bid_increments.data(), max_increments.data(), 264 | unass_cnt.data(), unass_cnt_sum.data(), unass_idx.data()); 265 | GetMax<<>>(batch_size, n, assignment.data(), bid.data(), bid_increments.data(), max_increments.data(), max_idx.data()); 266 | Assign<<>>(batch_size, n, assignment.data(), assignment_inv.data(), price.data(), bid.data(), 267 | bid_increments.data(), max_increments.data(), max_idx.data(), i == iters - 1); 268 | } 269 | CalcDist<<>>(batch_size, n, xyz1.data(), xyz2.data(), dist.data(), assignment.data()); 270 | //cudaEventRecord(stop); 271 | //cudaEventSynchronize(stop); 272 | //float elapsedTime; 273 | //cudaEventElapsedTime(&elapsedTime,start,stop); 274 | //printf("%lf\n", elapsedTime); 275 | 276 | cudaError_t err = cudaGetLastError(); 277 | if (err != cudaSuccess) { 278 | printf("error in nnd Output: %s\n", cudaGetErrorString(err)); 279 | return 0; 280 | } 281 | return 1; 282 | } 283 | 284 | __global__ void NmDistanceGradKernel(int b, int n, const float * xyz1, const float * xyz2, const float * grad_dist, const int * idx, float * grad_xyz){ 285 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 286 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) { 287 | float x1 = xyz1[(i * n + j) * 3 + 0]; 288 | float y1 = xyz1[(i * n + j) * 3 + 1]; 289 | float z1 = xyz1[(i * n + j) * 3 + 2]; 290 | int j2 = idx[i * n + j]; 291 | float x2 = xyz2[(i * n + j2) * 3 + 0]; 292 | float y2 = xyz2[(i * n + j2) * 3 + 1]; 293 | float z2 = xyz2[(i * n + j2) * 3 + 2]; 294 | float g = grad_dist[i * n + j] * 2; 295 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 0]), g * (x1 - x2)); 296 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 1]), g * (y1 - y2)); 297 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 2]), g * (z1 - z2)); 298 | } 299 | } 300 | } 301 | 302 | int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx){ 303 | const auto batch_size = xyz1.size(0); 304 | const auto n = xyz1.size(1); 305 | const auto m = xyz2.size(1); 306 | 307 | NmDistanceGradKernel<<>>(batch_size, n, xyz1.data(), xyz2.data(), graddist.data(), idx.data(), gradxyz.data()); 308 | 309 | cudaError_t err = cudaGetLastError(); 310 | if (err != cudaSuccess) { 311 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 312 | return 0; 313 | } 314 | return 1; 315 | 316 | } 317 | -------------------------------------------------------------------------------- /loss/ops/emd/emd_module.py: -------------------------------------------------------------------------------- 1 | # EMD approximation module (based on auction algorithm) 2 | # memory complexity: O(n) 3 | # time complexity: O(n^2 * iter) 4 | # author: Minghua Liu 5 | 6 | # Input: 7 | # xyz1, xyz2: [#batch, #points, 3] 8 | # where xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud 9 | # two point clouds should have same size and be normalized to [0, 1] 10 | # #points should be a multiple of 1024 11 | # #batch should be no greater than 512 12 | # eps is a parameter which balances the error rate and the speed of convergence 13 | # iters is the number of iteration 14 | # we only calculate gradient for xyz1 15 | 16 | # Output: 17 | # dist: [#batch, #points], sqrt(dist) -> L2 distance 18 | # assignment: [#batch, #points], index of the matched point in the ground truth point cloud 19 | # the result is an approximation and the assignment is not guranteed to be a bijection 20 | 21 | import time 22 | 23 | import emd 24 | import numpy as np 25 | import torch 26 | from torch import nn 27 | from torch.autograd import Function 28 | 29 | 30 | class emdFunction(Function): 31 | @staticmethod 32 | def forward(ctx, xyz1, xyz2, eps, iters): 33 | batchsize, n, _ = xyz1.size() 34 | _, m, _ = xyz2.size() 35 | 36 | assert (n == m) 37 | assert (xyz1.size()[0] == xyz2.size()[0]) 38 | assert (n % 1024 == 0) 39 | assert (batchsize <= 512) 40 | 41 | xyz1 = xyz1.contiguous().float().cuda() 42 | xyz2 = xyz2.contiguous().float().cuda() 43 | dist = torch.zeros(batchsize, n, device='cuda').contiguous() 44 | assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1 45 | assignment_inv = torch.zeros(batchsize, m, device='cuda', dtype=torch.int32).contiguous() - 1 46 | price = torch.zeros(batchsize, m, device='cuda').contiguous() 47 | bid = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() 48 | bid_increments = torch.zeros(batchsize, n, device='cuda').contiguous() 49 | max_increments = torch.zeros(batchsize, m, device='cuda').contiguous() 50 | unass_idx = torch.zeros(batchsize * n, device='cuda', dtype=torch.int32).contiguous() 51 | max_idx = torch.zeros(batchsize * m, device='cuda', dtype=torch.int32).contiguous() 52 | unass_cnt = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 53 | unass_cnt_sum = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 54 | cnt_tmp = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous() 55 | 56 | emd.forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, 57 | unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters) 58 | 59 | ctx.save_for_backward(xyz1, xyz2, assignment) 60 | return dist, assignment 61 | 62 | @staticmethod 63 | def backward(ctx, graddist, gradidx): 64 | xyz1, xyz2, assignment = ctx.saved_tensors 65 | graddist = graddist.contiguous() 66 | 67 | gradxyz1 = torch.zeros(xyz1.size(), device='cuda').contiguous() 68 | gradxyz2 = torch.zeros(xyz2.size(), device='cuda').contiguous() 69 | 70 | emd.backward(xyz1, xyz2, gradxyz1, graddist, assignment) 71 | return gradxyz1, gradxyz2, None, None 72 | 73 | 74 | class emdModule(nn.Module): 75 | def __init__(self): 76 | super(emdModule, self).__init__() 77 | 78 | def forward(self, input1, input2, eps, iters): 79 | return emdFunction.apply(input1, input2, eps, iters) 80 | 81 | 82 | def test_emd(): 83 | x1 = torch.rand(20, 8192, 3).cuda() 84 | x2 = torch.rand(20, 8192, 3).cuda() 85 | emdm = emdModule() 86 | start_time = time.perf_counter() 87 | dis, assigment = emdm(x1, x2, 0.05, 3000) 88 | print("Input_size: ", x1.shape) 89 | print("Runtime: %lfs" % (time.perf_counter() - start_time)) 90 | print("EMD: %lf" % np.sqrt(dis.cpu()).mean()) 91 | print("|set(assignment)|: %d" % assigment.unique().numel()) 92 | assigment = assigment.cpu().numpy() 93 | assigment = np.expand_dims(assigment, -1) 94 | x2 = np.take_along_axis(x2, assigment, axis=1) 95 | d = (x1 - x2) * (x1 - x2) 96 | print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean()) 97 | 98 | 99 | if __name__ == '__main__': 100 | test_emd() 101 | -------------------------------------------------------------------------------- /loss/ops/emd/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='emd', 6 | ext_modules=[ 7 | CUDAExtension('emd', [ 8 | 'emd.cpp', 9 | 'emd_cuda.cu', 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | -------------------------------------------------------------------------------- /loss/reg_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from scipy.spatial.transform import Rotation 8 | 9 | from loss.ops.emd.emd_module import emdFunction 10 | from misc.point_utils import get_knn_idx_dist, group, gather 11 | 12 | 13 | def cd_loss(preds, gts): 14 | def batch_pairwise_dist(x, y): 15 | bs, num_points_x, points_dim = x.size() 16 | _, num_points_y, _ = y.size() 17 | xx = torch.bmm(x, x.transpose(2, 1)) 18 | yy = torch.bmm(y, y.transpose(2, 1)) 19 | zz = torch.bmm(x, y.transpose(2, 1)) 20 | 21 | diag_ind_x = torch.arange(0, num_points_x).to(device=x.device) 22 | diag_ind_y = torch.arange(0, num_points_y).to(device=y.device) 23 | 24 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1)) 25 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz) 26 | P = (rx.transpose(2, 1) + ry - 2 * zz) 27 | return P 28 | 29 | P = batch_pairwise_dist(gts, preds) 30 | mins, _ = torch.min(P, 1) 31 | loss_1 = torch.mean(mins) 32 | mins, _ = torch.min(P, 2) 33 | loss_2 = torch.mean(mins) 34 | 35 | return loss_1 + loss_2 36 | 37 | 38 | def emd_loss(preds, gts, eps=0.005, iters=50): 39 | loss, _ = emdFunction.apply(preds, gts, eps, iters) 40 | return torch.mean(loss) 41 | 42 | 43 | class ChamferLoss(nn.Module): 44 | 45 | def __init__(self): 46 | super().__init__() 47 | 48 | def forward(self, preds, gts, **kwargs): 49 | return cd_loss(preds, gts) 50 | 51 | 52 | class EMDLoss(nn.Module): 53 | 54 | def __init__(self, eps=0.005, iters=50): 55 | super().__init__() 56 | self.eps = eps 57 | self.iters = iters 58 | 59 | def forward(self, preds, gts, **kwargs): 60 | return emd_loss(preds, gts, eps=self.eps, iters=self.iters) 61 | 62 | 63 | class ProjectionLoss(nn.Module): 64 | 65 | def __init__(self, knn=8, sigma_p=0.03, sigma_n=math.radians(15)): 66 | super().__init__() 67 | self.sigma_p = sigma_p 68 | self.sigma_n = sigma_n 69 | self.knn = knn 70 | 71 | def distance_weight(self, dist): 72 | """ 73 | :param dist: (B, N, k), Squared L2 distance 74 | :return (B, N, k) 75 | """ 76 | return torch.exp(- dist / (self.sigma_p ** 2)) 77 | 78 | def angle_weight(self, nb_normals): 79 | """ 80 | :param nb_normals: (B, N, k, 3), Normals of neighboring points 81 | :return (B, N, k) 82 | """ 83 | estm_normal = nb_normals[:, :, 0:1, :] # (B, N, 1, 3) 84 | inner_prod = (nb_normals * estm_normal.expand_as(nb_normals)).sum(dim=-1) # (B, N, k) 85 | return torch.exp(- (1 - inner_prod) / (1 - math.cos(self.sigma_n))) 86 | 87 | def forward(self, preds, gts, normals, **kwargs): 88 | knn_idx, knn_dist = get_knn_idx_dist(gts, query=preds, k=self.knn, offset=0) # (B, N, k), squared L2 distance 89 | 90 | nb_points = group(gts, idx=knn_idx) # (B, N, k, 3) 91 | nb_normals = group(normals, idx=knn_idx) # (B, N, k, 3) 92 | 93 | distance_w = self.distance_weight(knn_dist) # (B, N, k) 94 | angle_w = self.angle_weight(nb_normals) # (B, N, k) 95 | weights = distance_w * angle_w # (B, N, k) 96 | 97 | inner_prod = ((preds.unsqueeze(-2).expand_as(nb_points) - nb_points) * nb_normals).sum(dim=-1) # (B, N, k) 98 | inner_prod = torch.abs(inner_prod) # (B, N, k) 99 | 100 | point_displacement = (inner_prod * weights).sum(dim=-1) / weights.sum(dim=-1) # (B, N) 101 | 102 | return point_displacement.sum() 103 | 104 | 105 | class UnsupervisedLoss(nn.Module): 106 | 107 | def __init__(self, k=64, radius=0.05, pdf_std=0.5657, inv_scale=0.05, decay_epoch=80, emd_eps=0.005, emd_iters=50): 108 | super().__init__() 109 | self.knn = k 110 | self.radius = radius 111 | self.pdf_std = pdf_std 112 | self.inv_scale = inv_scale 113 | self.decay_epoch = decay_epoch 114 | self.emd_eps = emd_eps 115 | self.emd_iters = emd_iters 116 | 117 | def stochastic_neighborhood(self, inputs): 118 | """ 119 | param: inputs: (B, N, 3) 120 | return: knn_idx: (B, N, k), Indices of neighboring points 121 | return: mask: (B, N, k), Mask 122 | """ 123 | knn_idx, knn_dist = get_knn_idx_dist(inputs, query=inputs, k=self.knn, offset=1) # (B, N, k), exclude self 124 | 125 | # Gaussian spatial prior 126 | SQRT_2PI = 2.5066282746 127 | prob = torch.exp(- (knn_dist / (self.inv_scale ** 2)) / (2 * self.pdf_std ** 2)) / ( 128 | self.pdf_std * SQRT_2PI) # (B, N, k) 129 | mask = torch.bernoulli(prob) # (B, N, k) 130 | 131 | prob = prob * (torch.sqrt(knn_dist) <= self.radius) # Radius cutoff 132 | 133 | # If all the neighbor of a point are rejected, then accept at least one to avoid zero loss 134 | # Here we accept the farthest one, because all-rejected probably happens when the point is displaced too far (high noise). 135 | mask_sum = mask.sum(dim=-1, keepdim=True) # (B, N, 1) 136 | mask_farthest = torch.where(mask_sum == 0, torch.ones_like(mask_sum), torch.zeros_like(mask_sum)) # (B, N, 1) 137 | mask_delta = torch.cat([torch.zeros_like(mask_sum).repeat(1, 1, self.knn - 1), mask_farthest], 138 | dim=-1) # (B, N, k) 139 | mask = mask + mask_delta # (B, N, k) 140 | 141 | return knn_idx, mask, prob 142 | 143 | def forward(self, preds, inputs, epoch, **kwargs): 144 | _, assignment = emdFunction.apply(inputs, preds, self.emd_eps, 145 | self.emd_iters) # (B, N), assign each input point to a predicted point 146 | 147 | # Permute the predicted points according to the assignment 148 | # One-to-one correspondent to input points 149 | permuted_preds = gather(preds, idx=assignment.long()) # (B, N, 3) 150 | 151 | input_nbh_idx, input_nbh_mask, _ = self.stochastic_neighborhood(inputs) # (B, N, k), (B, N, k) 152 | input_nbh_pos = group(inputs, idx=input_nbh_idx) # (B, N, k, 3) 153 | 154 | dist = (permuted_preds.unsqueeze(dim=-2).expand_as(input_nbh_pos) - input_nbh_pos) # (B, N, k, 3) 155 | dist = (dist ** 2).sum(dim=-1) # (B, N, k), squared-L2 distance 156 | 157 | num_nbh = input_nbh_mask.sum(dim=-1) # (B, N), number of neighbors 158 | avg_dist = (dist * input_nbh_mask).sum(dim=-1) / num_nbh # (B, N), average distance 159 | 160 | return avg_dist.sum() 161 | 162 | 163 | def get_loss_layer(name): 164 | if name == 'emd': 165 | return EMDLoss() 166 | elif name == 'cd': 167 | return ChamferLoss() 168 | elif name == 'proj': 169 | return ProjectionLoss() 170 | elif name == 'unsupervised': 171 | return UnsupervisedLoss() 172 | elif name is None or name == 'None': 173 | return None 174 | else: 175 | raise ValueError('Unknown loss: %s ' % name) 176 | 177 | 178 | class RepulsionLoss(nn.Module): 179 | 180 | def __init__(self, knn=4, h=0.03): 181 | super().__init__() 182 | self.knn = knn 183 | self.h = h 184 | 185 | def forward(self, pc): 186 | knn_idx, knn_dist = get_knn_idx_dist(pc, pc, k=self.knn, offset=1) # (B, N, k) 187 | weight = torch.exp(- knn_dist / (self.h ** 2)) 188 | loss = torch.sum(- knn_dist * weight) 189 | return loss 190 | 191 | 192 | def make_reg_loss(params): 193 | model_params = params.model_params 194 | if not params.is_register: 195 | return None 196 | if model_params.reg_loss == 'EMD': 197 | loss_fn = EMDLoss(iters=params.max_iter) 198 | elif model_params.reg_loss == 'ChamferLoss': 199 | loss_fn = ChamferLoss() 200 | elif model_params.reg_loss == 'PointMSE': 201 | loss_fn = PointMSE() 202 | elif model_params.reg_loss == 'PoseMSE': 203 | loss_fn = PoseMSE() 204 | else: 205 | print('Unknown loss: {}'.format(model_params.reg_loss)) 206 | raise NotImplementedError 207 | return loss_fn 208 | 209 | 210 | class PointMSE(nn.Module): 211 | 212 | def __init__(self): 213 | super().__init__() 214 | 215 | def forward(self, pred_pts, gt_pts): 216 | loss = torch.nn.functional.mse_loss(pred_pts, gt_pts) 217 | 218 | return loss 219 | 220 | 221 | class PoseMSE(nn.Module): 222 | 223 | def __init__(self): 224 | super().__init__() 225 | 226 | def forward(self, rotation_ab_pred, translation_ab_pred, rotation_ab, translation_ab): 227 | batch_size = rotation_ab_pred.size(0) 228 | identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1) 229 | loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \ 230 | + F.mse_loss(translation_ab_pred, translation_ab) 231 | 232 | return loss 233 | 234 | 235 | def npmat2euler(mats, seq='zyx'): 236 | eulers = [] 237 | for i in range(mats.shape[0]): 238 | r = Rotation.from_matrix(mats[i]) 239 | eulers.append(r.as_euler(seq, degrees=True)) 240 | return np.asarray(eulers, dtype='float32') 241 | 242 | 243 | def dcp_metric(R_ab, t_ab, R_ab_pred, t_ab_pred): 244 | R_ab_pred_euler = npmat2euler(R_ab_pred) 245 | R_ab_euler = npmat2euler(R_ab) 246 | r_mse_ab = np.mean((R_ab_pred_euler - R_ab_euler) ** 2) 247 | r_mae_ab = np.mean(np.abs(R_ab_pred_euler - R_ab_euler)) 248 | t_mse_ab = np.mean((t_ab - t_ab_pred) ** 2) 249 | t_mae_ab = np.mean(np.abs(t_ab - t_ab_pred)) 250 | 251 | temp_state = {} 252 | temp_state.update({'r_mse_ab': r_mse_ab}) 253 | temp_state.update({'r_mae_ab': r_mae_ab}) 254 | temp_state.update({'t_mse_ab': t_mse_ab}) 255 | temp_state.update({'t_mae_ab': t_mae_ab}) 256 | 257 | return temp_state 258 | 259 | 260 | def cor_loss(dict_pose, gt_T): 261 | src = dict_pose['src'] 262 | src_corr = dict_pose['src_corr'] 263 | rotation_ab_pred = dict_pose['rotation_ab'] 264 | translation_ab_pred = dict_pose['translation_ab'] 265 | outlier_src_mask = dict_pose['outlier_src_mask'] 266 | gt_T = gt_T.view(-1, 4, 4).cuda() 267 | rotation_ab = gt_T[:, :3, :3] 268 | translation_ab = gt_T[:, :3, 3] 269 | 270 | transformed_srcK = transform_point_cloud(src, rotation_ab, translation_ab) 271 | loss = mse_mask(transformed_srcK, src_corr, outlier_src_mask) 272 | 273 | return loss 274 | 275 | 276 | def quat2mat(quat): 277 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3] 278 | 279 | B = quat.size(0) 280 | 281 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 282 | wx, wy, wz = w * x, w * y, w * z 283 | xy, xz, yz = x * y, x * z, y * z 284 | 285 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 286 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 287 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 288 | return rotMat 289 | 290 | 291 | def transform_point_cloud(point_cloud, rotation, translation): 292 | if len(rotation.size()) == 2: 293 | rot_mat = quat2mat(rotation) 294 | else: 295 | rot_mat = rotation 296 | return torch.matmul(rot_mat, point_cloud) + translation.unsqueeze(2) 297 | 298 | 299 | def mse_mask(transformed_srcK, src_corrK, mask): 300 | b, _, num = transformed_srcK.size() 301 | mask = mask.contiguous().view(b, 1, num).repeat(1, 3, 1) 302 | transformed_srcK = torch.masked_fill(transformed_srcK, mask, 0) 303 | src_corrK = torch.masked_fill(src_corrK, mask, 0) 304 | loss = torch.nn.functional.mse_loss(transformed_srcK, src_corrK) 305 | return loss 306 | -------------------------------------------------------------------------------- /media/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/media/pipeline.png -------------------------------------------------------------------------------- /misc/log.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | from dateutil import tz 6 | 7 | log_dir = None 8 | reg_log_dir = None 9 | LOG_FOUT = None 10 | inited = False 11 | 12 | 13 | def setup_log(args): 14 | global LOG_FOUT, log_dir, inited, start_time, reg_log_dir 15 | if inited: 16 | return 17 | inited = True 18 | config = args.config.split('/')[-1].split('.')[0].replace('config_baseline', 'cb') 19 | model_config = args.model_config.split('/')[-1].split('.')[0] 20 | tz_sh = tz.gettz('Asia/Shanghai') 21 | now = datetime.now(tz=tz_sh) 22 | if (not os.path.exists("./tf_logs")): 23 | os.mkdir("./tf_logs") 24 | # dir = '{}-{}-{}'.format(config, model_config, now.strftime("%m%d-%H%M%S")) 25 | dir = '{}-{}'.format(config, model_config) 26 | log_dir = os.path.join("./tf_logs", dir) 27 | if not os.path.exists(log_dir): 28 | os.makedirs(log_dir) 29 | os.system('rm -r {}'.format(os.path.join("./tf_logs", 'latest'))) 30 | os.system("cd tf_logs && ln -s {} {} && cd ..".format(dir, "latest")) 31 | 32 | start_time = now 33 | LOG_FOUT = open(os.path.join(log_dir, 'log_train.txt'), 'w') 34 | log_string('log dir: {}'.format(log_dir)) 35 | reg_log_dir = os.path.join(log_dir, "registration") 36 | if not os.path.exists(reg_log_dir): 37 | os.makedirs(reg_log_dir) 38 | 39 | 40 | def log_string(out_str, end='\n'): 41 | LOG_FOUT.write(out_str) 42 | LOG_FOUT.write(end) 43 | LOG_FOUT.flush() 44 | print(out_str, end=end, flush=True) 45 | 46 | 47 | def log_silent(out_str, end='\n'): 48 | LOG_FOUT.write(out_str) 49 | LOG_FOUT.write(end) 50 | LOG_FOUT.flush() 51 | 52 | 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--config', type=str) 55 | parser.add_argument('--model_config', type=str) 56 | parser.add_argument('--debug', dest='debug', action='store_true') 57 | parser.set_defaults(debug=False) 58 | parser.add_argument('--checkpoint', type=str, required=False, help='Trained model weights', default="") 59 | parser.add_argument('--weights') 60 | parser.add_argument('--log', action='store_true') 61 | parser.add_argument('--visualize', dest='visualize', action='store_true') 62 | parser.set_defaults(visualize=False) 63 | args = parser.parse_known_args()[0] 64 | setup_log(args) 65 | 66 | 67 | def is_last_run_end(last_run_file): 68 | with open(last_run_file) as f: 69 | lines = f.readlines() 70 | for i in lines: 71 | if 'end' in i: 72 | return True 73 | return False 74 | 75 | 76 | cuda_dev = os.environ.get('CUDA_VISIBLE_DEVICE') 77 | if cuda_dev is None: 78 | cuda_dev = '0' 79 | last_run = 'lastrun_{}'.format(cuda_dev) 80 | last_run_file = last_run + '.log' 81 | last_run_id = 1 82 | # while os.path.exists(last_run_file) and not is_last_run_end(last_run_file): 83 | # last_run_file = last_run + str(last_run_id) + '.log' 84 | # last_run_id += 1 85 | # with open(last_run_file, 'w') as f: 86 | # f.write(f'start:{start_time.strftime("%m%d-%H%M%S")}\n') 87 | # f.write(f'log_dir:{log_dir}\n') 88 | # for k,v in vars(args).items(): 89 | # f.write(f'{k}:{v}\n') 90 | 91 | # @atexit.register 92 | # def end_last_run(): 93 | # tz_sh = tz.gettz('Asia/Shanghai') 94 | # now = datetime.now(tz=tz_sh) 95 | # with open(last_run_file, 'a') as f: 96 | # f.write(f'end:{now.strftime("%m%d-%H%M%S")}\n') 97 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/__init__.py -------------------------------------------------------------------------------- /models/discriminator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/discriminator/__init__.py -------------------------------------------------------------------------------- /models/discriminator/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import MinkowskiEngine as ME 25 | import torch.nn as nn 26 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 27 | 28 | 29 | class ResNetBase(nn.Module): 30 | BLOCK = None 31 | LAYERS = () 32 | INIT_DIM = 64 33 | PLANES = (64, 128, 256, 512) 34 | 35 | def __init__(self, in_channels, out_channels, D=3): 36 | nn.Module.__init__(self) 37 | self.D = D 38 | assert self.BLOCK is not None 39 | 40 | self.network_initialization(in_channels, out_channels, D) 41 | self.weight_initialization() 42 | 43 | def network_initialization(self, in_channels, out_channels, D): 44 | 45 | self.inplanes = self.INIT_DIM 46 | self.conv1 = nn.Sequential( 47 | ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D 49 | ), 50 | ME.MinkowskiBatchNorm(self.inplanes), 51 | ME.MinkowskiReLU(inplace=True), 52 | ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D), 53 | ) 54 | 55 | self.layer1 = self._make_layer( 56 | self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2 57 | ) 58 | self.layer2 = self._make_layer( 59 | self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2 60 | ) 61 | self.layer3 = self._make_layer( 62 | self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2 63 | ) 64 | self.layer4 = self._make_layer( 65 | self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2 66 | ) 67 | 68 | self.conv5 = nn.Sequential( 69 | ME.MinkowskiDropout(), 70 | ME.MinkowskiConvolution( 71 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D 72 | ), 73 | ME.MinkowskiBatchNorm(self.inplanes), 74 | ME.MinkowskiGELU(), 75 | ) 76 | 77 | self.glob_pool = ME.MinkowskiGlobalMaxPooling() 78 | 79 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 80 | 81 | def weight_initialization(self): 82 | for m in self.modules(): 83 | if isinstance(m, ME.MinkowskiConvolution): 84 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") 85 | 86 | if isinstance(m, ME.MinkowskiBatchNorm): 87 | nn.init.constant_(m.bn.weight, 1) 88 | nn.init.constant_(m.bn.bias, 0) 89 | 90 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): 91 | downsample = None 92 | if stride != 1 or self.inplanes != planes * block.expansion: 93 | downsample = nn.Sequential( 94 | ME.MinkowskiConvolution( 95 | self.inplanes, 96 | planes * block.expansion, 97 | kernel_size=1, 98 | stride=stride, 99 | dimension=self.D, 100 | ), 101 | ME.MinkowskiBatchNorm(planes * block.expansion), 102 | ) 103 | layers = [] 104 | layers.append( 105 | block( 106 | self.inplanes, 107 | planes, 108 | stride=stride, 109 | dilation=dilation, 110 | downsample=downsample, 111 | dimension=self.D, 112 | ) 113 | ) 114 | self.inplanes = planes * block.expansion 115 | for i in range(1, blocks): 116 | layers.append( 117 | block( 118 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D 119 | ) 120 | ) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.layer1(x) 127 | x = self.layer2(x) 128 | x = self.layer3(x) 129 | x = self.layer4(x) 130 | x = self.conv5(x) 131 | x = self.glob_pool(x) 132 | return self.final(x) 133 | 134 | 135 | class ResNet14(ResNetBase): 136 | BLOCK = BasicBlock 137 | LAYERS = (1, 1, 1, 1) 138 | 139 | 140 | class ResNet18(ResNetBase): 141 | BLOCK = BasicBlock 142 | LAYERS = (2, 2, 2, 2) 143 | 144 | 145 | class ResNet34(ResNetBase): 146 | BLOCK = BasicBlock 147 | LAYERS = (3, 4, 6, 3) 148 | 149 | 150 | class ResNet50(ResNetBase): 151 | BLOCK = Bottleneck 152 | LAYERS = (3, 4, 6, 3) 153 | 154 | 155 | class ResNet101(ResNetBase): 156 | BLOCK = Bottleneck 157 | LAYERS = (3, 4, 23, 3) 158 | 159 | 160 | class ResFieldNetBase(ResNetBase): 161 | def network_initialization(self, in_channels, out_channels, D): 162 | field_ch = 32 163 | field_ch2 = 64 164 | self.field_network = nn.Sequential( 165 | ME.MinkowskiSinusoidal(in_channels, field_ch), 166 | ME.MinkowskiBatchNorm(field_ch), 167 | ME.MinkowskiReLU(inplace=True), 168 | ME.MinkowskiLinear(field_ch, field_ch), 169 | ME.MinkowskiBatchNorm(field_ch), 170 | ME.MinkowskiReLU(inplace=True), 171 | ME.MinkowskiToSparseTensor(), 172 | ) 173 | self.field_network2 = nn.Sequential( 174 | ME.MinkowskiSinusoidal(field_ch + in_channels, field_ch2), 175 | ME.MinkowskiBatchNorm(field_ch2), 176 | ME.MinkowskiReLU(inplace=True), 177 | ME.MinkowskiLinear(field_ch2, field_ch2), 178 | ME.MinkowskiBatchNorm(field_ch2), 179 | ME.MinkowskiReLU(inplace=True), 180 | ME.MinkowskiToSparseTensor(), 181 | ) 182 | 183 | ResNetBase.network_initialization(self, field_ch2, out_channels, D) 184 | 185 | def forward(self, x): 186 | otensor = self.field_network(x) 187 | otensor2 = self.field_network2(otensor.cat_slice(x)) 188 | return ResNetBase.forward(self, otensor2) 189 | 190 | 191 | class ResFieldNet14(ResFieldNetBase): 192 | BLOCK = BasicBlock 193 | LAYERS = (1, 1, 1, 1) 194 | 195 | 196 | class ResFieldNet18(ResFieldNetBase): 197 | BLOCK = BasicBlock 198 | LAYERS = (2, 2, 2, 2) 199 | 200 | 201 | class ResFieldNet34(ResFieldNetBase): 202 | BLOCK = BasicBlock 203 | LAYERS = (3, 4, 6, 3) 204 | 205 | 206 | class ResFieldNet50(ResFieldNetBase): 207 | BLOCK = Bottleneck 208 | LAYERS = (3, 4, 6, 3) 209 | 210 | 211 | class ResFieldNet101(ResFieldNetBase): 212 | BLOCK = Bottleneck 213 | LAYERS = (3, 4, 23, 3) 214 | 215 | 216 | if __name__ == "__main__": 217 | criterion = nn.CrossEntropyLoss() 218 | net = ResNet14(in_channels=3, out_channels=5, D=2) 219 | print(net) 220 | -------------------------------------------------------------------------------- /models/discriminator/resnetD.py: -------------------------------------------------------------------------------- 1 | import MinkowskiEngine as ME 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from MinkowskiEngine.modules.resnet_block import BasicBlock 6 | 7 | from misc.log import log_string 8 | from models.discriminator.resnet import ResNetBase 9 | 10 | 11 | class ResNet14(ResNetBase): 12 | BLOCK = BasicBlock 13 | LAYERS = (1, 1, 1, 1) 14 | INIT_DIM = 32 15 | PLANES = (32, 64, 128, 128) 16 | 17 | 18 | class ResnetDor(torch.nn.Module): 19 | def __init__(self, planes=(32, 64, 64), n_layers=1, D=3): 20 | super().__init__() 21 | self.planes = planes 22 | self.n_layers = n_layers 23 | self.D = D 24 | self.network_initialization() 25 | 26 | def network_initialization(self): 27 | 28 | layers = len(self.planes) 29 | self.convs = nn.ModuleList() 30 | self.conv1 = nn.Sequential( 31 | ME.MinkowskiConvolution(self.planes[0], self.planes[0], kernel_size=2, stride=2, dimension=self.D), 32 | ME.MinkowskiBatchNorm(self.planes[0]), 33 | ME.MinkowskiReLU(inplace=True)) 34 | self.convs.append(self.conv1) 35 | for i in range(layers - 1): 36 | self.convs.append(nn.Sequential( 37 | ME.MinkowskiConvolution(self.planes[i] * 2, self.planes[i + 1], kernel_size=2, stride=2, 38 | dimension=self.D), 39 | ME.MinkowskiBatchNorm(self.planes[i + 1]), 40 | ME.MinkowskiReLU(inplace=True))) 41 | 42 | self.pool = ME.MinkowskiGlobalAvgPooling() 43 | 44 | self.fc = [] 45 | for j in range(self.n_layers): 46 | self.fc += [ME.MinkowskiLinear(self.planes[layers - 1], self.planes[layers - 1], bias=True)] 47 | self.fc += [ME.MinkowskiReLU(inplace=True)] 48 | self.fc += [ME.MinkowskiLinear(self.planes[layers - 1], 1, bias=True)] 49 | self.fc = nn.Sequential(*self.fc) 50 | 51 | # 32,32,64,64 52 | def forward(self, mid_feature): 53 | last_layer = None 54 | for i, conv in enumerate(self.convs): 55 | layer = mid_feature[i] if last_layer is None else ME.cat(last_layer, mid_feature[i]) 56 | layer = conv(layer) 57 | last_layer = layer 58 | layer = self.pool(layer) 59 | pred = self.fc(layer) 60 | 61 | return pred.F 62 | 63 | def print_info(self): 64 | 65 | para = sum([np.prod(list(p.size())) for p in self.parameters()]) 66 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000)) 67 | 68 | 69 | class MeanFCDor(torch.nn.Module): 70 | def __init__(self, planes=(32, 64, 64), feature_dim=256, n_layers=3, D=3): 71 | super().__init__() 72 | self.feature_dim = feature_dim 73 | self.n_layers = n_layers 74 | self.D = D 75 | self.network_initialization() 76 | self.pool = ME.MinkowskiGlobalAvgPooling() 77 | 78 | def network_initialization(self): 79 | self.fc = [] 80 | lay_dim = self.feature_dim 81 | lay_dim = 960 82 | # for j in range(self.n_layers): 83 | self.fc += [nn.Linear(960, 64, bias=True)] 84 | self.fc += [nn.ReLU(inplace=True)] 85 | self.fc += [nn.Linear(64, 64, bias=True)] 86 | self.fc += [nn.ReLU(inplace=True)] 87 | self.fc += [nn.Linear(64, 1, bias=True)] 88 | self.fc = nn.Sequential(*self.fc) 89 | 90 | # 32,32,64,64 91 | def forward(self, mid_feature): 92 | mid_layers = [] 93 | for i in range(len(mid_feature) - 1): 94 | mid_layers.append(self.pool(mid_feature[i]).F) 95 | mid_layers.append(mid_feature[-1]) 96 | mid_layers = torch.cat(mid_layers, dim=1) 97 | pred = self.fc(mid_layers) 98 | 99 | return pred 100 | 101 | def print_info(self): 102 | para = sum([np.prod(list(p.size())) for p in self.parameters()]) 103 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000)) 104 | 105 | 106 | class FCDor(torch.nn.Module): 107 | def __init__(self, feature_dim=256, n_layers=3, D=3): 108 | super().__init__() 109 | self.feature_dim = feature_dim 110 | self.n_layers = n_layers 111 | self.D = D 112 | self.network_initialization() 113 | 114 | def network_initialization(self): 115 | self.fc = [] 116 | lay_dim = self.feature_dim 117 | for j in range(self.n_layers): 118 | self.fc += [nn.Linear(lay_dim, lay_dim // (j + 1), bias=True)] 119 | self.fc += [nn.ReLU(inplace=True)] 120 | lay_dim = lay_dim // (j + 1) 121 | self.fc += [nn.Linear(lay_dim, 1, bias=True)] 122 | self.fc = nn.Sequential(*self.fc) 123 | 124 | # 32,32,64,64 125 | def forward(self, mid_feature): 126 | # layer = mid_feature 127 | pred = self.fc(mid_feature) 128 | 129 | return pred 130 | 131 | def print_info(self): 132 | para = sum([np.prod(list(p.size())) for p in self.parameters()]) 133 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000)) 134 | 135 | 136 | class ResnetFDor(torch.nn.Module): 137 | def __init__(self, planes=(32, 64, 64), n_layers=1, D=3): 138 | super().__init__() 139 | self.planes = planes 140 | self.n_layers = n_layers 141 | self.D = D 142 | self.network_initialization() 143 | 144 | def network_initialization(self): 145 | 146 | layers = len(self.planes) 147 | self.convs = nn.ModuleList() 148 | self.conv1 = nn.Sequential( 149 | nn.Conv1d(self.planes[0], self.planes[0], kernel_size=2, stride=2), 150 | nn.BatchNorm1d(self.planes[0]), 151 | nn.ReLU(inplace=True)) 152 | self.convs.append(self.conv1) 153 | for i in range(layers - 1): 154 | self.convs.append(nn.Sequential( 155 | nn.Conv1d(self.planes[i] * 2, self.planes[i + 1], kernel_size=2, stride=2), 156 | nn.BatchNorm1d(self.planes[i + 1]), 157 | nn.ReLU(inplace=True))) 158 | 159 | self.pool = lambda x: torch.mean(x, dim=-1, keepdim=True) 160 | 161 | self.fc = [] 162 | for j in range(self.n_layers): 163 | self.fc += [nn.Linear(self.planes[layers - 1], self.planes[layers - 1], bias=True)] 164 | self.fc += [nn.ReLU(inplace=True)] 165 | self.fc += [nn.Linear(self.planes[layers - 1], 1, bias=True)] 166 | self.fc = nn.Sequential(*self.fc) 167 | pass 168 | 169 | # 32,32,64,64 170 | def forward(self, mid_feature): 171 | last_layer = None 172 | for i, conv in enumerate(self.convs): 173 | layer = mid_feature[i] if last_layer is None else torch.cat([last_layer, mid_feature[i]], dim=2) 174 | layer = conv(layer) 175 | last_layer = layer 176 | layer = self.pool(layer) 177 | pred = self.fc(layer) 178 | 179 | return pred 180 | 181 | def print_info(self): 182 | 183 | para = sum([np.prod(list(p.size())) for p in self.parameters()]) 184 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000)) 185 | -------------------------------------------------------------------------------- /models/minkloc3d/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/minkloc3d/__init__.py -------------------------------------------------------------------------------- /models/minkloc3d/minkfpn.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | import MinkowskiEngine as ME 6 | import torch.nn as nn 7 | from MinkowskiEngine.modules.resnet_block import BasicBlock 8 | 9 | from models.minkloc3d.resnet_mink import ResNetBase 10 | from models.minkloc3d.senet_block import SEBasicBlock 11 | 12 | 13 | class MinkFPN(ResNetBase): 14 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 15 | def __init__(self, in_channels=1, model_params=None, block=BasicBlock): 16 | assert len(model_params.layers) == len(model_params.planes) 17 | assert 1 <= len(model_params.layers) 18 | assert 0 <= model_params.num_top_down <= len(model_params.layers) 19 | self.num_bottom_up = len(model_params.layers) 20 | self.num_top_down = model_params.num_top_down 21 | self.conv0_kernel_size = model_params.conv0_kernel_size 22 | self.block = block 23 | self.layers = model_params.layers 24 | self.planes = model_params.planes 25 | self.lateral_dim = model_params.feature_size 26 | self.init_dim = model_params.planes[0] 27 | ResNetBase.__init__(self, in_channels, model_params.feature_size, D=3) 28 | 29 | def network_initialization(self, in_channels, out_channels, D): 30 | assert len(self.layers) == len(self.planes) 31 | assert len(self.planes) == self.num_bottom_up 32 | 33 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 34 | self.bn = nn.ModuleList() # Bottom-up BatchNorms 35 | self.blocks = nn.ModuleList() # Bottom-up blocks 36 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions 37 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections 38 | 39 | # The first convolution is special case, with kernel size = 5 40 | self.inplanes = conv0_out = self.planes[0] 41 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, 42 | dimension=D) 43 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 44 | 45 | for plane, layer in zip(self.planes, self.layers): 46 | self.convs.append( 47 | ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 48 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes)) 49 | self.blocks.append(self._make_layer(self.block, plane, layer)) 50 | 51 | # Lateral connections 52 | for i in range(self.num_top_down): 53 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1, 54 | stride=1, dimension=D)) 55 | 56 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2, 57 | stride=2, dimension=D)) 58 | # There's one more lateral connection than top-down TConv blocks 59 | if self.num_top_down < self.num_bottom_up: 60 | # Lateral connection from Conv block 1 or above 61 | self.conv1x1.append( 62 | ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1, 63 | stride=1, dimension=D)) 64 | else: 65 | # Lateral connection from Con0 block 66 | self.conv1x1.append(ME.MinkowskiConvolution(conv0_out, self.lateral_dim, kernel_size=1, 67 | stride=1, dimension=D)) 68 | 69 | # self.relu = ME.MinkowskiPReLU() 70 | self.relu = ME.MinkowskiReLU(inplace=True) 71 | 72 | def forward(self, x): 73 | # *** BOTTOM-UP PASS *** 74 | # First bottom-up convolution is special (with bigger stride) 75 | feature_maps = [] 76 | x = self.conv0(x) 77 | x = self.bn0(x) 78 | x = self.relu(x) 79 | if self.num_top_down == self.num_bottom_up: 80 | feature_maps.append(x) 81 | 82 | # BOTTOM-UP PASS 83 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)): 84 | x = conv(x) # Decreases spatial resolution (conv stride=2) 85 | x = bn(x) 86 | x = self.relu(x) 87 | x = block(x) 88 | 89 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1: 90 | feature_maps.append(x) 91 | 92 | assert len(feature_maps) == self.num_top_down 93 | 94 | x = self.conv1x1[0](x) 95 | 96 | # TOP-DOWN PASS 97 | for ndx, tconv in enumerate(self.tconvs): 98 | x = tconv(x) # Upsample using transposed convolution 99 | x = x + self.conv1x1[ndx + 1](feature_maps[-ndx - 1]) 100 | 101 | return x 102 | 103 | 104 | class MinkSEFPN(MinkFPN): 105 | def __init__(self, in_channels=1, model_params=None, block=SEBasicBlock): 106 | MinkFPN.__init__(self, in_channels=in_channels, model_params=model_params, block=block) 107 | -------------------------------------------------------------------------------- /models/minkloc3d/minkpool.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski 2 | # Warsaw University of Technology 3 | 4 | import MinkowskiEngine as ME 5 | import torch 6 | 7 | import models.minkloc3d.pooling as pooling 8 | from models.minkloc3d.minkfpn import MinkFPN 9 | 10 | 11 | class MinkPool(torch.nn.Module): 12 | def __init__(self, model_params, in_channels): 13 | super().__init__() 14 | self.model = model_params.mink_model 15 | self.in_channels = in_channels 16 | self.feature_size = model_params.feature_size # Size of local features produced by local feature extraction block 17 | self.output_dim = model_params.output_dim # Dimensionality of the global descriptor 18 | self.backbone = MinkFPN(in_channels=in_channels, model_params=model_params) 19 | self.n_backbone_features = model_params.output_dim 20 | self.pooling = pooling.GeM() 21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | # [B,dims,num_points,1], [B,1,num_points,3] 24 | def __input(self, feat_map, coords_xyz=None): 25 | # coords_xyz = coords_xyz.cpu().squeeze(1) 26 | coords = [] 27 | feat_maps = [] 28 | for e in range(len(coords_xyz)): 29 | coords_e, inds_e = ME.utils.sparse_quantize(coords_xyz[e][0].cpu().squeeze(1), return_index=True, 30 | quantization_size=0.01) 31 | feat_map_e = feat_map[e][:, inds_e, :].squeeze(-1).transpose(-1, -2) 32 | coords.append(coords_e) 33 | feat_maps.append(feat_map_e) 34 | # coords = [ME.utils.sparse_quantize(coords_xyz[e,...], quantization_size=0.01) 35 | # for e in range(coords_xyz.shape[0])] 36 | coords = ME.utils.batched_coordinates(coords) 37 | # Assign a dummy feature equal to 1 to each point 38 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 39 | feats = torch.vstack(feat_maps) 40 | batch = {'coords': coords, 'features': feats} 41 | return batch 42 | 43 | def forward(self, feat_map, coords_xyz=None): 44 | batch = self.__input(feat_map, coords_xyz=coords_xyz) 45 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation 46 | x = ME.SparseTensor(batch['features'], coordinates=batch['coords'], device=self.device) 47 | x = self.backbone(x) 48 | # x is (num_points, n_features) tensor 49 | assert x.shape[1] == self.feature_size, 'Backbone output tensor has: {} channels. Expected: {}'.format( 50 | x.shape[1], self.feature_size) 51 | x = self.pooling(x) 52 | assert x.dim() == 2, 'Expected 2-dimensional tensor (batch_size,output_dim). Got {} dimensions.'.format(x.dim()) 53 | assert x.shape[1] == self.output_dim, 'Output tensor has: {} channels. Expected: {}'.format(x.shape[1], 54 | self.output_dim) 55 | # x is (batch_size, output_dim) tensor 56 | return {'g_fea': x} 57 | 58 | def print_info(self): 59 | print('Model class: MinkPool') 60 | n_params = sum([param.nelement() for param in self.parameters()]) 61 | print('Total parameters: {}'.format(n_params)) 62 | n_params = sum([param.nelement() for param in self.backbone.parameters()]) 63 | print('Backbone parameters: {}'.format(n_params)) 64 | n_params = sum([param.nelement() for param in self.pooling.parameters()]) 65 | print('Aggregation parameters: {}'.format(n_params)) 66 | if hasattr(self.backbone, 'print_info'): 67 | self.backbone.print_info() 68 | if hasattr(self.pooling, 'print_info'): 69 | self.pooling.print_info() 70 | -------------------------------------------------------------------------------- /models/minkloc3d/pooling.py: -------------------------------------------------------------------------------- 1 | # Code taken from: https://github.com/filipradenovic/cnnimageretrieval-pytorch 2 | # and ported to MinkowskiEngine by Jacek Komorowski 3 | 4 | import MinkowskiEngine as ME 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class MAC(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | self.f = ME.MinkowskiGlobalMaxPooling() 13 | 14 | def forward(self, x: ME.SparseTensor): 15 | x = self.f(x) 16 | return x.F # Return (batch_size, n_features) tensor 17 | 18 | 19 | class SPoC(nn.Module): 20 | def __init__(self): 21 | super().__init__() 22 | self.f = ME.MinkowskiGlobalAvgPooling() 23 | 24 | def forward(self, x: ME.SparseTensor): 25 | x = self.f(x) 26 | return x.F # Return (batch_size, n_features) tensor 27 | 28 | 29 | class GeM(nn.Module): 30 | def __init__(self, p=3, eps=1e-6): 31 | super(GeM, self).__init__() 32 | self.p = nn.Parameter(torch.ones(1) * p) 33 | self.eps = eps 34 | self.f = ME.MinkowskiGlobalAvgPooling() 35 | 36 | def forward(self, x: ME.SparseTensor): 37 | # This implicitly applies ReLU on x (clamps negative values) 38 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C) 39 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling 40 | return temp.F.pow(1. / self.p) # Return (batch_size, n_features) tensor 41 | -------------------------------------------------------------------------------- /models/minkloc3d/resnet_mink.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import MinkowskiEngine as ME 26 | import torch.nn as nn 27 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 28 | 29 | 30 | class ResNetBase(nn.Module): 31 | block = None 32 | layers = () 33 | init_dim = 64 34 | planes = (64, 128, 256, 512) 35 | 36 | # default: 1, 256 37 | def __init__(self, in_channels, out_channels, D=3): 38 | nn.Module.__init__(self) 39 | self.D = D 40 | assert self.block is not None 41 | 42 | self.network_initialization(in_channels, out_channels, D) 43 | self.weight_initialization() 44 | 45 | def network_initialization(self, in_channels, out_channels, D): 46 | self.inplanes = self.init_dim 47 | self.conv1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 49 | 50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 51 | self.relu = ME.MinkowskiReLU(inplace=True) 52 | 53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 54 | 55 | self.layer1 = self._make_layer( 56 | self.block, self.planes[0], self.layers[0], stride=2) 57 | self.layer2 = self._make_layer( 58 | self.block, self.planes[1], self.layers[1], stride=2) 59 | self.layer3 = self._make_layer( 60 | self.block, self.planes[2], self.layers[2], stride=2) 61 | self.layer4 = self._make_layer( 62 | self.block, self.planes[3], self.layers[3], stride=2) 63 | 64 | self.conv5 = ME.MinkowskiConvolution( 65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 67 | 68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling() 69 | 70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 71 | 72 | def weight_initialization(self): 73 | for m in self.modules(): 74 | if isinstance(m, ME.MinkowskiConvolution): 75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 76 | 77 | if isinstance(m, ME.MinkowskiBatchNorm): 78 | nn.init.constant_(m.bn.weight, 1) 79 | nn.init.constant_(m.bn.bias, 0) 80 | 81 | def _make_layer(self, 82 | block, 83 | planes, 84 | blocks, 85 | stride=1, 86 | dilation=1, 87 | bn_momentum=0.1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | ME.MinkowskiConvolution( 92 | self.inplanes, 93 | planes * block.expansion, 94 | kernel_size=1, 95 | stride=stride, 96 | dimension=self.D), 97 | ME.MinkowskiBatchNorm(planes * block.expansion)) 98 | layers = [] 99 | layers.append( 100 | block( 101 | self.inplanes, 102 | planes, 103 | stride=stride, 104 | dilation=dilation, 105 | downsample=downsample, 106 | dimension=self.D)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append( 110 | block( 111 | self.inplanes, 112 | planes, 113 | stride=1, 114 | dilation=dilation, 115 | dimension=self.D)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.bn1(x) 122 | x = self.relu(x) 123 | x = self.pool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.conv5(x) 131 | x = self.bn5(x) 132 | x = self.relu(x) 133 | 134 | x = self.glob_avg(x) 135 | return self.final(x) 136 | 137 | 138 | class ResNet14(ResNetBase): 139 | BLOCK = BasicBlock 140 | LAYERS = (1, 1, 1, 1) 141 | 142 | 143 | class ResNet18(ResNetBase): 144 | BLOCK = BasicBlock 145 | LAYERS = (2, 2, 2, 2) 146 | 147 | 148 | class ResNet34(ResNetBase): 149 | BLOCK = BasicBlock 150 | LAYERS = (3, 4, 6, 3) 151 | 152 | 153 | class ResNet50(ResNetBase): 154 | BLOCK = Bottleneck 155 | LAYERS = (3, 4, 6, 3) 156 | 157 | 158 | class ResNet101(ResNetBase): 159 | BLOCK = Bottleneck 160 | LAYERS = (3, 4, 23, 3) 161 | -------------------------------------------------------------------------------- /models/minkloc3d/senet_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | import MinkowskiEngine as ME 25 | import torch.nn as nn 26 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 27 | 28 | 29 | class SELayer(nn.Module): 30 | 31 | def __init__(self, channel, reduction=16, dimension=-1): 32 | # Global coords does not require coords_key 33 | super(SELayer, self).__init__() 34 | self.fc = nn.Sequential( 35 | ME.MinkowskiLinear(channel, channel // reduction), 36 | ME.MinkowskiReLU(inplace=True), 37 | ME.MinkowskiLinear(channel // reduction, channel), 38 | ME.MinkowskiSigmoid()) 39 | self.pooling = ME.MinkowskiGlobalPooling() 40 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() 41 | 42 | def forward(self, x): 43 | y = self.pooling(x) 44 | y = self.fc(y) 45 | return self.broadcast_mul(x, y) 46 | 47 | 48 | class SEBasicBlock(BasicBlock): 49 | 50 | def __init__(self, 51 | inplanes, 52 | planes, 53 | stride=1, 54 | dilation=1, 55 | downsample=None, 56 | reduction=16, 57 | dimension=-1): 58 | super(SEBasicBlock, self).__init__( 59 | inplanes, 60 | planes, 61 | stride=stride, 62 | dilation=dilation, 63 | downsample=downsample, 64 | dimension=dimension) 65 | self.se = SELayer(planes, reduction=reduction, dimension=dimension) 66 | 67 | def forward(self, x): 68 | residual = x 69 | 70 | out = self.conv1(x) 71 | out = self.norm1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.norm2(out) 76 | out = self.se(out) 77 | 78 | if self.downsample is not None: 79 | residual = self.downsample(x) 80 | 81 | out += residual 82 | out = self.relu(out) 83 | 84 | return out 85 | 86 | 87 | class SEBottleneck(Bottleneck): 88 | 89 | def __init__(self, 90 | inplanes, 91 | planes, 92 | stride=1, 93 | dilation=1, 94 | downsample=None, 95 | dimension=3, 96 | reduction=16): 97 | super(SEBottleneck, self).__init__( 98 | inplanes, 99 | planes, 100 | stride=stride, 101 | dilation=dilation, 102 | downsample=downsample, 103 | dimension=dimension) 104 | self.se = SELayer(planes * self.expansion, reduction=reduction, dimension=dimension) 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.norm1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.norm2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.norm3(out) 119 | out = self.se(out) 120 | 121 | if self.downsample is not None: 122 | residual = self.downsample(x) 123 | 124 | out += residual 125 | out = self.relu(out) 126 | 127 | return out 128 | -------------------------------------------------------------------------------- /models/model_factory.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | import os 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from misc.log import log_string 11 | from models.discriminator.resnetD import FCDor 12 | from models.vcrnet.vcrnet import VCRNet 13 | from models.vlpdnet.vlpdnet import vLPDNet 14 | 15 | 16 | def model_factory(params): 17 | if 'vLPDNet' in params.model_params.model: 18 | model = vLPDNet(params) 19 | else: 20 | raise NotImplementedError('Model not implemented: {}'.format(params.model_params.model)) 21 | 22 | if hasattr(model, 'print_info'): 23 | model.print_info() 24 | 25 | para = sum([np.prod(list(p.size())) for p in model.parameters()]) 26 | log_string('Model {} : params: {:4f}M'.format(model._get_name(), para * 4 / 1000 / 1000)) 27 | 28 | if params.domain_adapt: 29 | if params.d_model == 'FCDor': 30 | d_model = FCDor(feature_dim=params.model_params.feature_size, n_layers=3, D=3) 31 | else: 32 | raise NotImplementedError('d_model not implemented: {}'.format(params.d_model)) 33 | 34 | if hasattr(d_model, 'print_info'): 35 | d_model.print_info() 36 | else: 37 | d_model = None 38 | 39 | vcr_model = None 40 | if params.is_register: 41 | vcr_model = VCRNet(params.model_params) 42 | if params.model_params.reg_model_path != "": 43 | checkpoint_dict = torch.load(params.model_params.reg_model_path, map_location=torch.device('cpu')) 44 | vcr_model.load_state_dict(checkpoint_dict, strict=True) 45 | log_string('load vcr_model with {}'.format(params.model_params.reg_model_path)) 46 | 47 | # Move the model to the proper device before configuring the optimizer 48 | if torch.cuda.is_available(): 49 | device = "cuda" 50 | model.to(device) 51 | d_model = d_model.to(device) if params.domain_adapt else None 52 | vcr_model = vcr_model.to(device) if params.is_register else None 53 | if torch.cuda.device_count() > 1: 54 | vcr_model = torch.nn.DataParallel(vcr_model) 55 | else: 56 | device = "cpu" 57 | 58 | log_string('Model device: {}'.format(device)) 59 | 60 | return model, device, d_model, vcr_model 61 | 62 | 63 | def load_weights(weights, model, optimizer=None, scheduler=None): 64 | starting_epoch = 0 65 | 66 | if weights is None or weights == "": 67 | return starting_epoch 68 | 69 | if not os.path.exists(weights): 70 | log_string("error: checkpoint not exists!") 71 | 72 | checkpoint_dict = torch.load(weights, map_location=torch.device('cpu')) 73 | if 'state_dict' in checkpoint_dict.keys(): 74 | saved_state_dict = checkpoint_dict['state_dict'] 75 | else: 76 | saved_state_dict = checkpoint_dict 77 | 78 | model_state_dict = model.state_dict() # 获取已创建net的state_dict 79 | saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_state_dict} 80 | model_state_dict.update(saved_state_dict) 81 | model.load_state_dict(model_state_dict, strict=True) 82 | 83 | if 'epoch' in checkpoint_dict.keys(): 84 | starting_epoch = checkpoint_dict['epoch'] + 1 85 | 86 | if 'optimizer' in checkpoint_dict.keys() and optimizer is not None: 87 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 88 | 89 | if 'scheduler' in checkpoint_dict.keys() and scheduler is not None: 90 | scheduler.load_state_dict(checkpoint_dict['scheduler']) 91 | 92 | log_string("load checkpoint " + weights + " starting_epoch: " + str(starting_epoch)) 93 | torch.cuda.empty_cache() 94 | 95 | return starting_epoch 96 | -------------------------------------------------------------------------------- /models/vcrnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/vcrnet/__init__.py -------------------------------------------------------------------------------- /models/vcrnet/vcrnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from misc.utils import ModelParams 6 | from models.vcrnet.transformer import Transformer 7 | from models.vlpdnet.lpdnet_model import LPDNet, LPDNetOrign 8 | 9 | 10 | class PoseSolver(nn.Module): 11 | def __init__(self, model_params, origin=False): 12 | super(PoseSolver, self).__init__() 13 | 14 | self.pointer = Transformer(model_params=model_params) 15 | 16 | self.head = EPCOR(model_params=model_params) 17 | 18 | self.svd = SVD_Weighted() 19 | 20 | self.loss = torch.nn.functional.mse_loss 21 | 22 | self.origin = origin 23 | 24 | def forward(self, src, tgt, src_embedding, tgt_embedding, positive_T, svd=True): 25 | # input: #[B,1,num,3] [B,1,num,3] [B,1,C,num,1] [B,posi_num,C,num,1] 26 | # expected: [B,C,num] 27 | batch, posi_num, num_points, C = tgt.size() 28 | # src = src.repeat(1,posi_num,1,1).transpose(-2,-1).contiguous().view(-1,C,num_points) 29 | src = src.transpose(-2, -1).contiguous().view(-1, C, num_points) 30 | tgt = tgt.transpose(-2, -1).contiguous().view(-1, C, num_points) 31 | C = tgt_embedding.size(2) 32 | src_embedding = src_embedding.squeeze(-1).repeat(1, posi_num, 1, 1).contiguous().view(-1, C, num_points) 33 | tgt_embedding = tgt_embedding.squeeze(-1).contiguous().view(-1, C, num_points) 34 | 35 | src_embedding_p, tgt_embedding_p = self.pointer(src_embedding, tgt_embedding) 36 | src_embedding = src_embedding + src_embedding_p 37 | tgt_embedding = tgt_embedding + tgt_embedding_p 38 | 39 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.head(src_embedding, tgt_embedding, src, tgt) 40 | 41 | if svd: 42 | with torch.no_grad(): 43 | rotation_ab_pred, translation_ab_pred = self.svd(src, src_corr, src_weight) 44 | 45 | if self.origin: 46 | b, _, num = tgt.size() 47 | mask = outlier_src_mask.contiguous().view(b, 1, num).repeat(1, 3, 1) 48 | srcK = torch.masked_fill(src, mask, 0) 49 | src_corrK = torch.masked_fill(src_corr, mask, 0) 50 | return {'rotation_ab_pred': rotation_ab_pred, 'translation_ab_pred': translation_ab_pred, 51 | 'srcK': srcK, 'src_corrK': src_corrK} 52 | 53 | if positive_T is None: 54 | return rotation_ab_pred, translation_ab_pred 55 | 56 | if svd: 57 | loss = self.cor_loss(src, src_corr, outlier_src_mask, positive_T) 58 | 59 | rotation_ab, translation_ab, loss_pose = self.pose_loss(rotation_ab_pred, translation_ab_pred, positive_T) 60 | 61 | return {'rotation_ab_pred': rotation_ab_pred, 'translation_ab_pred': translation_ab_pred, 62 | 'rotation_ab': rotation_ab, 'translation_ab': translation_ab, 63 | 'loss_point': loss, 'loss_pose': loss_pose, 'mask_tgt': mask_tgt} 64 | else: 65 | return {'rotation_ab_pred': None, 'translation_ab_pred': None, 66 | 'rotation_ab': None, 'translation_ab': None, 67 | 'loss_point': None, 'loss_pose': None, 'mask_tgt': mask_tgt} 68 | 69 | def pose_loss(self, rotation_ab_pred, translation_ab_pred, gt_T): 70 | 71 | gt_T = gt_T.view(-1, 4, 4) 72 | rotation_ab = gt_T[:, :3, :3] 73 | translation_ab = gt_T[:, :3, 3] 74 | batch_size = translation_ab.size(0) 75 | identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1) 76 | loss_pose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \ 77 | + F.mse_loss(translation_ab_pred, translation_ab) 78 | return rotation_ab, translation_ab, loss_pose 79 | 80 | def cor_loss(self, src, src_corr, outlier_src_mask, gt_T): 81 | 82 | gt_T = gt_T.view(-1, 4, 4) 83 | rotation_ab = gt_T[:, :3, :3] 84 | translation_ab = gt_T[:, :3, 3] 85 | 86 | transformed_srcK = torch.matmul(rotation_ab, src) + translation_ab.unsqueeze(2) 87 | 88 | b, _, num = transformed_srcK.size() 89 | mask = outlier_src_mask.contiguous().view(b, 1, num).repeat(1, 3, 1) 90 | 91 | transformed_srcK = torch.masked_fill(transformed_srcK, mask, 0) 92 | src_corrK = torch.masked_fill(src_corr, mask, 0) 93 | 94 | loss = self.loss(transformed_srcK, src_corrK) 95 | 96 | return loss 97 | 98 | 99 | class VCRNet(nn.Module): 100 | def __init__(self, model_params: ModelParams): 101 | super(VCRNet, self).__init__() 102 | if model_params.featnet == "lpdnet": 103 | self.emb_nn = LPDNet(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels) 104 | elif model_params.featnet.lower() == "lpdnetorigin": 105 | self.emb_nn = LPDNetOrign(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels) 106 | else: 107 | print("featnet error") 108 | self.solver = PoseSolver(model_params, origin=True) 109 | 110 | def forward(self, *input): 111 | src = input[0] 112 | tgt = input[1] 113 | 114 | src_embedding = self.emb_nn(src) 115 | tgt_embedding = self.emb_nn(tgt) 116 | 117 | src = src.unsqueeze(1) 118 | tgt = tgt.unsqueeze(1) 119 | src_embedding = src_embedding.unsqueeze(1) 120 | tgt_embedding = tgt_embedding.unsqueeze(1) 121 | 122 | # input: #[B,1,num,3] [B,posi_num,num,3] [B,1,C,num,1] [B,posi_num,C,num,1] 123 | reg_dict = self.solver(src, tgt, src_embedding, tgt_embedding, None) 124 | rotation_ab = reg_dict['rotation_ab_pred'] 125 | translation_ab = reg_dict['translation_ab_pred'] 126 | rotation_ba = rotation_ab.transpose(2, 1).contiguous() 127 | translation_ba = -torch.matmul(rotation_ba, translation_ab.unsqueeze(2)).squeeze(2) 128 | srcK = reg_dict['srcK'] 129 | src_corrK = reg_dict['src_corrK'] 130 | 131 | return srcK, src_corrK, rotation_ab, translation_ab, rotation_ba, translation_ba 132 | 133 | 134 | class SVD_Weighted(nn.Module): 135 | def __init__(self): 136 | super(SVD_Weighted, self).__init__() 137 | 138 | def forward(self, src, src_corr, weights: torch.Tensor = None): 139 | _EPS = 1e-5 # To prevent division by zero 140 | a = src.transpose(-1, -2) 141 | b = src_corr.transpose(-1, -2) 142 | batch_size, num_src, dims = a.size() 143 | 144 | if weights is None: 145 | weights = torch.ones(size=(batch_size, num_src), device=torch.device('cuda')) / num_src 146 | 147 | """Compute rigid transforms between two point sets 148 | 149 | Args: 150 | a (torch.Tensor): (B, M, 3) points 151 | b (torch.Tensor): (B, N, 3) points 152 | weights (torch.Tensor): (B, M) 153 | 154 | Returns: 155 | Transform T (B, 3, 4) to get from a to b, i.e. T*a = b 156 | """ 157 | 158 | weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS) 159 | centroid_a = torch.sum(a * weights_normalized, dim=1) 160 | centroid_b = torch.sum(b * weights_normalized, dim=1) 161 | a_centered = a - centroid_a[:, None, :] 162 | b_centered = b - centroid_b[:, None, :] 163 | cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized) 164 | 165 | # Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3] 166 | # and choose based on determinant to avoid flips 167 | u, s, v = torch.svd(cov, some=False, compute_uv=True) 168 | rot_mat_pos = v @ u.transpose(-1, -2) 169 | v_neg = v.clone() 170 | v_neg[:, :, 2] *= -1 171 | rot_mat_neg = v_neg @ u.transpose(-1, -2) 172 | rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg) 173 | assert torch.all(torch.det(rot_mat) > 0) 174 | 175 | # Compute translation (uncenter centroid) 176 | translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None] 177 | 178 | return rot_mat, translation.view(batch_size, 3) 179 | 180 | 181 | class EPCOR(nn.Module): 182 | """Generate the VCP points based on K most similar points 183 | 184 | """ 185 | 186 | def __init__(self, model_params): 187 | super(EPCOR, self).__init__() 188 | self.model_params = model_params 189 | 190 | def forward(self, *input): 191 | 192 | src_emb = input[0] 193 | tgt_emb = input[1] 194 | src = input[2] 195 | tgt = input[3] 196 | 197 | if self.training: 198 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.getCopairALL(src, src_emb, tgt, tgt_emb) 199 | else: 200 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.selectCom_adap(src, src_emb, tgt, tgt_emb) 201 | 202 | return src_corr, src_weight, outlier_src_mask, mask_tgt 203 | 204 | def get_sparse_w(self, pairwise_distance, tau, K): 205 | batch_size, num_points, num_points_t = pairwise_distance.size() 206 | tgt_K = int(num_points_t * tau) 207 | src_K = int(num_points * tau) 208 | 209 | scoresSoftCol = torch.softmax(pairwise_distance, dim=2) # [b,num,num] 210 | scoresColSum = torch.sum(scoresSoftCol, dim=1, keepdim=True) 211 | topmin = scoresColSum.topk(k=tgt_K, dim=-1, largest=False)[0][..., -1].unsqueeze(-1) 212 | mask_tgt = scoresColSum < topmin 213 | # a = torch.sum(mask_tgt, dim=-1) 214 | 215 | scoresSoftRow = torch.softmax(pairwise_distance, dim=1) # [b,num,num] 216 | scoresRowSum = torch.sum(scoresSoftRow, dim=2, keepdim=True) 217 | topmin = scoresRowSum.topk(k=src_K, dim=1, largest=False)[0][:, -1, :].unsqueeze(-1) 218 | mask_src = scoresRowSum < topmin 219 | # mask_src = scoresRowSum < tau 220 | 221 | mask = mask_src + mask_tgt 222 | 223 | s = (torch.sum(mask_src, dim=-2) + torch.sum(mask_tgt, dim=-1)) / torch.full(size=(batch_size, 1), 224 | fill_value=num_points + num_points_t).to( 225 | mask_tgt.device).type_as(pairwise_distance) 226 | 227 | topk = scoresSoftCol.topk(k=K, dim=2)[0][..., -1].unsqueeze(-1) 228 | mask = mask + scoresSoftCol < topk 229 | 230 | src_tgt_weight_sparse = torch.masked_fill(scoresSoftCol, mask, 0) 231 | scoresColSum_sparse = torch.sum(src_tgt_weight_sparse, dim=-1, keepdim=True) # [B, num, 1] 232 | scoresColSum_sparse = torch.masked_fill(scoresColSum_sparse, scoresColSum_sparse < 1e-5, 1e-5) # 防止除以0 233 | src_tgt_weight = torch.div(src_tgt_weight_sparse, scoresColSum_sparse) 234 | 235 | # src_tgt_weight = torch.softmax(src_tgt_weight, dim=2) 236 | 237 | # 计算 238 | val_sum = torch.sum(src_tgt_weight, dim=-1, keepdim=True) # [B, num, 1] 239 | val_sum_inlier = torch.masked_fill(val_sum, mask_src, 0.0).squeeze(-1) 240 | sum_ = torch.sum(val_sum_inlier, dim=[-1], keepdim=True) 241 | src_weight = torch.div(val_sum_inlier, sum_) 242 | 243 | return s, src_weight, mask_src, mask_tgt, src_tgt_weight 244 | 245 | def selectCom_adap(self, src, src_emb, tgt, tgt_emb, tau=0.3): 246 | 247 | batch_size, _, num_points = src.size() 248 | batch_size, _, num_points_t = tgt.size() 249 | 250 | inner = -2 * torch.matmul(src_emb.transpose(2, 1).contiguous(), tgt_emb) 251 | xx = torch.sum(src_emb ** 2, dim=1, keepdim=True).transpose(2, 1).contiguous() 252 | yy = torch.sum(tgt_emb ** 2, dim=1, keepdim=True) 253 | 254 | pairwise_distance = -xx - inner 255 | src_tgt_weight = pairwise_distance - yy 256 | 257 | s, src_weight, mask_src, mask_tgt, src_tgt_weight = self.get_sparse_w(src_tgt_weight, tau, K=1) 258 | 259 | src_corr = torch.matmul(tgt, src_tgt_weight.transpose(2, 1).contiguous()) 260 | 261 | return src_corr, src_weight, mask_src, mask_tgt 262 | 263 | def getCopairALL(self, src, src_emb, tgt, tgt_emb): 264 | 265 | batch_size, n_dims, num_points = src.size() 266 | # Calculate the distance matrix 267 | inner = -2 * torch.matmul(src_emb.transpose(2, 1).contiguous(), tgt_emb) 268 | xx = torch.sum(src_emb ** 2, dim=1, keepdim=True).transpose(2, 1).contiguous() 269 | yy = torch.sum(tgt_emb ** 2, dim=1, keepdim=True) 270 | 271 | pairwise_distance = -xx - inner 272 | pairwise_distance = pairwise_distance - yy 273 | 274 | scores = torch.softmax(pairwise_distance, dim=2) # [b,num,num] 275 | src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous()) 276 | 277 | src_weight = torch.ones(size=(batch_size, num_points), device=src_corr.device) / num_points 278 | 279 | outlier_src_mask = torch.full((batch_size, num_points, 1), False, device=src_corr.device, dtype=torch.bool) 280 | mask_tgt = torch.full((batch_size, num_points, 1), False, device=src_corr.device, dtype=torch.bool) 281 | 282 | return src_corr, src_weight, outlier_src_mask, mask_tgt 283 | 284 | def return_mask(self): 285 | return self.mask_src, self.mask_tgt 286 | -------------------------------------------------------------------------------- /models/vlpdnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/vlpdnet/__init__.py -------------------------------------------------------------------------------- /models/vlpdnet/lpdnet_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from misc.point_utils import get_graph_feature_Origin, get_graph_feature, knn 6 | 7 | 8 | class LPDNetOrign(nn.Module): 9 | def __init__(self, emb_dims=512, channels=[64, 64, 128]): 10 | super(LPDNetOrign, self).__init__() 11 | self.act_f = nn.LeakyReLU(inplace=True, negative_slope=1e-2) 12 | self.k = 20 13 | self.emb_dims = emb_dims 14 | self.convDG1 = nn.Sequential(nn.Conv2d(channels[0] * 2, channels[0], kernel_size=1, bias=False), 15 | nn.BatchNorm2d(channels[0]), self.act_f) 16 | self.convDG2 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False), 17 | nn.BatchNorm2d(channels[0]), self.act_f) 18 | self.convSN1 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False), 19 | nn.BatchNorm2d(channels[0]), self.act_f) 20 | self.convSN2 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False), 21 | nn.BatchNorm2d(channels[0]), self.act_f) 22 | self.conv1_lpd = nn.Sequential(nn.Conv1d(3, channels[0], kernel_size=1, bias=False), 23 | nn.BatchNorm1d(channels[0]), self.act_f) 24 | self.conv2_lpd = nn.Sequential(nn.Conv1d(channels[0], channels[0], kernel_size=1, bias=False), 25 | nn.BatchNorm1d(channels[0]), self.act_f) 26 | self.conv3_lpd = nn.Sequential(nn.Conv1d(channels[0], channels[1], kernel_size=1, bias=False), 27 | nn.BatchNorm1d(channels[1]), self.act_f) 28 | self.conv4_lpd = nn.Sequential(nn.Conv1d(channels[1], channels[2], kernel_size=1, bias=False), 29 | nn.BatchNorm1d(channels[2]), self.act_f) 30 | self.conv5_lpd = nn.Sequential(nn.Conv1d(channels[2], self.emb_dims, kernel_size=1, bias=False), 31 | nn.BatchNorm1d(self.emb_dims), self.act_f) 32 | 33 | # input x: # [B,1,num,num_dims] 34 | # output x: # [b,emb_dims,num,1] 35 | def forward(self, x): 36 | x = torch.squeeze(x, dim=1).transpose(2, 1) # [B,num_dims,num] 37 | batch_size, num_dims, num_points = x.size() 38 | 39 | xInit3d = x 40 | 41 | x = self.conv1_lpd(x) 42 | x = self.conv2_lpd(x) 43 | 44 | # Serial structure 45 | # Danymic Graph cnn for feature space 46 | x = get_graph_feature_Origin(x, k=self.k) # [b,channel*2,num,20] 47 | x = self.convDG1(x) # [b,channel,num,20] 48 | x = self.convDG2(x) # [b,channel,num,20] 49 | x = x.max(dim=-1, keepdim=True)[0] # [b,channel,num,1] 50 | 51 | # Spatial Neighborhood fusion for cartesian space 52 | idx = knn(xInit3d, k=self.k) 53 | x = get_graph_feature_Origin(x, idx=idx, k=self.k, cat=False) # [b,channel,num,20] 54 | x = self.convSN1(x) # [b,channel,num,20] 55 | x = self.convSN2(x) # [b,channel,num,20] 56 | x = x.max(dim=-1, keepdim=True)[0].squeeze(-1) # [b,channel,num] 57 | 58 | x = self.conv3_lpd(x) # [b,channel,num] 59 | x = self.conv4_lpd(x) # [b,128,num] 60 | x = self.conv5_lpd(x) # [b,emb_dims,num] 61 | x = x.unsqueeze(-1) # [b,emb_dims,num,1] 62 | 63 | return x 64 | 65 | 66 | class LPDNet(nn.Module): 67 | """Implement for LPDNet using pytorch package. 68 | 69 | """ 70 | 71 | def __init__(self, emb_dims=1024, channels=[64, 128, 256]): 72 | super(LPDNet, self).__init__() 73 | self.negative_slope = 0.0 74 | self.k = 20 75 | self.emb_dims = emb_dims 76 | # [b,6,num,20] 77 | self.convDG1 = nn.Sequential(nn.Conv2d(channels[0] * 2, channels[1], kernel_size=1, bias=True), 78 | nn.LeakyReLU(negative_slope=self.negative_slope)) 79 | self.convDG2 = nn.Sequential(nn.Conv2d(channels[1], channels[1], kernel_size=1, bias=True), 80 | nn.LeakyReLU(negative_slope=self.negative_slope)) 81 | self.convSN1 = nn.Sequential(nn.Conv2d(channels[1] * 2, channels[2], kernel_size=1, bias=True), 82 | nn.LeakyReLU(negative_slope=self.negative_slope)) 83 | 84 | self.conv1_lpd = nn.Conv1d(3, channels[0], kernel_size=1, bias=True) 85 | self.conv2_lpd = nn.Conv1d(channels[0], channels[0], kernel_size=1, bias=True) 86 | self.conv3_lpd = nn.Conv1d((channels[1] + channels[1] + channels[2]), self.emb_dims, kernel_size=1, bias=True) 87 | 88 | # input x: # [B,num_dims,num] 89 | # output x: # [b,emb_dims,num] 90 | def forward(self, x): 91 | x = torch.squeeze(x, dim=1).transpose(2, 1) # [B,num_dims,num] 92 | batch_size, num_dims, num_points = x.size() 93 | # 94 | xInit3d = x 95 | 96 | x = F.leaky_relu(self.conv1_lpd(x), negative_slope=self.negative_slope) 97 | x = F.leaky_relu(self.conv2_lpd(x), negative_slope=self.negative_slope) 98 | 99 | # Serial structure 100 | # Dynamic Graph cnn for feature space 101 | x = get_graph_feature(x, k=self.k) # [b,64*2,num,20] 102 | x = self.convDG1(x) # [b,128,num,20] 103 | x1 = x.max(dim=-1, keepdim=True)[0] # [b,128,num,1] 104 | x = self.convDG2(x) # [b,128,num,20] 105 | x2 = x.max(dim=-1, keepdim=True)[0] # [b,128,num,1] 106 | 107 | # Spatial Neighborhood fusion for cartesian space 108 | idx = knn(xInit3d, k=self.k) 109 | x = get_graph_feature(x2.squeeze(-1), idx=idx, k=self.k) # [b,128*2,num,20] 110 | x = self.convSN1(x) # [b,256,num,20] 111 | x3 = x.max(dim=-1, keepdim=True)[0] # [b,256,num,1] 112 | 113 | x = torch.cat((x1, x2, x3), dim=1).squeeze(-1) # [b,512,num] 114 | x = F.leaky_relu(self.conv3_lpd(x), negative_slope=self.negative_slope).view(batch_size, -1, 115 | num_points) # [b,emb_dims,num] 116 | x = x.unsqueeze(-1) # [b,emb_dims,num,1] 117 | return x 118 | -------------------------------------------------------------------------------- /models/vlpdnet/vlpdnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import math 4 | 5 | import MinkowskiEngine as ME 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.parallel 11 | import torch.utils.data 12 | 13 | from models.minkloc3d.minkpool import MinkPool 14 | from models.vcrnet.vcrnet import PoseSolver 15 | from models.vlpdnet.lpdnet_model import LPDNet, LPDNetOrign 16 | 17 | 18 | class vLPDNet(nn.Module): 19 | def __init__(self, params): 20 | super(vLPDNet, self).__init__() 21 | self.params = params 22 | model_params = params.model_params 23 | self.featnet = model_params.featnet 24 | if model_params.featnet == "lpdnet": 25 | self.emb_nn = LPDNet(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels) 26 | elif model_params.featnet.lower() == "lpdnetorigin": 27 | self.emb_nn = LPDNetOrign(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels) 28 | else: 29 | print("featnet error") 30 | if params.lpd_fixed: 31 | self.emb_nn.requires_grad = False 32 | self.emb_nn.eval() 33 | # self.net_vlad = NetVLADLoupe(feature_size=args.emb_dims, cluster_size=64, 34 | # output_dim=args.output_dim, gating=True, add_batch_norm=True, 35 | # is_training=True) 36 | self.net_vlad = MinkPool(model_params, in_channels=model_params.emb_dims) 37 | 38 | self.is_register = params.is_register 39 | # self.is_register = True 40 | if params.domain_adapt: 41 | self.is_register = False 42 | self.params.lpd_fixed = False 43 | if params.is_register: 44 | self.pose_solver = PoseSolver(model_params=model_params) 45 | else: 46 | self.pose_solver = None 47 | 48 | # input x [B,1,N,3] 49 | # intermediate feat_x [B,C,N,1] 50 | # output g_feat [N,3] 51 | def forward(self, source_batch=None, target_batch=None, gt_T=None): 52 | if self.params.lpd_fixed: 53 | self.emb_nn.eval() 54 | with torch.no_grad(): 55 | if self.is_register and source_batch != None: 56 | source = source_batch["cloud"].unsqueeze(1) 57 | feat_x_s = self.emb_nn(source) 58 | target = target_batch["cloud"].unsqueeze(1) 59 | feat_x_t = self.emb_nn(target) 60 | else: 61 | if self.is_register and source_batch != None: 62 | source = source_batch["cloud"].unsqueeze(1) 63 | feat_x_s = self.emb_nn(source) 64 | target = target_batch["cloud"].unsqueeze(1) 65 | feat_x_t = self.emb_nn(target) 66 | 67 | # registration 68 | # batch_R, batch_t, transformed_xyz = self.reg_model(feat_x_s, feat_x_t, source) 69 | if self.is_register and source_batch != None: 70 | reg_dict = self.pose_solver(source, target, feat_x_s.unsqueeze(1), feat_x_t.unsqueeze(1), gt_T, svd=False) 71 | fs = [] 72 | pcls = [] 73 | mask_tgt = reg_dict['mask_tgt'].squeeze(-1) 74 | for i in range(mask_tgt.shape[0]): 75 | mask = ~mask_tgt[i] 76 | # a = torch.sum(mask) 77 | fs.append(feat_x_t[i][:, mask, :]) 78 | pcls.append(target[i][:, mask, :]) 79 | else: 80 | reg_dict = None 81 | fs = [] 82 | pcls = [] 83 | for i in range(feat_x_t.shape[0]): 84 | fs.append(feat_x_t[i]) 85 | pcls.append(target[i]) 86 | 87 | # g_fea = self.net_vlad(feat_x_t, target) 88 | g_fea = self.net_vlad(fs, pcls) 89 | 90 | if isinstance(g_fea, dict): 91 | g_fea = g_fea['g_fea'] 92 | 93 | return { 94 | 'embeddings': g_fea, 95 | 'reg_dict': reg_dict 96 | } 97 | 98 | return {'g_fea': g_fea, 'l_fea': feat_x_t, 'node_fea': None} 99 | 100 | def registration(self, src, tgt, src_embedding, tgt_embedding, gt_T): 101 | 102 | reg_dict = self.pose_solver(src, tgt, src_embedding, tgt_embedding, gt_T) 103 | 104 | return reg_dict 105 | 106 | def registration_only(self, src, tgt, gt_T): 107 | # input x [B,1,N,3] 108 | src_embedding = self.emb_nn(src) 109 | src_embedding = src_embedding.permute(0, 3, 1, 2) 110 | if self.is_register: 111 | tgt_embedding = self.emb_nn(tgt) 112 | tgt_embedding = tgt_embedding.permute(0, 3, 1, 2) 113 | reg_dict = self.pose_solver(src, tgt, src_embedding, tgt_embedding, gt_T) 114 | return reg_dict 115 | 116 | 117 | class NetVLADLoupe(nn.Module): 118 | def __init__(self, feature_size, cluster_size, output_dim, 119 | gating=True, add_batch_norm=True, is_training=True): 120 | super(NetVLADLoupe, self).__init__() 121 | self.feature_size = feature_size 122 | self.output_dim = output_dim 123 | self.is_training = is_training 124 | self.gating = gating 125 | self.add_batch_norm = add_batch_norm 126 | self.cluster_size = cluster_size 127 | self.softmax = nn.Softmax(dim=-1) 128 | self.cluster_weights = nn.Parameter(torch.randn( 129 | feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 130 | self.cluster_weights2 = nn.Parameter(torch.randn( 131 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size)) 132 | self.hidden1_weights = nn.Parameter( 133 | torch.randn(cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size)) 134 | 135 | if add_batch_norm: 136 | self.cluster_biases = None 137 | self.bn1 = nn.BatchNorm1d(cluster_size) 138 | else: 139 | self.cluster_biases = nn.Parameter(torch.randn( 140 | cluster_size) * 1 / math.sqrt(feature_size)) 141 | self.bn1 = None 142 | 143 | self.bn2 = nn.BatchNorm1d(output_dim) 144 | 145 | self.weight_initialization_all() 146 | 147 | # [B, dims, num_points, 1] 148 | def forward(self, x, coord=None): 149 | max_samples = x.size(2) # num of points 150 | x = x.transpose(1, 3).contiguous() 151 | x = x.view((-1, max_samples, self.feature_size)) 152 | activation = torch.matmul(x, self.cluster_weights) 153 | if self.add_batch_norm: 154 | # activation = activation.transpose(1,2).contiguous() 155 | activation = activation.view(-1, self.cluster_size) 156 | activation = self.bn1(activation) 157 | activation = activation.view(-1, 158 | max_samples, self.cluster_size) 159 | # activation = activation.transpose(1,2).contiguous() 160 | else: 161 | activation = activation + self.cluster_biases 162 | activation = self.softmax(activation) 163 | activation = activation.view((-1, max_samples, self.cluster_size)) 164 | 165 | a_sum = activation.sum(-2, keepdim=True) 166 | a = a_sum * self.cluster_weights2 167 | 168 | activation = torch.transpose(activation, 2, 1) 169 | x = x.view((-1, max_samples, self.feature_size)) 170 | vlad = torch.matmul(activation, x) 171 | vlad = torch.transpose(vlad, 2, 1) 172 | vlad = vlad - a 173 | 174 | vlad = F.normalize(vlad, dim=1, p=2) 175 | # print(vlad.shape,self.cluster_size,self.feature_size,self.cluster_size * self.feature_size) 176 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size)) 177 | vlad = F.normalize(vlad, dim=1, p=2) 178 | 179 | vlad = torch.matmul(vlad, self.hidden1_weights) 180 | 181 | vlad = self.bn2(vlad) 182 | 183 | if self.gating: 184 | vlad = self.context_gating(vlad) 185 | 186 | return vlad 187 | 188 | def weight_initialization_all(self): 189 | for m in self.modules(): 190 | if isinstance(m, nn.Sequential): 191 | for m in self.modules(): 192 | self.weight_init(m) 193 | else: 194 | self.weight_init(m) 195 | 196 | def weight_init(self, m): 197 | if isinstance(m, nn.Conv1d): 198 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 199 | if isinstance(m, nn.Conv2d): 200 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 201 | elif isinstance(m, nn.BatchNorm1d): 202 | nn.init.constant_(m.weight, 1) 203 | nn.init.constant_(m.bias, 0) 204 | elif isinstance(m, nn.BatchNorm2d): 205 | nn.init.constant_(m.weight, 1) 206 | nn.init.constant_(m.bias, 0) 207 | 208 | 209 | class GatingContext(nn.Module): 210 | def __init__(self, dim, add_batch_norm=True): 211 | super(GatingContext, self).__init__() 212 | self.dim = dim 213 | self.add_batch_norm = add_batch_norm 214 | self.gating_weights = nn.Parameter( 215 | torch.randn(dim, dim) * 1 / math.sqrt(dim)) 216 | self.sigmoid = nn.Sigmoid() 217 | 218 | if add_batch_norm: 219 | self.gating_biases = None 220 | self.bn1 = nn.BatchNorm1d(dim) 221 | else: 222 | self.gating_biases = nn.Parameter( 223 | torch.randn(dim) * 1 / math.sqrt(dim)) 224 | self.bn1 = None 225 | 226 | self.weight_initialization_all() 227 | 228 | def weight_initialization_all(self): 229 | for m in self.modules(): 230 | if isinstance(m, nn.Sequential): 231 | for m in self.modules(): 232 | self.weight_init(m) 233 | else: 234 | self.weight_init(m) 235 | 236 | def weight_init(self, m): 237 | if isinstance(m, nn.Conv1d): 238 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 239 | if isinstance(m, nn.Conv2d): 240 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 241 | elif isinstance(m, nn.BatchNorm1d): 242 | nn.init.constant_(m.weight, 1) 243 | nn.init.constant_(m.bias, 0) 244 | elif isinstance(m, nn.BatchNorm2d): 245 | nn.init.constant_(m.weight, 1) 246 | nn.init.constant_(m.bias, 0) 247 | 248 | def forward(self, x): 249 | gates = torch.matmul(x, self.gating_weights) 250 | 251 | if self.add_batch_norm: 252 | if gates.size(0) == 1: 253 | gates = gates 254 | else: 255 | gates = self.bn1(gates) 256 | else: 257 | gates = gates + self.gating_biases 258 | 259 | gates = self.sigmoid(gates) 260 | 261 | activation = x * gates 262 | 263 | return activation 264 | 265 | 266 | def extract_features(model, 267 | xyz, 268 | rgb=None, 269 | normal=None, 270 | voxel_size=0.05, 271 | device=None, 272 | skip_check=False, 273 | is_eval=True): 274 | ''' 275 | xyz is a N x 3 matrix 276 | rgb is a N x 3 matrix and all color must range from [0, 1] or None 277 | normal is a N x 3 matrix and all normal range from [-1, 1] or None 278 | 279 | if both rgb and normal are None, we use Nx1 one vector as an input 280 | 281 | if device is None, it tries to use gpu by default 282 | 283 | if skip_check is True, skip rigorous checks to speed up 284 | 285 | model = model.to(device) 286 | xyz, feats = extract_features(model, xyz) 287 | ''' 288 | if is_eval: 289 | model.eval() 290 | 291 | if device is None: 292 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 293 | 294 | feats = [] 295 | if rgb is not None: 296 | # [0, 1] 297 | feats.append(rgb - 0.5) 298 | 299 | if normal is not None: 300 | # [-1, 1] 301 | feats.append(normal / 2) 302 | 303 | if rgb is None and normal is None: 304 | feats.append(np.ones((len(xyz), 1))) 305 | 306 | feats = np.hstack(feats) 307 | 308 | # Voxelize xyz and feats 309 | coords = np.floor(xyz / voxel_size) 310 | 311 | coords, inds = ME.utils.sparse_quantize(coords, return_index=True) 312 | 313 | # Convert to batched coords compatible with ME 314 | coords = ME.utils.batched_coordinates([coords]) 315 | return_coords = xyz[inds] 316 | 317 | feats = feats[inds] 318 | 319 | feats = torch.tensor(feats, dtype=torch.float32) 320 | coords = coords.clone().detach() 321 | 322 | stensor = ME.SparseTensor(feats, coordinates=coords, device=device) 323 | 324 | return return_coords, model(stensor).F 325 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Author: Zhijian Qiao 2 | # Shanghai Jiao Tong University 3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git 4 | 5 | 6 | import argparse 7 | import os 8 | import sys 9 | 10 | import torch 11 | 12 | sys.path.append(os.path.join(os.path.abspath("./"), "../")) 13 | from training.trainer import Trainer 14 | from misc.utils import MinkLocParams 15 | from misc.log import log_string 16 | from dataloader.dataset_utils import make_dataloaders 17 | import pdb 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser(description='Train Minkowski Net embeddings using BatchHard negative mining') 21 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file') 22 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file') 23 | parser.add_argument('--debug', dest='debug', action='store_true') 24 | parser.add_argument('--qzj_debug', action='store_true') 25 | parser.set_defaults(debug=False) 26 | parser.add_argument('--checkpoint', type=str, required=False, help='Trained model weights', default="") 27 | parser.add_argument('--visualize', dest='visualize', action='store_true') 28 | parser.set_defaults(visualize=False) 29 | 30 | args = parser.parse_args() 31 | log_string('Training config path: {}'.format(args.config)) 32 | log_string('Model config path: {}'.format(args.model_config)) 33 | log_string('Debug mode: {}'.format(args.debug)) 34 | log_string('Visualize: {}'.format(args.visualize)) 35 | 36 | params = MinkLocParams(args.config, args.model_config) 37 | if args.qzj_debug: 38 | params.batch_size = 4 39 | if params.is_register: 40 | params.reg.batch_size = 4 41 | params.print() 42 | 43 | if args.debug: 44 | torch.autograd.set_detect_anomaly(True) 45 | 46 | dataloaders = make_dataloaders(params, debug=args.debug) 47 | trainer = Trainer(dataloaders, params, debug=args.debug, visualize=args.visualize, checkpoint=args.checkpoint) 48 | if args.debug: 49 | pdb.set_trace() 50 | 51 | trainer.do_train() 52 | -------------------------------------------------------------------------------- /training/optimizer_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from misc.utils import MinkLocParams 4 | 5 | 6 | def optimizer_factory(params: MinkLocParams, model, d_model): 7 | # Training elements 8 | if params.weight_decay is None or params.weight_decay == 0: 9 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr) 10 | optimizer_d = torch.optim.Adam(d_model.parameters(), lr=params.d_lr) if params.domain_adapt else None 11 | print("optimizer use Adam with lr {}".format(params.lr)) 12 | if optimizer_d is not None: 13 | print("optimizer_d use Adam with and lr {}".format(params.d_lr)) 14 | else: 15 | if params.lpd_fixed: 16 | ignored_params = list(map(id, model.emb_nn.parameters())) 17 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters()) 18 | optimizer = torch.optim.Adam([ 19 | {'params': base_params}, 20 | {'params': model.emb_nn.parameters(), 'lr': 0.0}], lr=params.lr, weight_decay=params.weight_decay) 21 | else: 22 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay) 23 | optimizer_d = torch.optim.Adam(d_model.parameters(), lr=params.d_lr, 24 | weight_decay=params.d_weight_decay) if params.domain_adapt else None 25 | print("optimizer use Adam with weight_decay {} and lr {}".format(params.weight_decay, params.lr)) 26 | if optimizer_d is not None: 27 | print("optimizer_d use Adam with weight_decay {} and lr {}".format(params.d_weight_decay, params.d_lr)) 28 | 29 | return optimizer, optimizer_d 30 | 31 | 32 | def scheduler_factory(params: MinkLocParams, optimizer, optimizer_d=None): 33 | # Training elements 34 | optimizer_d = optimizer if optimizer_d == None else optimizer_d 35 | if params.scheduler is None: 36 | scheduler = None 37 | scheduler_d = None 38 | else: 39 | if params.scheduler == 'CosineAnnealingLR': 40 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.epochs + 1, 41 | eta_min=params.min_lr) 42 | scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=params.epochs + 1, 43 | eta_min=params.min_lr) if params.domain_adapt else None 44 | elif params.scheduler == 'MultiStepLR': 45 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, params.scheduler_milestones, gamma=0.1) 46 | scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optimizer_d, params.scheduler_milestones, 47 | gamma=0.1) if params.domain_adapt else None 48 | else: 49 | scheduler = None 50 | scheduler_d = None 51 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(params.scheduler)) 52 | 53 | return scheduler, scheduler_d 54 | -------------------------------------------------------------------------------- /weights/vlpdnet-registration.t7: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/weights/vlpdnet-registration.t7 --------------------------------------------------------------------------------