├── .gitignore
├── LICENSE
├── README.md
├── config
├── config_baseline.txt
├── config_baseline_gtav5.txt
└── models
│ └── vlpdnet.txt
├── dataloader
├── dataset_utils.py
├── gta5.py
├── oxford.py
└── samplers.py
├── evaluate.py
├── generating_queries
├── generate_test_sets.py
├── generate_test_sets_gtav5.py
├── generate_training_tuples_baseline.py
├── generate_training_tuples_baseline_gtav5.py
└── registration.py
├── loss
├── __init__.py
├── d_loss.py
├── metric_loss.py
├── mmd_loss.py
├── ops
│ └── emd
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── emd.cpp
│ │ ├── emd_cuda.cu
│ │ ├── emd_module.py
│ │ └── setup.py
└── reg_loss.py
├── media
└── pipeline.png
├── misc
├── log.py
├── point_utils.py
└── utils.py
├── models
├── __init__.py
├── discriminator
│ ├── __init__.py
│ ├── resnet.py
│ └── resnetD.py
├── minkloc3d
│ ├── __init__.py
│ ├── minkfpn.py
│ ├── minkpool.py
│ ├── pooling.py
│ ├── resnet_mink.py
│ └── senet_block.py
├── model_factory.py
├── vcrnet
│ ├── __init__.py
│ ├── transformer.py
│ └── vcrnet.py
└── vlpdnet
│ ├── __init__.py
│ ├── lpdnet_model.py
│ └── vlpdnet.py
├── train.py
├── training
├── optimizer_factory.py
├── reg_train.py
└── trainer.py
└── weights
└── vlpdnet-registration.t7
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | generating_queries/pickles
10 |
11 | # Distribution / packaging
12 | .Python
13 | build
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # PyCharm settings
133 | .idea
134 |
135 | *.pyc
136 | *.pth
137 | *.pickle
138 | results*
139 | __pycache__/
140 | checkpoints/
141 | data/
142 | /.idea/
143 | *.zip
144 | *.backup
145 | *.npz
146 | *.npy
147 | *.h5
148 | .fuse*
149 | .~lock*
150 | *.so
151 | *.ckpt.*
152 | *.pyc
153 | *.out.*
154 | /benchmark_datasets/
155 | benchmark_datasets
156 | *benchmark_datasets*
157 | *.pickle
158 | /tf_logs
159 | *gpu_mem_track.txt
160 | *.ckpt
161 | benchmark_datasets/
162 |
163 | tf_logs
164 |
165 | *.log
166 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Intelligent Robotics and Machine Vision Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # vLPD-Net: Registration-aided 3D Point Cloud Learning for Large-Scale Place Recognition
2 |
3 | Our IROS 2021: Registration-aided 3D Point Cloud Learning for Large-Scale Place Recognition.
4 |
5 | [Paper](https://arxiv.org/abs/2012.05018) [Video](https://www.youtube.com/watch?v=Nnj9KUXIFns&t=11s)
6 |
7 | Author: Zhijian Qiao1†, Hanjiang Hu1, 2†, Weiang Shi1, Siyuan Chen1, Zhe Liu1, 3, Hesheng Wang1.
8 |
9 | Shanghai Jiao Tong University1 & Carnegie Mellon University2 & University of Cambridge3
10 |
11 |
12 | 
13 |
14 | ### Introduction
15 | This is the our IROS 2021 work. The synthetic virtual dataset has been utilized through GTA-V to involve multi-environment traverses without laborious human efforts, while the accurate registration ground truth can be obtained at the same time.
16 |
17 | To this end, we propose vLPD-Net, which is a novel registration-aided 3D domain adaptation network for point cloud based place recognition. A structure-aware registration network is proposed to leverage geometry property and co-contextual information between two input point clouds. Along with them, the adversarial training pipeline is implemented to reduce the gap between synthetic and real-world domain.
18 |
19 | ### Citation
20 | If you find this work useful, please cite:
21 | ```
22 | @INPROCEEDINGS{9635878, author={Qiao, Zhijian and Hu, Hanjiang and Shi, Weiang and Chen, Siyuan and Liu, Zhe and Wang, Hesheng}, booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, title={A Registration-aided Domain Adaptation Network for 3D Point Cloud Based Place Recognition}, year={2021}, volume={}, number={}, pages={1317-1322}, doi={10.1109/IROS51168.2021.9635878}}
23 | ```
24 |
25 | ### Environment and Dependencies
26 | Code was tested using Python 3.8 with PyTorch 1.7.1 and MinkowskiEngine 0.5.0 on Ubuntu 18.04 with CUDA 10.2.
27 |
28 | The following Python packages are required:
29 | * PyTorch (version 1.7)
30 | * MinkowskiEngine (version 0.5.0)
31 | * pytorch_metric_learning (version 0.9.94 or above)
32 | * tensorboard
33 | * pandas
34 | * psutil
35 | * bitarray
36 |
37 | ### Datasets
38 | **vLPD-Net** is trained on a subset of Oxford RobotCar datasets and synthetic virtual dataset built on GTAV5. Processed Oxford Robotcar datasets are available in [PointNetVLAD](https://github.com/mikacuy/pointnetvlad.git). To build the synthetic virtual dataset, we use a plugin [DeepGTAV](https://github.com/aitorzip/DeepGTAV) and its expansion [A-virtual-LiDAR-for-DeepGTAV](https://github.com/gdpinchina/A-virtual-LiDAR-for-DeepGTAV), which provide the vehicle with cameras and LiDAR. Besides, we use [VPilot](https://github.com/aitorzip/VPilot) for better interaction with DeepGTAV. The data collected by the plugins is simply saved in txt files, and then processed and saved with the Oxford Robotcar format.
39 |
40 |
41 | The structure of directory is like this:
42 | ```
43 | vLPD-Net
44 | └── benchmark_datasets
45 | ├── GTA5
46 | └── oxford
47 | ```
48 |
49 | To generate pickles, run:
50 | ```generate pickles
51 | cd generating_queries/
52 |
53 | # Generate training tuples for the Oxford Dataset
54 | python generate_training_tuples_baseline.py
55 |
56 | # Generate evaluation Oxford tuples
57 | python generate_test_sets.py
58 |
59 | # In addition, if you also want to experiment on GTAV5, we have open sourced the code for processing this dataset. We hope it will be helpful to you.
60 | # Generate training tuples for the gtav5 Dataset
61 | python generate_training_tuples_baseline_gtav5.py
62 |
63 | # Generate evaluation gtav5 tuples
64 | python generate_test_sets_gtav5.py
65 | ```
66 |
67 | ### Training
68 | We train vLPD-Net on one 2080Ti GPU.
69 |
70 | To train the network, run:
71 |
72 | ```train baseline
73 | # To train vLPD-Net model on the Oxford Dataset
74 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python train.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt
75 |
76 | # To train vLPD-Net model on the GTAV5 Dataset
77 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python train.py --config ./config/config_baseline_gtav5.txt --model_config ./config/models/vlpdnet.txt
78 | ```
79 |
80 | For registration on oxford dataset, we transform point cloud by ourself to generate ground truth. The registration training piplines on two datasets are also a bit different. On oxford dataset, we just use EPCOR in inference while we train with EPCOR on GTAV5 based on a pretrained whole-to-whole registration model like [VCR-Net](https://github.com/qiaozhijian/VCR-Net.git). This is due to the fact that registration on Oxford dataset is naturally whole-to-whole.
81 | ### Pre-trained Models
82 |
83 | Pretrained models are available in `weights` directory
84 | - `vlpdnet-oxford.pth` trained on the Oxford Dataset
85 | - `vlpdnet-gtav5.pth` trained on the GTAV5 Dataset with domain adaptation
86 | - `vlpdnet-registration.t7` trained on the Oxford Dataset for registration.
87 |
88 | ### Evaluation
89 |
90 | To evaluate pretrained models, run the following commands:
91 |
92 | ```eval baseline
93 |
94 | # To evaluate the model trained on the Oxford Dataset
95 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --weights=./weights/vlpdnet-oxford.pth
96 |
97 | # To evaluate the model trained on the GTAV5 Dataset
98 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --weights=./weights/vlpdnet-gtav5.pth
99 |
100 | # To evaluate the model trained on the Oxford Dataset for registration
101 | export OMP_NUM_THREADS=24;CUDA_VISIBLE_DEVICES=0 python evaluate.py --config ./config/config_baseline.txt --model_config ./config/models/vlpdnet.txt --eval_reg=./weights/vlpdnet-registration.t7
102 | ```
103 |
104 | ### Acknowledgment
105 |
106 | [MinkLoc3D](https://github.com/jac99/MinkLoc3D.git)
107 |
108 | [VCR-Net](https://github.com/qiaozhijian/VCR-Net.git)
109 |
110 | [LPD-Net-Pytorch](https://github.com/qiaozhijian/LPD-Net-Pytorch.git)
111 |
112 | ### License
113 | Our code is released under the MIT License (see LICENSE file for details).
114 |
--------------------------------------------------------------------------------
/config/config_baseline.txt:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | num_points = 4096
3 |
4 | dataset_folder = ./benchmark_datasets
5 | queries_folder = ./generating_queries/pickles
6 |
7 | [TRAIN]
8 | num_workers = 16
9 | batch_size = 56
10 | batch_size_limit = 56
11 | batch_expansion_rate = 1.4
12 | batch_expansion_th = 0.7
13 |
14 | lr = 1e-3
15 | epochs = 50
16 | scheduler_milestones = 40
17 |
18 | aug_mode = 2
19 | weight_decay = 1e-3
20 | eval_simple = True
21 |
22 | loss = BatchHardTripletMarginLoss
23 | normalize_embeddings = False
24 | margin = 0.2
25 | swap = True
26 | lamda = 1
27 | lamda_reg = 0.001
28 | lpd_fixed = True
29 |
30 | train_file = training_queries_baseline.pickle
31 |
32 | [REGISTARTION]
33 | batch_size = 30
34 | epochs = 100
35 | num_points = 4096
36 | iter = 1
37 | overlap = 1.0
--------------------------------------------------------------------------------
/config/config_baseline_gtav5.txt:
--------------------------------------------------------------------------------
1 | [DEFAULT]
2 | num_points = 4096
3 |
4 | dataset_folder = ./benchmark_datasets
5 | queries_folder = ./generating_queries/pickles
6 |
7 | [TRAIN]
8 | num_workers = 16
9 | batch_size = 32
10 | batch_size_limit = 42
11 | batch_expansion_rate = 1.4
12 | batch_expansion_th = 0.7
13 |
14 | lr = 1e-3
15 | epochs = 40
16 | scheduler_milestones = 30
17 |
18 | aug_mode = 2
19 | weight_decay = 1e-3
20 | eval_simple = True
21 |
22 | loss = BatchHardTripletMarginLoss
23 | normalize_embeddings = False
24 | margin = 0.2
25 | swap = True
26 | lamda = 1
27 | lamda_reg = 0.001
28 | lpd_fixed = True
29 |
30 | train_file = gta5_training_queries_baseline.pickle
31 |
32 | [REGISTARTION]
33 | batch_size = 30
34 | epochs = 100
35 | num_points = 4096
36 | iter = 1
37 | overlap = 1.0
38 |
39 | [DOMAIN_ADAPT]
40 | model = FCDor
41 | lr = 1e-4
42 | epochs = 30
43 | scheduler_milestones = 20
44 | weight_decay = 1e-3
45 |
46 | lamda_gd = 0.01
47 | lamda_d = 0.01
48 | loss = MSEWithMask
49 |
50 | train_file = training_queries_baseline.pickle
--------------------------------------------------------------------------------
/config/models/vlpdnet.txt:
--------------------------------------------------------------------------------
1 | # MinkLoc3D model
2 | [MODEL]
3 | model = vLPDNet
4 | mink_quantization_size = 0.01
5 | planes = 32,64,64
6 | layers = 1,1,1
7 | num_top_down = 1
8 | conv0_kernel_size = 5
9 | feature_size = 256
10 | separa_mode = Simple
11 |
12 | [LPD]
13 | emb_dims = 128
14 | featnet = lpdnet
15 | lpd_channels = 16,32,64
16 |
17 |
18 | [REGRESS]
19 | model = vcr
20 | loss = PointMSE
21 | ff_dims = 128
--------------------------------------------------------------------------------
/dataloader/dataset_utils.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 | import MinkowskiEngine as ME
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import DataLoader
9 |
10 | from dataloader.gta5 import GTAV5
11 | from dataloader.oxford import OxfordDataset, TrainTransform, TrainSetTransform, Oxford
12 | from dataloader.samplers import BatchSampler
13 | from misc.utils import MinkLocParams
14 |
15 |
16 | def make_datasets(params: MinkLocParams, debug=False):
17 | # Create training and validation datasets
18 | datasets = {}
19 | train_transform = TrainTransform(params.aug_mode)
20 | train_set_transform = TrainSetTransform(params.aug_mode)
21 | if debug:
22 | max_elems = 1000
23 | else:
24 | max_elems = None
25 |
26 | datasets['train'] = OxfordDataset(params, params.train_file, train_transform, set_transform=train_set_transform,
27 | max_elems=max_elems)
28 | val_transform = None
29 | if params.val_file is not None:
30 | datasets['val'] = OxfordDataset(params, params.val_file, val_transform)
31 | return datasets
32 |
33 |
34 | def make_eval_dataset(params: MinkLocParams):
35 | # Create evaluation datasets
36 | dataset = OxfordDataset(params, params.test_file, transform=None)
37 | return dataset
38 |
39 |
40 | def sparse_quantize(xyz_batch, params=None):
41 | coords = [
42 | ME.utils.sparse_quantize(e, quantization_size=params.model_params.mink_quantization_size, return_index=True)
43 | for e in xyz_batch]
44 | coords_idx = [e[1] + i * e[1].shape[0] for i, e in enumerate(coords)]
45 | coords_idx = torch.cat(coords_idx)
46 | coords = [e[0] for e in coords]
47 | coords = ME.utils.batched_coordinates(coords)
48 |
49 | # Assign a dummy feature equal to 1 to each point
50 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation
51 | feats = torch.ones((coords.shape[0], 1), dtype=torch.float32)
52 | batch = {'coords': coords, 'features': feats, 'index': coords_idx, 'cloud': xyz_batch}
53 |
54 | return batch
55 |
56 |
57 | def make_collate_fn(dataset: OxfordDataset, params=None):
58 | # set_transform: the transform to be applied to all batch elements
59 | def collate_fn(data_list):
60 | # Constructs a batch object
61 | batch_size = len(data_list)
62 | clouds = [e[0] for e in data_list]
63 | labels = [e[1] for e in data_list]
64 | xyz_batch = torch.stack(clouds, dim=0) # Produces (batch_size, n_points, 3) tensor
65 | if dataset.set_transform is not None:
66 | # Apply the same transformation on all dataset elements
67 | xyz_batch, R_set, t_set = dataset.set_transform(xyz_batch)
68 | else:
69 | R_set = np.eye(3)
70 | t_set = np.zeros((3, 1))
71 |
72 | batch = sparse_quantize(xyz_batch, params)
73 |
74 | clouds_source = [e[5] for e in data_list]
75 | xyz_batch_source = torch.stack(clouds_source, dim=0) # Produces (batch_size, n_points, 3) tensor
76 | source_batch = sparse_quantize(xyz_batch_source, params)
77 |
78 | if data_list[0][2] is not None:
79 | clouds_da = [e[2][0] for e in data_list]
80 | xyz_batch_da = torch.stack(clouds_da, dim=0) # Produces (batch_size, n_points, 3) tensor
81 | da_batch = sparse_quantize(xyz_batch_da, params)
82 | else:
83 | da_batch = None
84 |
85 | # Compute positives and negatives mask
86 | # dataset.queries[label]['positives'] is bitarray
87 | positives_mask = [[dataset.queries[label]['positives'][e] for e in labels] for label in labels]
88 | negatives_mask = [[dataset.queries[label]['negatives'][e] for e in labels] for label in labels]
89 |
90 | positives_mask = torch.tensor(positives_mask)
91 | negatives_mask = torch.tensor(negatives_mask)
92 |
93 | R2 = [e[3].reshape(1, 3, 3) for e in data_list]
94 | R2 = torch.tensor(np.concatenate(R2, axis=0), dtype=torch.float32)
95 | t2 = [e[4].reshape(1, 3, 1) for e in data_list]
96 | t2 = torch.tensor(np.concatenate(t2, axis=0), dtype=torch.float32)
97 |
98 | R_set = torch.tensor(R_set.reshape(1, 3, 3), dtype=torch.float32).repeat(batch_size, 1, 1)
99 | t_set = torch.tensor(t_set.reshape(1, 3, 1), dtype=torch.float32).repeat(batch_size, 1, 1)
100 |
101 | R = torch.bmm(R_set, R2)
102 | t = t_set + torch.bmm(R_set, t2)
103 |
104 | gt_T = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1)
105 | gt_T[:, :3, :3] = R
106 | gt_T[:, :3, 3] = t.reshape(batch_size, 3)
107 |
108 | # Returns (batch_size, n_points, 3) tensor and positives_mask and
109 | # negatives_mask which are batch_size x batch_size boolean tensors
110 | return batch, positives_mask, negatives_mask, da_batch, R.float(), t.float(), source_batch, gt_T.float()
111 |
112 | return collate_fn
113 |
114 |
115 | def make_dataloaders(params: MinkLocParams, debug=False):
116 | """
117 | Create training and validation dataloaders that return groups of k=2 similar elements
118 | :param train_params:
119 | :param model_params:
120 | :return:
121 | """
122 | datasets = make_datasets(params, debug=debug)
123 |
124 | dataloders = {}
125 | train_sampler = BatchSampler(datasets['train'], batch_size=params.batch_size,
126 | batch_size_limit=params.batch_size_limit,
127 | batch_expansion_rate=params.batch_expansion_rate)
128 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch
129 | train_collate_fn = make_collate_fn(datasets['train'], params)
130 | dataloders['train'] = DataLoader(datasets['train'], batch_sampler=train_sampler, collate_fn=train_collate_fn,
131 | num_workers=params.num_workers, pin_memory=True)
132 |
133 | if 'val' in datasets:
134 | val_sampler = BatchSampler(datasets['val'], batch_size=params.batch_size)
135 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch
136 | # Currently validation dataset has empty set_transform function, but it may change in the future
137 | val_collate_fn = make_collate_fn(datasets['val'], params)
138 | dataloders['val'] = DataLoader(datasets['val'], batch_sampler=val_sampler, collate_fn=val_collate_fn,
139 | num_workers=params.num_workers, pin_memory=True)
140 | if params.is_register:
141 | if not "gta" in params.train_file.lower():
142 | train_loader = DataLoader(
143 | Oxford(params=params, partition='train'),
144 | batch_size=params.reg.batch_size, shuffle=True, drop_last=True, num_workers=16)
145 | test_loader = DataLoader(
146 | Oxford(params=params, partition='test'),
147 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16)
148 | else:
149 | train_loader = DataLoader(
150 | GTAV5(params=params, partition='train'),
151 | batch_size=params.reg.batch_size, shuffle=True, drop_last=True, num_workers=16)
152 | test_loader = DataLoader(
153 | GTAV5(params=params, partition='test'),
154 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16)
155 | dataloders['reg_train'] = train_loader
156 | dataloders['reg_test'] = test_loader
157 |
158 | return dataloders
159 |
--------------------------------------------------------------------------------
/dataloader/gta5.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import numpy as np
5 | import open3d as o3d
6 | from scipy.spatial.transform import Rotation
7 | from sklearn.neighbors import NearestNeighbors
8 | from torch.utils.data import Dataset
9 |
10 | from misc.utils import MinkLocParams
11 |
12 |
13 | def load_data(params, partition):
14 | query_filepath = os.path.join(params.queries_folder, params.train_file)
15 | with open(query_filepath, 'rb') as handle:
16 | # key:{'query':file,'positives':[files],'negatives:[files], 'neighbors':[keys]}
17 | queries = pickle.load(handle)
18 | num = len(queries)
19 | all_data = queries
20 | # if partition=='train':
21 | # for idx in range(num):
22 | # if idx % 4!=0:
23 | # all_data.append(queries[idx])
24 | # else:
25 | # for idx in range(num):
26 | # if idx % 4==0:
27 | # all_data.append(queries[idx])
28 | return all_data
29 |
30 |
31 | class GTAV5(Dataset):
32 | def __init__(self, params: MinkLocParams, partition='train'):
33 | self.num_points = params.reg.num_points
34 | self.overlap = params.reg.overlap
35 | self.partition = partition
36 | print('Load GTAV5 Dataset')
37 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
38 | self.DATA_DIR = os.path.join(BASE_DIR, '../../dataset/benchmark_datasets/')
39 | self.data = load_data(params, partition)
40 | # self.partial = True
41 |
42 | def __getitem__(self, item):
43 |
44 | if self.partition != 'train':
45 | np.random.seed(item)
46 |
47 | # [num,num_dim]
48 | query_dict = self.data[item]
49 |
50 | filename = os.path.join(self.DATA_DIR, query_dict['query'])
51 | pointcloud = np.fromfile(filename, dtype=np.float64)
52 | pointcloud = np.reshape(pointcloud, (pointcloud.shape[0] // 3, 3))
53 | pointcloud1 = (np.random.permutation(pointcloud)[: self.num_points]).T
54 |
55 | num_samples = len(query_dict['positives'])
56 | pos_idx = np.random.randint(num_samples)
57 | positive_dict = self.data[query_dict['positives'][pos_idx]]
58 | filename = os.path.join(self.DATA_DIR, positive_dict['query'])
59 | pointcloud = np.fromfile(filename, dtype=np.float64)
60 | pointcloud = np.reshape(pointcloud, (pointcloud.shape[0] // 3, 3))
61 | pointcloud2 = (np.random.permutation(pointcloud)[: self.num_points]).T
62 |
63 | positive_T = query_dict['positives_T'][pos_idx]
64 |
65 | rotation_ab = positive_T[:3, :3]
66 | translation_ab = positive_T[:3, 3].reshape(3, 1)
67 | pointcloud1 = rotation_ab @ pointcloud1 + translation_ab
68 |
69 | # self.visual_pcl_simple(pcl1=pointcloud1.T, pcl2=pointcloud2.T)
70 |
71 | anglex = (np.random.uniform()) * 5.0 / 180.0 * np.pi
72 | angley = (np.random.uniform()) * 5.0 / 180.0 * np.pi
73 | anglez = (np.random.uniform()) * 35.0 / 180.0 * np.pi
74 |
75 | cosx = np.cos(anglex)
76 | cosy = np.cos(angley)
77 | cosz = np.cos(anglez)
78 | sinx = np.sin(anglex)
79 | siny = np.sin(angley)
80 | sinz = np.sin(anglez)
81 | Rx = np.array([[1, 0, 0],
82 | [0, cosx, -sinx],
83 | [0, sinx, cosx]])
84 | Ry = np.array([[cosy, 0, siny],
85 | [0, 1, 0],
86 | [-siny, 0, cosy]])
87 | Rz = np.array([[cosz, -sinz, 0],
88 | [sinz, cosz, 0],
89 | [0, 0, 1]])
90 | R_ab = Rx.dot(Ry).dot(Rz)
91 | R_ba = R_ab.T
92 |
93 | if self.partition == 'train':
94 | translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),
95 | np.random.uniform(-0.5, 0.5)])
96 | else:
97 | translation_ab = np.array([np.random.uniform(-0.5, 0.5), np.random.uniform(-0.5, 0.5),
98 | np.random.uniform(-0.5, 0.5)])
99 |
100 | translation_ba = -R_ba.dot(translation_ab)
101 |
102 | rotation_ab = Rotation.from_euler('zyx', [anglez, angley, anglex])
103 | pointcloud2 = rotation_ab.apply(pointcloud2.T).T + np.expand_dims(translation_ab, axis=1)
104 |
105 | euler_ab = np.asarray([anglez, angley, anglex])
106 |
107 | euler_ba = -euler_ab[::-1]
108 |
109 | pointcloud1 = np.random.permutation(pointcloud1.T).T
110 | pointcloud2 = np.random.permutation(pointcloud2.T).T
111 |
112 | # [3,num_points] (3,)
113 | return pointcloud1.astype('float32'), pointcloud2.astype('float32'), R_ab.astype('float32'), \
114 | translation_ab.astype('float32'), R_ba.astype('float32'), translation_ba.astype('float32'), \
115 | euler_ab.astype('float32'), euler_ba.astype('float32'), 0
116 |
117 | def __len__(self):
118 | return len(self.data)
119 | # return 100
120 |
121 | def visual_pcl_simple(self, pcl1, pcl2, name='Open3D Origin'):
122 | # pcl: N,3
123 | pcd1 = o3d.geometry.PointCloud()
124 | pcd1.points = o3d.utility.Vector3dVector(pcl1[:, :3])
125 | pcd2 = o3d.geometry.PointCloud()
126 | pcd2.points = o3d.utility.Vector3dVector(pcl2[:, :3])
127 | pcd1.paint_uniform_color([1, 0.706, 0])
128 | pcd2.paint_uniform_color([0, 0.651, 0.929])
129 | o3d.visualization.draw_geometries([pcd1, pcd2], window_name=name, width=1920, height=1080,
130 | left=50,
131 | top=50,
132 | point_show_normal=False, mesh_show_wireframe=False,
133 | mesh_show_back_face=False)
134 |
135 | def nearest_neighbor(self, dst, reserve):
136 | dst = dst.T
137 | num = np.max([dst.shape[0], dst.shape[1]])
138 | num = int(num * reserve)
139 | src = dst[-1, :].reshape(1, -1)
140 | neigh = NearestNeighbors(n_neighbors=num)
141 | neigh.fit(dst)
142 | indices = neigh.kneighbors(src, return_distance=False)
143 | indices = indices.ravel()
144 | return dst[indices, :].T
145 |
--------------------------------------------------------------------------------
/dataloader/samplers.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski
2 | # Warsaw University of Technology
3 |
4 | import copy
5 | import random
6 |
7 | from torch.utils.data import DataLoader, Sampler
8 |
9 | from dataloader.oxford import OxfordDataset
10 | from misc.log import log_string
11 |
12 | VERBOSE = False
13 |
14 |
15 | class BatchSampler(Sampler):
16 | # Sampler returning list of indices to form a mini-batch
17 | # Samples elements in groups consisting of k=2 similar elements (positives)
18 | # Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k
19 | def __init__(self, dataset: OxfordDataset, batch_size: int, batch_size_limit: int = None,
20 | batch_expansion_rate: float = None):
21 | if batch_expansion_rate is not None:
22 | assert batch_expansion_rate > 1., 'batch_expansion_rate must be greater than 1'
23 | assert batch_size <= batch_size_limit, 'batch_size_limit must be greater or equal to batch_size'
24 |
25 | self.batch_size = batch_size
26 | self.batch_size_limit = batch_size_limit
27 | self.batch_expansion_rate = batch_expansion_rate
28 | self.dataset = dataset
29 | self.k = 2 # Number of positive examples per group must be 2
30 | if self.batch_size < 2 * self.k:
31 | self.batch_size = 2 * self.k
32 | print('WARNING: Batch too small. Batch size increased to {}.'.format(self.batch_size))
33 |
34 | self.batch_idx = [] # Index of elements in each batch (re-generated every epoch)
35 |
36 | self.elems_ndx = {} # Dictionary of point cloud indexes
37 | for ndx in self.dataset.queries:
38 | self.elems_ndx[ndx] = True
39 | # self.generate_batches() # 初始化batch_idx
40 |
41 | def __iter__(self):
42 | # Re-generate batches every epoch
43 | self.generate_batches()
44 | for batch in self.batch_idx:
45 | yield batch
46 |
47 | def __len(self):
48 | return len(self.batch_idx)
49 |
50 | def __len__(self) -> int:
51 | return len(self.batch_idx)
52 |
53 | def expand_batch(self):
54 | if self.batch_expansion_rate is None:
55 | print('WARNING: batch_expansion_rate is None')
56 | return
57 |
58 | if self.batch_size >= self.batch_size_limit:
59 | return
60 |
61 | old_batch_size = self.batch_size
62 | self.batch_size = int(self.batch_size * self.batch_expansion_rate)
63 | self.batch_size = min(self.batch_size, self.batch_size_limit)
64 | log_string('=> Batch size increased from: {} to {}'.format(old_batch_size, self.batch_size))
65 |
66 | def generate_batches(self):
67 | # Generate training/evaluation batches.
68 | # batch_idx holds indexes of elements in each batch as a list of lists
69 | self.batch_idx = []
70 |
71 | unused_elements_ndx = copy.deepcopy(self.elems_ndx)
72 | current_batch = []
73 |
74 | assert self.k == 2, 'sampler can sample only k=2 elements from the same class'
75 |
76 | while True:
77 | if len(current_batch) >= self.batch_size or len(unused_elements_ndx) == 0:
78 | # Flush out a new batch and reinitialize a list of available location
79 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more
80 | # elements to process
81 | if len(current_batch) >= 2 * self.k:
82 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible
83 | # to find negative examples in the batch
84 | assert len(current_batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(current_batch))
85 | self.batch_idx.append(current_batch)
86 | current_batch = []
87 | if len(unused_elements_ndx) == 0:
88 | break
89 |
90 | # Add k=2 similar elements to the batch
91 | selected_element = random.choice(list(unused_elements_ndx))
92 | unused_elements_ndx.pop(selected_element)
93 | positives = self.dataset.get_positives_ndx(selected_element)
94 | if len(positives) == 0:
95 | # Broken dataset element without any positives
96 | continue
97 |
98 | unused_positives = [e for e in positives if e in unused_elements_ndx]
99 | # If there're unused elements similar to selected_element, sample from them
100 | # otherwise sample from all similar elements
101 | if len(unused_positives) > 0:
102 | second_positive = random.choice(unused_positives)
103 | unused_elements_ndx.pop(second_positive)
104 | else:
105 | second_positive = random.choice(positives)
106 |
107 | # 每个batch里都有batch_size/2对正样本对
108 | current_batch += [selected_element, second_positive]
109 |
110 | for batch in self.batch_idx:
111 | assert len(batch) % self.k == 0, 'Incorrect bach size: {}'.format(len(batch))
112 |
113 |
114 | if __name__ == '__main__':
115 | dataset_path = '/media/sf_Datasets/PointNetVLAD'
116 | query_filename = 'test_queries_baseline.pickle'
117 |
118 | from configparser import ConfigParser
119 |
120 | config = ConfigParser()
121 | config.dataset_path = dataset_path
122 |
123 | ds = OxfordDataset(config, query_filename)
124 | sampler = BatchSampler(ds, batch_size=16)
125 | dataloader = DataLoader(ds, batch_sampler=sampler)
126 | e = ds[0]
127 | res = next(iter(dataloader))
128 | log_string(res)
129 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 | import argparse
6 | import json
7 | import os
8 | import pickle
9 |
10 | import numpy as np
11 | import torch
12 | from sklearn.neighbors import KDTree
13 | from tqdm import tqdm
14 |
15 | from misc.log import log_string
16 | from misc.utils import MinkLocParams
17 | from models.model_factory import model_factory, load_weights
18 | from training.reg_train import testVCRNet
19 | from torch.utils.data import DataLoader
20 | from dataloader.oxford import Oxford
21 | DEBUG = False
22 |
23 |
24 | def evaluate(model, device, params, log=False):
25 | # Run evaluation on all eval datasets
26 |
27 | if DEBUG:
28 | params.eval_database_files = params.eval_database_files[0:1]
29 | params.eval_query_files = params.eval_query_files[0:1]
30 |
31 | assert len(params.eval_database_files) == len(params.eval_query_files)
32 |
33 | stats = {}
34 | for database_file, query_file in zip(params.eval_database_files, params.eval_query_files):
35 | # Extract location name from query and database files
36 | location_name = database_file.split('_')[0]
37 | temp = query_file.split('_')[0]
38 | assert location_name == temp, 'Database location: {} does not match query location: {}'.format(database_file,
39 | query_file)
40 |
41 | p = os.path.join(params.queries_folder, database_file)
42 | with open(p, 'rb') as f:
43 | database_sets = pickle.load(f)
44 |
45 | p = os.path.join(params.queries_folder, query_file)
46 | with open(p, 'rb') as f:
47 | query_sets = pickle.load(f)
48 | if log:
49 | print('Evaluation:{} on {}'.format(database_file, query_file))
50 | temp = evaluate_dataset(model, device, params, database_sets, query_sets, log=log)
51 | stats[location_name] = temp
52 |
53 | for database_name in stats:
54 | log_string('Dataset: {} '.format(database_name), end='')
55 | t = 'Avg. top 1 recall: {:.2f} Avg. top 1% recall: {:.2f} Avg. similarity: {:.4f}'
56 | log_string(t.format(stats[database_name]['ave_recall'][0],
57 | stats[database_name]['ave_one_percent_recall'],
58 | stats[database_name]['average_similarity']))
59 | return stats
60 |
61 |
62 | def evaluate_dataset(model, device, params, database_sets, query_sets, log=False):
63 | # Run evaluation on a single dataset
64 | recall = np.zeros(25)
65 | count = 0
66 | similarity = []
67 | one_percent_recall = []
68 |
69 | database_embeddings = []
70 | query_embeddings = []
71 |
72 | model.eval()
73 | if log:
74 | tqdm_ = lambda x, desc: tqdm(x, desc=desc)
75 | else:
76 | tqdm_ = lambda x, desc: x
77 |
78 | torch.cuda.empty_cache()
79 | for set in tqdm_(database_sets, 'Database'):
80 | database_embeddings.append(get_latent_vectors(model, set, device, params))
81 |
82 | for set in tqdm_(query_sets, ' Query'):
83 | query_embeddings.append(get_latent_vectors(model, set, device, params))
84 |
85 | for i in tqdm_(range(len(query_sets)), ' Test'):
86 | for j in range(len(query_sets)):
87 | if i == j:
88 | continue
89 | pair_recall, pair_similarity, pair_opr = get_recall(i, j, database_embeddings, query_embeddings, query_sets,
90 | database_sets, log=log)
91 | recall += np.array(pair_recall)
92 | count += 1
93 | one_percent_recall.append(pair_opr)
94 | for x in pair_similarity:
95 | similarity.append(x)
96 |
97 | ave_recall = recall / count
98 | average_similarity = np.mean(similarity)
99 | ave_one_percent_recall = np.mean(one_percent_recall)
100 | stats = {'ave_one_percent_recall': ave_one_percent_recall, 'ave_recall': ave_recall,
101 | 'average_similarity': average_similarity, 'Loc rebuild': 1}
102 | return stats
103 |
104 |
105 | def load_pc(file_name, params, make_tensor=True):
106 | # returns Nx3 matrix
107 | file_path = os.path.join(params.dataset_folder, file_name)
108 | pc = np.fromfile(file_path, dtype=np.float64)
109 | # coords are within -1..1 range in each dimension
110 | assert pc.shape[0] == params.num_points * 3, "Error in point cloud shape: {}".format(file_path)
111 | pc = np.reshape(pc, (pc.shape[0] // 3, 3))
112 | if make_tensor:
113 | pc = torch.tensor(pc, dtype=torch.float)
114 | return pc
115 |
116 |
117 | def load_pc_files(elem_ndxs, set, params):
118 | pcs = []
119 | for elem_ndx in elem_ndxs:
120 | pc = load_pc(set[elem_ndx]["query"], params, make_tensor=False)
121 | if (pc.shape[0] != 4096):
122 | assert 0, 'pc.shape[0] != 4096'
123 | pcs.append(pc)
124 | pcs = np.asarray(pcs)
125 |
126 | pcs = torch.tensor(pcs, dtype=torch.float)
127 | return pcs
128 |
129 |
130 | def genrate_batch(num, batch_size):
131 | sets = np.arange(0, num, batch_size)
132 | sets = sets.tolist()
133 | if sets[-1] != num:
134 | sets.append(num)
135 | return sets
136 |
137 |
138 | def get_latent_vectors(model, set, device, params: MinkLocParams):
139 | if DEBUG:
140 | embeddings = np.random.rand(len(set), 256)
141 | return embeddings
142 |
143 | model.eval()
144 | embeddings_l = []
145 |
146 | batch_set = genrate_batch(len(set), int(params.batch_size * 1.5))
147 | for batch_id in range(len(batch_set) - 1):
148 | elem_ndx = np.arange(batch_set[batch_id], batch_set[batch_id + 1])
149 | x = load_pc_files(elem_ndx, set, params)
150 | with torch.no_grad():
151 | batch = {'cloud': x.cuda()}
152 | embedding = model(target_batch=batch)['embeddings']
153 | # embedding is (1, 1024) tensor
154 | if params.normalize_embeddings:
155 | embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) # Normalize embeddings
156 |
157 | embedding = embedding.detach().cpu().numpy()
158 | embeddings_l.append(embedding)
159 |
160 | embeddings = np.vstack(embeddings_l)
161 | return embeddings
162 |
163 |
164 | def get_recall(m, n, database_vectors, query_vectors, query_sets, database_sets, log=False):
165 | # Original PointNetVLAD code
166 | database_output = database_vectors[m]
167 | queries_output = query_vectors[n]
168 |
169 | database_nbrs = KDTree(database_output)
170 |
171 | num_neighbors = 25
172 | recall = [0] * num_neighbors
173 |
174 | top1_similarity_score = []
175 | one_percent_retrieved = 0
176 | threshold = max(int(round(len(database_output) / 100.0)), 1)
177 |
178 | num_evaluated = 0
179 | for i in range(len(queries_output)):
180 | # i is query element ndx
181 | query_details = query_sets[n][i] # {'query': path, 'northing': , 'easting': }
182 | true_neighbors = query_details[m]
183 | if len(true_neighbors) == 0:
184 | continue
185 | num_evaluated += 1
186 | distances, indices = database_nbrs.query(np.array([queries_output[i]]), k=num_neighbors)
187 |
188 | for j in range(len(indices[0])):
189 | if indices[0][j] in true_neighbors:
190 | if j == 0:
191 | similarity = np.dot(queries_output[i], database_output[indices[0][j]])
192 | top1_similarity_score.append(similarity)
193 | recall[j] += 1
194 | break
195 |
196 | if len(list(set(indices[0][0:threshold]).intersection(set(true_neighbors)))) > 0:
197 | one_percent_retrieved += 1
198 |
199 | one_percent_recall = (one_percent_retrieved / float(num_evaluated)) * 100
200 | recall = (np.cumsum(recall) / float(num_evaluated)) * 100
201 | # log_string(recall)
202 | # log_string(np.mean(top1_similarity_score))
203 | # log_string(one_percent_recall)
204 | return recall, top1_similarity_score, one_percent_recall
205 |
206 |
207 | def export_eval_stats(file_name, prefix, eval_stats):
208 | s = prefix
209 | ave_1p_recall_l = []
210 | ave_recall_l = []
211 | # Print results on the final model
212 | with open(file_name, "a") as f:
213 | for ds in ['oxford', 'university', 'residential', 'business']:
214 | ave_1p_recall = eval_stats[ds]['ave_one_percent_recall']
215 | ave_1p_recall_l.append(ave_1p_recall)
216 | ave_recall = eval_stats[ds]['ave_recall'][0]
217 | ave_recall_l.append(ave_recall)
218 | s += ", {:0.2f}, {:0.2f}".format(ave_1p_recall, ave_recall)
219 |
220 | mean_1p_recall = np.mean(ave_1p_recall_l)
221 | mean_recall = np.mean(ave_recall_l)
222 | s += ", {:0.2f}, {:0.2f}\n".format(mean_1p_recall, mean_recall)
223 | f.write(s)
224 |
225 |
226 | if __name__ == "__main__":
227 | parser = argparse.ArgumentParser(description='Evaluate model on PointNetVLAD (Oxford) dataset')
228 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
229 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file')
230 | parser.add_argument('--weights', type=str, required=False, help='Trained model weights')
231 | parser.add_argument('--debug', dest='debug', action='store_true')
232 | parser.set_defaults(debug=False)
233 | parser.add_argument('--visualize', dest='visualize', action='store_true')
234 | parser.set_defaults(visualize=False)
235 | parser.add_argument('--savejson', type=str, default='', help='')
236 | parser.add_argument('--eval_reg', type=str, default="")
237 |
238 | args = parser.parse_args()
239 | log_string('Config path: {}'.format(args.config))
240 | log_string('Model config path: {}'.format(args.model_config))
241 | if args.weights is None:
242 | w = 'RANDOM WEIGHTS'
243 | else:
244 | w = args.weights
245 | log_string('Weights: {}'.format(w))
246 | log_string('Debug mode: {}'.format(args.debug))
247 | log_string('Visualize: {}'.format(args.visualize))
248 |
249 | params = MinkLocParams(args.config, args.model_config)
250 | params.print()
251 |
252 | model, device, d_model, vcr_model = model_factory(params)
253 |
254 | load_weights(args.weights, model)
255 |
256 | if args.eval_reg != "":
257 | test_loader = DataLoader(
258 | Oxford(params=params, partition='test'),
259 | batch_size=int(params.reg.batch_size * 1.2), shuffle=False, drop_last=False, num_workers=16)
260 | checkpoint_dict = torch.load(args.eval_reg, map_location=torch.device('cpu'))
261 | vcr_model.load_state_dict(checkpoint_dict, strict=True)
262 | log_string('load vcr_model with {}'.format(args.eval_reg))
263 | testVCRNet(1, vcr_model, test_loader)
264 | else:
265 | stats = evaluate(model, device, params, True)
266 |
267 | for database_name in stats:
268 | log_string(' Avg. recall @N:')
269 | log_string(str(stats[database_name]['ave_recall']))
270 |
271 | if len(args.savejson) > 0:
272 | result = {}
273 | result['trainfile'] = params.train_file
274 | result['weightfile'] = args.weights
275 | result['lr'] = params.lr
276 | result['lamda_g'] = params.lamda_gd
277 | result['weight_decay'] = params.weight_decay
278 | result['domain_adapt'] = params.domain_adapt
279 | if params.domain_adapt:
280 | result['lr_d'] = params.d_lr
281 | result['lamda_d'] = params.lamda_d
282 | result['weight_decay_d'] = params.d_weight_decay
283 | else:
284 | result['lr_d'] = None
285 | result['lamda_d'] = None
286 | result['weight_decay_d'] = None
287 |
288 | for database_name in stats:
289 | result_database = {}
290 | result_database['recall_top1'] = float(stats[database_name]['ave_recall'][0])
291 | result_database['recall_top1per'] = float(stats[database_name]['ave_one_percent_recall'])
292 | result_database['similarity'] = float(stats[database_name]['average_similarity'])
293 | result[database_name] = result_database
294 | json.dump(result, open(args.savejson, 'w'))
295 |
--------------------------------------------------------------------------------
/generating_queries/generate_test_sets.py:
--------------------------------------------------------------------------------
1 | # Code taken from PointNetVLAD repo: https://github.com/mikacuy/pointnetvlad
2 |
3 | import os
4 | import pickle
5 |
6 | import numpy as np
7 | import pandas as pd
8 | from sklearn.neighbors import KDTree
9 |
10 | #####For training and test data split#####
11 | x_width = 150
12 | y_width = 150
13 |
14 | # For Oxford
15 | p1 = [5735712.768124, 620084.402381]
16 | p2 = [5735611.299219, 620540.270327]
17 | p3 = [5735237.358209, 620543.094379]
18 | p4 = [5734749.303802, 619932.693364]
19 |
20 | # For University Sector
21 | p5 = [363621.292362, 142864.19756]
22 | p6 = [364788.795462, 143125.746609]
23 | p7 = [363597.507711, 144011.414174]
24 |
25 | # For Residential Area
26 | p8 = [360895.486453, 144999.915143]
27 | p9 = [362357.024536, 144894.825301]
28 | p10 = [361368.907155, 145209.663042]
29 |
30 | p_dict = {"oxford": [p1, p2, p3, p4], "university": [p5, p6, p7], "residential": [p8, p9, p10], "business": [],
31 | "lgsvl": []}
32 | p_dict_thresold = {"oxford": 25, "university": 25, "residential": 25, "business": 25, "lgsvl": 60}
33 |
34 |
35 | def check_in_test_set(northing, easting, points, x_width, y_width):
36 | in_test_set = False
37 | if points == []:
38 | return True
39 | for point in points:
40 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[
41 | 1] - y_width < easting and easting < point[1] + y_width):
42 | in_test_set = True
43 | break
44 | return in_test_set
45 |
46 |
47 | ##########################################
48 |
49 | def output_to_file(output, filename):
50 | if not os.path.exists('pickles'):
51 | os.mkdir('pickles')
52 | filename = os.path.join('pickles', filename)
53 |
54 | with open(filename, 'wb') as handle:
55 | pickle.dump(output, handle, protocol=pickle.DEFAULT_PROTOCOL)
56 | print("Done ", filename)
57 |
58 |
59 | def construct_query_and_database_sets(base_path, runs_folder, folders, pointcloud_fols, filename, p, output_name):
60 | database_trees = []
61 | test_trees = []
62 | for folder in folders:
63 | print(folder)
64 | df_database = pd.DataFrame(columns=['file', 'northing', 'easting'])
65 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting'])
66 |
67 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',')
68 | for index, row in df_locations.iterrows():
69 | # entire business district is in the test set
70 | if (output_name == "business"):
71 | df_test = df_test.append(row, ignore_index=True)
72 | elif (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)):
73 | df_test = df_test.append(row, ignore_index=True)
74 | df_database = df_database.append(row, ignore_index=True)
75 |
76 | database_tree = KDTree(df_database[['northing', 'easting']])
77 | test_tree = KDTree(df_test[['northing', 'easting']])
78 | database_trees.append(database_tree)
79 | test_trees.append(test_tree)
80 |
81 | test_sets = []
82 | database_sets = []
83 | for folder in folders:
84 | database = {}
85 | test = {}
86 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',')
87 | df_locations['timestamp'] = runs_folder + folder + pointcloud_fols + df_locations['timestamp'].astype(
88 | str) + '.bin'
89 |
90 | df_locations = df_locations.rename(columns={'timestamp': 'file'})
91 | for index, row in df_locations.iterrows():
92 | # entire business district is in the test set
93 | if (output_name == "business"):
94 | test[len(test.keys())] = {'query': row['file'], 'northing': row['northing'], 'easting': row['easting']}
95 | elif (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)):
96 | test[len(test.keys())] = {'query': row['file'], 'northing': row['northing'], 'easting': row['easting']}
97 | database[len(database.keys())] = {'query': row['file'], 'northing': row['northing'],
98 | 'easting': row['easting']}
99 |
100 | database_sets.append(database)
101 | test_sets.append(test)
102 |
103 | for i in range(len(database_sets)):
104 | tree = database_trees[i]
105 | for j in range(len(test_sets)):
106 | if (i == j):
107 | continue
108 | for key in range(len(test_sets[j].keys())):
109 | coor = np.array([[test_sets[j][key]["northing"], test_sets[j][key]["easting"]]])
110 | index = tree.query_radius(coor, r=p_dict_thresold[output_name])
111 | # indices of the positive matches in database i of each query (key) in test set j
112 | test_sets[j][key][i] = index[0].tolist()
113 |
114 | output_to_file(database_sets, output_name + '_evaluation_database.pickle')
115 | output_to_file(test_sets, output_name + '_evaluation_query.pickle')
116 |
117 |
118 | ###Building database and query files for evaluation
119 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
120 | base_path = "../benchmark_datasets/"
121 |
122 | # For Oxford
123 | folders = []
124 | runs_folder = "oxford/"
125 | all_folders = sorted(os.listdir(os.path.join(BASE_DIR, base_path, runs_folder)))
126 | index_list = [5, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 24, 31, 32, 33, 38, 39, 43, 44]
127 | print(len(index_list))
128 | for index in index_list:
129 | folders.append(all_folders[index])
130 |
131 | print(folders)
132 | construct_query_and_database_sets(base_path, runs_folder, folders, "/pointcloud_20m/", "pointcloud_locations_20m.csv",
133 | p_dict["oxford"], "oxford")
134 |
135 | # For lgsvl District
136 | # runs_folder = "lgsvl_new/"
137 | # all_folders=sorted(os.listdir(os.path.join(BASE_DIR,base_path,runs_folder)))
138 | # print(all_folders)
139 | # construct_query_and_database_sets(base_path, runs_folder, all_folders, "/pointcloud_20m/", "pointcloud_locations_20m.csv", p_dict["lgsvl"], "lgsvl")
140 |
--------------------------------------------------------------------------------
/generating_queries/generate_test_sets_gtav5.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import sys
4 |
5 | import numpy as np
6 | import pandas as pd
7 | from sklearn.neighbors import KDTree
8 |
9 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
10 | sys.path.append(BASE_DIR)
11 | sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "..")))
12 | from tqdm import tqdm
13 | from glob import glob
14 | from registration import registration_withinit
15 |
16 | recall_dis = 1.25
17 |
18 |
19 | def main():
20 | # Building database and query files for evaluation
21 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
22 | base_path = '../benchmark_datasets/'
23 |
24 | # For Oxford
25 | runs_folder = "GTA5/"
26 | all_folders = sorted(glob(os.path.join(BASE_DIR, base_path, runs_folder) + '/round*'))
27 | for i in range(len(all_folders)):
28 | all_folders[i] = os.path.basename(all_folders[i])
29 | print("start process")
30 | construct_query_and_database_sets(base_path, runs_folder, all_folders, "/pcl/",
31 | "pointcloud_locations_20m.csv")
32 |
33 |
34 | def check_in_test_set():
35 | return True
36 |
37 |
38 | ##########################################
39 | def output_to_file(output, filename):
40 | with open(filename, 'wb') as handle:
41 | pickle.dump(output, handle, protocol=pickle.HIGHEST_PROTOCOL)
42 | print("Done ", filename)
43 |
44 |
45 | def construct_query_and_database_sets(base_path, runs_folder, folders, pointcloud_fols, filename):
46 | database_trees = []
47 | test_trees = []
48 | for folder in tqdm(folders):
49 | # print(folder)
50 | df_database = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude'])
51 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude'])
52 |
53 | df_locations = pd.read_csv(os.path.join(
54 | base_path, runs_folder, folder, filename), sep=',')
55 |
56 | for index, row in df_locations.iterrows():
57 | # entire business district is in the test set
58 | if (check_in_test_set()):
59 | df_test = df_test.append(row, ignore_index=True)
60 | df_database = df_database.append(row, ignore_index=True)
61 |
62 | database_tree = KDTree(df_database[['northing', 'easting', 'altitude']])
63 | test_tree = KDTree(df_test[['northing', 'easting', 'altitude']])
64 | database_trees.append(database_tree)
65 | test_trees.append(test_tree)
66 |
67 | test_sets = []
68 | database_sets = []
69 | for folder in tqdm(folders):
70 | database = {}
71 | test = {}
72 | df_locations = pd.read_csv(os.path.join(
73 | base_path, runs_folder, folder, filename), sep=',')
74 | df_locations['timestamp'] = runs_folder + folder + \
75 | pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin'
76 | df_locations = df_locations.rename(columns={'timestamp': 'file'})
77 | for index, row in df_locations.iterrows():
78 | # entire business district is in the test set
79 | if (check_in_test_set()):
80 | test[len(test.keys())] = {
81 | 'query': row['file'], 'northing': row['northing'], 'easting': row['easting'],
82 | 'altitude': row['altitude']}
83 | database[len(database.keys())] = {
84 | 'query': row['file'], 'northing': row['northing'], 'easting': row['easting'],
85 | 'altitude': row['altitude']}
86 | database_sets.append(database)
87 | test_sets.append(test)
88 |
89 | for i in tqdm(range(len(database_sets))):
90 | tree = database_trees[i]
91 | for j in range(len(test_sets)):
92 | if (i == j):
93 | continue
94 | for key in range(len(test_sets[j].keys())):
95 | coor = np.array(
96 | [[test_sets[j][key]["northing"], test_sets[j][key]["easting"], test_sets[j][key]["altitude"]]])
97 | index = tree.query_radius(coor, r=recall_dis)
98 | # indices of the positive matches in database i of each query (key) in test set j
99 | test_sets[j][key][i] = index[0].tolist()
100 |
101 | positives_T = []
102 | for pos_i in test_sets[j][key][i]:
103 | query = test_sets[j][key]["query"]
104 | test_i = test_sets[i][pos_i]
105 | query_i = test_i["query"]
106 | coor_i = np.array([test_i["northing"], test_i["easting"], test_i["altitude"]])
107 | trans_pre = coor.squeeze() - coor_i.squeeze()
108 | T = registration_withinit(base_path + query, base_path + query_i, trans_pre=trans_pre)
109 | positives_T.append(T)
110 | test_sets[j][key]['positives_T'] = positives_T
111 |
112 | if not os.path.exists("pickles/"):
113 | os.mkdir("pickles/")
114 |
115 | output_to_file(database_sets, "pickles/" + 'gta5_evaluation_database.pickle')
116 | output_to_file(test_sets, "pickles/" + 'gta5_evaluation_query.pickle')
117 |
118 |
119 | if __name__ == '__main__':
120 | main()
121 |
--------------------------------------------------------------------------------
/generating_queries/generate_training_tuples_baseline.py:
--------------------------------------------------------------------------------
1 | # Code taken from PointNetVLAD repo: https://github.com/mikacuy/pointnetvlad
2 |
3 | import os
4 | import pickle
5 | import random
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from sklearn.neighbors import KDTree
10 | from tqdm import tqdm
11 |
12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13 | base_path = "../benchmark_datasets/"
14 |
15 | runs_folder = "oxford/"
16 | filename = "pointcloud_locations_20m_10overlap.csv"
17 | pointcloud_fols = "/pointcloud_20m_10overlap/"
18 |
19 | all_folders = sorted(os.listdir(os.path.join(BASE_DIR, base_path, runs_folder)))
20 |
21 | folders = []
22 |
23 | # All runs are used for training (both full and partial)
24 | index_list = range(len(all_folders) - 1)
25 | print("Number of runs: " + str(len(index_list)))
26 | for index in index_list:
27 | folders.append(all_folders[index])
28 | print(folders)
29 |
30 | #####For training and test data split#####
31 | x_width = 150
32 | y_width = 150
33 | p1 = [5735712.768124, 620084.402381]
34 | p2 = [5735611.299219, 620540.270327]
35 | p3 = [5735237.358209, 620543.094379]
36 | p4 = [5734749.303802, 619932.693364]
37 | p = [p1, p2, p3, p4]
38 |
39 |
40 | def check_in_test_set(northing, easting, points, x_width, y_width):
41 | in_test_set = False
42 | for point in points:
43 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[
44 | 1] - y_width < easting and easting < point[1] + y_width):
45 | in_test_set = True
46 | break
47 | return in_test_set
48 |
49 |
50 | ##########################################
51 |
52 |
53 | def construct_query_dict(df_centroids, filename):
54 | tree = KDTree(df_centroids[['northing', 'easting']])
55 | ind_nn = tree.query_radius(df_centroids[['northing', 'easting']], r=10)
56 | ind_r = tree.query_radius(df_centroids[['northing', 'easting']], r=50)
57 | queries = {}
58 | for i in tqdm(range(len(ind_nn))):
59 | query = df_centroids.iloc[i]["file"]
60 | positives = np.setdiff1d(ind_nn[i], [i]).tolist()
61 | negatives = np.setdiff1d(df_centroids.index.values.tolist(), ind_r[i]).tolist()
62 | random.shuffle(negatives)
63 | queries[i] = {"query": query, "positives": positives, "negatives": negatives}
64 |
65 | if not os.path.exists('pickles'):
66 | os.mkdir('pickles')
67 | filename = os.path.join('pickles', filename)
68 | with open(filename, 'wb') as handle:
69 | pickle.dump(queries, handle, protocol=pickle.DEFAULT_PROTOCOL)
70 |
71 | print("Done ", filename)
72 |
73 |
74 | ####Initialize pandas DataFrame
75 | df_train = pd.DataFrame(columns=['file', 'northing', 'easting'])
76 | df_test = pd.DataFrame(columns=['file', 'northing', 'easting'])
77 |
78 | for folder in tqdm(folders):
79 | df_locations = pd.read_csv(os.path.join(base_path, runs_folder, folder, filename), sep=',')
80 | df_locations['timestamp'] = runs_folder + folder + pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin'
81 | df_locations = df_locations.rename(columns={'timestamp': 'file'})
82 |
83 | for index, row in df_locations.iterrows():
84 | if (check_in_test_set(row['northing'], row['easting'], p, x_width, y_width)):
85 | df_test = df_test.append(row, ignore_index=True)
86 | else:
87 | df_train = df_train.append(row, ignore_index=True)
88 |
89 | print("Number of training submaps: " + str(len(df_train['file'])))
90 | print("Number of non-disjoint test submaps: " + str(len(df_test['file'])))
91 | construct_query_dict(df_train, "training_queries_baseline.pickle")
92 | construct_query_dict(df_test, "test_queries_baseline.pickle")
93 |
--------------------------------------------------------------------------------
/generating_queries/generate_training_tuples_baseline_gtav5.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | import sys
5 | from glob import glob
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from sklearn.neighbors import KDTree
10 | from tqdm import tqdm
11 |
12 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
13 | sys.path.append(BASE_DIR)
14 | sys.path.append(os.path.abspath(os.path.join(BASE_DIR, "..")))
15 | from registration import registration_withinit
16 |
17 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
18 | base_path = '../benchmark_datasets/'
19 |
20 | runs_folder = "GTA5/"
21 | filename = "pointcloud_locations_20m_10overlap.csv"
22 | pointcloud_fols = "/pcl/"
23 |
24 | all_folders = sorted(glob(os.path.join(BASE_DIR, base_path, runs_folder) + '/round*'))
25 | for i in range(len(all_folders)):
26 | all_folders[i] = os.path.basename(all_folders[i])
27 |
28 | # 正样本范围 点云单位上正负1
29 | pos_dis = 0.5
30 | # 负样本范围
31 | neg_dis = 4
32 |
33 |
34 | def check_in_test_set(northing, easting, points, x_width, y_width):
35 | in_test_set = False
36 | for point in points:
37 | if (point[0] - x_width < northing and northing < point[0] + x_width and point[
38 | 1] - y_width < easting and easting < point[1] + y_width):
39 | in_test_set = True
40 | break
41 | return in_test_set
42 |
43 |
44 | ##########################################
45 |
46 |
47 | def construct_query_dict(df_centroids, filename):
48 | tree = KDTree(df_centroids[['northing', 'easting', 'altitude']])
49 | ind_nn = tree.query_radius(df_centroids[['northing', 'easting', 'altitude']], r=pos_dis)
50 | ind_r = tree.query_radius(df_centroids[['northing', 'easting', 'altitude']], r=neg_dis)
51 | queries = {}
52 | for i in tqdm(range(len(ind_nn))):
53 | query = df_centroids.iloc[i]["file"]
54 | positives = np.setdiff1d(ind_nn[i], [i]).tolist()
55 | list_neg = np.setdiff1d(df_centroids.index.values.tolist(), ind_r[i]).tolist()
56 | positives_T = []
57 | for pos_i in range(len(positives)):
58 | trans_pre = np.asarray(df_centroids.iloc[i][['northing', 'easting', 'altitude']]) - np.asarray(
59 | df_centroids.iloc[positives[pos_i]][['northing', 'easting', 'altitude']])
60 | T = registration_withinit(base_path + query, base_path + df_centroids.iloc[positives[pos_i]]["file"],
61 | trans_pre=trans_pre)
62 | positives_T.append(T)
63 | # if len(list_neg) > 4001:
64 | # random.shuffle(list_neg)
65 | # list_neg = list_neg[:4001]
66 | negatives = list_neg
67 | random.shuffle(negatives)
68 | queries[i] = {"query": query, "positives": positives,
69 | "positives_T": positives_T, "negatives": negatives}
70 |
71 | with open(filename, 'wb') as handle:
72 | pickle.dump(queries, handle, protocol=pickle.HIGHEST_PROTOCOL)
73 |
74 | print("Done ", filename)
75 |
76 |
77 | # Initialize pandas DataFrame
78 | df_train = pd.DataFrame(columns=['file', 'northing', 'easting', 'altitude'])
79 |
80 | for folder in all_folders:
81 | df_locations = pd.read_csv(os.path.join(
82 | base_path, runs_folder, folder, filename), sep=',')
83 | df_locations['timestamp'] = runs_folder + folder + \
84 | pointcloud_fols + df_locations['timestamp'].astype(str) + '.bin'
85 | df_locations = df_locations.rename(columns={'timestamp': 'file'})
86 |
87 | for index, row in df_locations.iterrows():
88 | df_train = df_train.append(row, ignore_index=True)
89 |
90 | if not os.path.exists("pickles/"):
91 | os.mkdir("pickles/")
92 |
93 | print("Number of training submaps: " + str(len(df_train['file'])))
94 | construct_query_dict(df_train, "pickles/" + "gta5_training_queries_baseline.pickle")
95 |
--------------------------------------------------------------------------------
/generating_queries/registration.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import sys
3 |
4 | import numpy as np
5 | import open3d as o3d
6 |
7 | # monkey patches visualization and provides helpers to load geometries
8 | sys.path.append('..')
9 |
10 |
11 | def draw_registration_result(source, target, transformation):
12 | source_temp = copy.deepcopy(source)
13 | target_temp = copy.deepcopy(target)
14 | source_temp.paint_uniform_color([1, 0.706, 0])
15 | target_temp.paint_uniform_color([0, 0.651, 0.929])
16 | # a=np.asarray(source_temp.points)
17 | source_temp.transform(transformation)
18 | # b=np.asarray(source_temp.points)
19 | # o3d.visualization.draw_geometries([source_temp, target_temp])
20 |
21 |
22 | def preprocess_point_cloud(pcd, voxel_size, is_voxel=True):
23 | if is_voxel:
24 | # print(":: Downsample with a voxel size %.3f." % voxel_size)
25 | pcd_down = pcd.voxel_down_sample(voxel_size)
26 | else:
27 | pcd_down = pcd
28 | radius_normal = voxel_size * 2
29 | # print(":: Estimate normal with search radius %.3f." % radius_normal)
30 | pcd_down.estimate_normals(
31 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_normal, max_nn=30))
32 |
33 | radius_feature = voxel_size * 5
34 | # print(":: Compute FPFH feature with search radius %.3f." % radius_feature)
35 | pcd_fpfh = o3d.pipelines.registration.compute_fpfh_feature(
36 | pcd_down,
37 | o3d.geometry.KDTreeSearchParamHybrid(radius=radius_feature, max_nn=100))
38 | return pcd_down, pcd_fpfh
39 |
40 |
41 | def prepare_dataset(voxel_size):
42 | print(":: Load two point clouds and disturb initial pose.")
43 | source = o3d.io.read_point_cloud("./test_data/cloud_bin_0.pcd")
44 | target = o3d.io.read_point_cloud("./test_data/cloud_bin_1.pcd")
45 | # trans_init = np.asarray([[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0],
46 | # [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 0.0, 1.0]])
47 | # source.transform(trans_init)
48 | # draw_registration_result(source, target, np.identity(4))
49 |
50 | source_down, source_fpfh = preprocess_point_cloud(source, voxel_size)
51 | target_down, target_fpfh = preprocess_point_cloud(target, voxel_size)
52 | return source, target, source_down, target_down, source_fpfh, target_fpfh
53 |
54 |
55 | def execute_global_registration(source_down, target_down, source_fpfh,
56 | target_fpfh, voxel_size):
57 | distance_threshold = voxel_size * 1.5
58 | # print(":: RANSAC registration on downsampled point clouds.")
59 | # print(" Since the downsampling voxel size is %.3f," % voxel_size)
60 | # print(" we use a liberal distance threshold %.3f." % distance_threshold)
61 | result = o3d.pipelines.registration.registration_ransac_based_on_feature_matching(
62 | source_down, target_down, source_fpfh, target_fpfh, distance_threshold,
63 | o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
64 | 4, [
65 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnEdgeLength(
66 | 0.9),
67 | o3d.pipelines.registration.CorrespondenceCheckerBasedOnDistance(
68 | distance_threshold)
69 | ], o3d.pipelines.registration.RANSACConvergenceCriteria(4000000, 500))
70 | return result
71 |
72 |
73 | def refine_registration(source, target, init_transformation, voxel_size=-1):
74 | if voxel_size == -1:
75 | distance_threshold = 0.01
76 | else:
77 | distance_threshold = voxel_size * 0.4
78 | # print(":: Point-to-plane ICP registration is applied on original point")
79 | # print(" clouds to refine the alignment. This time we use a strict")
80 | # print(" distance threshold %.3f." % distance_threshold)
81 | result = o3d.pipelines.registration.registration_icp(
82 | source, target, distance_threshold, init_transformation,
83 | o3d.pipelines.registration.TransformationEstimationPointToPoint())
84 | return result
85 |
86 |
87 | def execute_fast_global_registration(source_down, target_down, source_fpfh,
88 | target_fpfh, voxel_size):
89 | distance_threshold = voxel_size * 0.5
90 | print(":: Apply fast global registration with distance threshold %.3f" \
91 | % distance_threshold)
92 | result = o3d.pipelines.registration.registration_fast_based_on_feature_matching(
93 | source_down, target_down, source_fpfh, target_fpfh,
94 | o3d.pipelines.registration.FastGlobalRegistrationOption(
95 | maximum_correspondence_distance=distance_threshold))
96 | return result
97 |
98 |
99 | def demo_open3d():
100 | voxel_size = 0.05 # means 5cm for this dataset
101 | source, target, source_down, target_down, source_fpfh, target_fpfh = prepare_dataset(
102 | voxel_size)
103 |
104 | result_ransac = execute_global_registration(source_down, target_down,
105 | source_fpfh, target_fpfh,
106 | voxel_size)
107 | draw_registration_result(source_down, target_down, result_ransac.transformation)
108 |
109 | result_fast = execute_fast_global_registration(source_down, target_down,
110 | source_fpfh, target_fpfh,
111 | voxel_size)
112 | draw_registration_result(source_down, target_down, result_fast.transformation)
113 |
114 | result_icp = refine_registration(source, target, result_fast.transformation, voxel_size)
115 | draw_registration_result(source, target, result_icp.transformation)
116 |
117 |
118 | def load_bin(path1, path2, voxel_size):
119 | source = o3d.geometry.PointCloud()
120 | source.points = o3d.utility.Vector3dVector(np.fromfile(path1).reshape([-1, 3]))
121 | target = o3d.geometry.PointCloud()
122 | target.points = o3d.utility.Vector3dVector(np.fromfile(path2).reshape([-1, 3]))
123 |
124 | source_down, source_fpfh = preprocess_point_cloud(source, voxel_size, is_voxel=False)
125 | target_down, target_fpfh = preprocess_point_cloud(target, voxel_size, is_voxel=False)
126 | return source, target, source_down, target_down, source_fpfh, target_fpfh
127 |
128 |
129 | def registration_numpy(pcl1, pcl2, trans_pre=np.eye(4)):
130 | source = o3d.geometry.PointCloud()
131 | source.points = o3d.utility.Vector3dVector(pcl1)
132 | target = o3d.geometry.PointCloud()
133 | target.points = o3d.utility.Vector3dVector(pcl2)
134 |
135 | draw_registration_result(source, target, np.eye(4))
136 | draw_registration_result(source, target, trans_pre)
137 |
138 |
139 | def registration_withinit(path1, path2, trans_pre=[0, 0, 0]):
140 | source = o3d.geometry.PointCloud()
141 | source.points = o3d.utility.Vector3dVector(np.fromfile(path1).reshape([-1, 3]))
142 | target = o3d.geometry.PointCloud()
143 | target.points = o3d.utility.Vector3dVector(np.fromfile(path2).reshape([-1, 3]))
144 |
145 | trans_init = np.asarray([[1.0, 0.0, 0.0, trans_pre[0]], [0.0, 1.0, 0.0, trans_pre[1]],
146 | [0.0, 0.0, 1.0, trans_pre[2]], [0.0, 0.0, 0.0, 1.0]])
147 |
148 | result_icp = refine_registration(source, target, trans_init)
149 | # draw_registration_result(source, target, result_icp.transformation)
150 |
151 | T = np.asarray(result_icp.transformation)
152 |
153 | return T
154 |
155 |
156 | if __name__ == '__main__':
157 | # demo_open3d()
158 | path1 = '../benchmark_datasets/GTA5/round1/pcl/1.bin'
159 | path2 = '../benchmark_datasets/GTA5/round1/pcl/2.bin'
160 |
161 | # path1 = '../benchmark_datasets/oxford/2014-05-19-13-20-57/pointcloud_20m_10overlap/1400505893170765.bin'
162 | # path2 = '../benchmark_datasets/oxford/2014-05-19-13-20-57/pointcloud_20m_10overlap/1400505894395159.bin'
163 | registration_withinit(path1, path2, voxel_size=0.00125)
164 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/loss/__init__.py
--------------------------------------------------------------------------------
/loss/d_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from misc.log import log_string
4 |
5 |
6 | def make_d_loss(params):
7 | if params.d_loss == 'BCELogitWithMask':
8 | loss_fn = BCELogitWithMask()
9 | elif params.d_loss == 'MSEWithMask':
10 | loss_fn = MSEWithMask()
11 | elif params.d_loss == 'MSE':
12 | loss_fn = MSE()
13 | elif params.d_loss == 'BCELogit':
14 | loss_fn = BCELogit()
15 | else:
16 | log_string('Unknown loss: {}'.format(params.d_loss))
17 | raise NotImplementedError
18 | return loss_fn
19 |
20 |
21 | class DisLossBase:
22 | # Hard triplet miner
23 | def __init__(self):
24 | self.distance = None
25 | self.mask = False
26 | if torch.cuda.is_available():
27 | self.device = "cuda"
28 | else:
29 | self.device = "cpu"
30 |
31 | def __call__(self, syn, hard_triplets, real=None):
32 | assert self.distance is not None
33 | assert len(syn.shape) == 2
34 |
35 | if self.mask:
36 | a, p, n = hard_triplets
37 | syn = syn[a]
38 | real = real[a] if real is not None else None
39 |
40 | if real is not None:
41 | assert len(real.shape) == 2
42 | syn_label = torch.zeros((syn.shape[0], 1), dtype=torch.float32, device=self.device)
43 | real_label = torch.ones((real.shape[0], 1), dtype=torch.float32, device=self.device)
44 | label = torch.cat((syn_label, real_label), dim=1)
45 | out = torch.cat((syn, real), dim=1)
46 | else:
47 | label = torch.ones((syn.shape[0], 1), dtype=torch.float32, device=self.device)
48 | out = syn
49 |
50 | loss = self.distance(out, label)
51 |
52 | if real is not None:
53 | stats = {'d_loss': loss.detach().cpu().item()}
54 | else:
55 | stats = {'g_d_loss': loss.detach().cpu().item()}
56 |
57 | return loss, stats
58 |
59 |
60 | class BCELogitWithMask(DisLossBase):
61 | def __init__(self):
62 | super(BCELogitWithMask, self).__init__()
63 | self.distance = torch.nn.BCEWithLogitsLoss(reduction='mean').to(self.device)
64 | self.mask = True
65 |
66 |
67 | class MSEWithMask(DisLossBase):
68 | def __init__(self):
69 | super(MSEWithMask, self).__init__()
70 | self.distance = torch.nn.MSELoss(reduction='mean').to(self.device)
71 | self.mask = True
72 |
73 |
74 | class BCELogit(DisLossBase):
75 | def __init__(self):
76 | super(BCELogit, self).__init__()
77 | self.distance = torch.nn.BCEWithLogitsLoss(reduction='mean').to(self.device)
78 | self.mask = False
79 |
80 |
81 | class MSE(DisLossBase):
82 | def __init__(self):
83 | super(MSE, self).__init__()
84 | self.distance = torch.nn.MSELoss(reduction='mean').to(self.device)
85 | self.mask = False
86 |
--------------------------------------------------------------------------------
/loss/metric_loss.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 |
6 | import numpy as np
7 | import torch
8 | from pytorch_metric_learning import losses
9 | from pytorch_metric_learning.distances import LpDistance
10 | from scipy.spatial.transform import Rotation
11 |
12 | from loss.mmd_loss import MMD_loss
13 | from misc.log import log_string
14 |
15 | D2G = np.pi / 180.2
16 |
17 |
18 | def make_loss(params):
19 | if params.loss == 'BatchHardTripletMarginLoss':
20 | # BatchHard mining with triplet margin loss
21 | # Expects input: embeddings, positives_mask, negatives_mask
22 | loss_fn = BatchHardTripletLossWithMasks(params.margin, params.normalize_embeddings, params.swap)
23 | elif params.loss == 'BatchHardContrastiveLoss':
24 | loss_fn = BatchHardContrastiveLossWithMasks(params.pos_margin, params.neg_margin, params.normalize_embeddings)
25 | elif params.loss == 'LocQuatMSELoss':
26 | loss_fn = BatchLocQuatLoss(torch.nn.MSELoss())
27 | elif params.loss == 'LocQuatL1Loss':
28 | loss_fn = BatchLocQuatLoss(torch.nn.L1Loss())
29 | elif params.loss == 'LocQuatCloudLoss':
30 | loss_fn = BatchLocQuatCloudLoss()
31 | else:
32 | log_string('Unknown loss: {}'.format(params.ldss))
33 | raise NotImplementedError
34 | return loss_fn
35 |
36 |
37 | class MMDLoss:
38 | def __init__(self, margin, normalize_embeddings, swap):
39 | self.margin = margin
40 | self.normalize_embeddings = normalize_embeddings
41 | self.loss_fn = MMD_loss()
42 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings)
43 | # We use triplet loss with Euclidean distance
44 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
45 |
46 | def __call__(self, embeddings, positives_mask, negatives_mask):
47 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask)
48 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device)
49 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets)
50 | stats = {'metric_loss': loss.detach().cpu().item(),
51 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm,
52 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter,
53 | 'num_triplets': len(hard_triplets[0]),
54 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist,
55 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist,
56 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist,
57 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist,
58 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist,
59 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist
60 | }
61 |
62 |
63 | def quat_to_mat_torch(q, req_grad=False):
64 | x = q[0]
65 | y = q[1]
66 | z = q[2]
67 | w = q[3]
68 | m11 = 1 - 2 * y * y - 2 * z * z
69 | m21 = 2 * x * y + 2 * w * z
70 | m31 = 2 * x * z - 2 * w * y
71 | m12 = 2 * x * y - 2 * w * z
72 | m22 = 1 - 2 * x * x - 1 * 2 * z * z
73 | m32 = 2 * y * z + 2 * w * x
74 | m13 = 2 * x * z + 2 * w * y
75 | m23 = 2 * y * z - 2 * w * x
76 | m33 = 1 - 2 * x * x - 2 * y * y
77 |
78 | return torch.tensor([[m11, m12, m13], [m21, m22, m23], [m31, m32, m33]], requires_grad=req_grad)
79 |
80 |
81 | def to_clouds(batch, device=torch.device('cpu')):
82 | clouds = [[] for i in range(batch[-1][0] + 1)]
83 | for p in batch:
84 | clouds[p[0]].append(p[1:])
85 | clouds = [torch.Tensor(c).to(device) for c in clouds]
86 | return clouds
87 |
88 |
89 | class BatchLocQuatCloudLoss:
90 | def __init__(self):
91 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
92 | self.loss_fn = torch.nn.L1Loss()
93 | pass
94 |
95 | def __call__(self, batch, quat, truth_quat):
96 | batch = batch['coords'].detach().numpy()
97 | rots = []
98 |
99 | loss = 0.0
100 | raw_clouds = to_clouds(batch, self.device)
101 | for q, qt, cloud in zip(quat, truth_quat, raw_clouds):
102 | m1 = quat_to_mat_torch(q, True).to(self.device)
103 | m2 = quat_to_mat_torch(qt).to(self.device)
104 |
105 | c1 = torch.matmul(cloud, m1)
106 | c2 = torch.matmul(cloud, m2)
107 |
108 | loss += self.loss_fn(c1, c2)
109 |
110 | loss = loss / len(raw_clouds)
111 | return loss, {'loss': loss}
112 |
113 |
114 | class BatchLocQuatLoss:
115 | def __init__(self, tloss):
116 | self.loss_fn = tloss
117 |
118 | def __call__(self, quats, quats_truth):
119 | loss = self.loss_fn(quats, quats_truth)
120 | rloss = np.array([0., 0., 0.])
121 | maxrloss = 0
122 | max_loss_r1 = []
123 | max_loss_r2 = []
124 | for q, t in zip(quats.detach().cpu().numpy(), quats_truth.detach().cpu().numpy()):
125 | r1 = Rotation.from_quat(q).as_rotvec() / D2G
126 | r2 = Rotation.from_quat(t).as_rotvec() / D2G
127 | r = abs(r1 - r2)
128 | rloss += r
129 | if maxrloss < np.mean(r):
130 | max_loss_r1 = r1
131 | max_loss_r2 = r2
132 | rloss = np.mean(rloss) / len(quats)
133 | rloss2 = np.linalg.norm(rloss) / len(quats)
134 | return loss, {'loss': loss.detach().cpu(),
135 | 'avg.rloss': rloss,
136 | 'avg.rloss2': rloss2,
137 | 'max_loss_r1_x': max_loss_r1[0],
138 | 'max_loss_r1_y': max_loss_r1[1],
139 | 'max_loss_r1_z': max_loss_r1[2],
140 |
141 | 'max_loss_r2_x': max_loss_r2[0],
142 | 'max_loss_r2_y': max_loss_r2[1],
143 | 'max_loss_r2_z': max_loss_r2[2], }
144 |
145 |
146 | class HardTripletMinerWithMasks:
147 | # Hard triplet miner
148 | def __init__(self, distance):
149 | self.distance = distance
150 | # Stats
151 | self.max_pos_pair_dist = None
152 | self.max_neg_pair_dist = None
153 | self.mean_pos_pair_dist = None
154 | self.mean_neg_pair_dist = None
155 | self.min_pos_pair_dist = None
156 | self.min_neg_pair_dist = None
157 |
158 | def __call__(self, embeddings, positives_mask, negatives_mask):
159 | assert embeddings.dim() == 2
160 | d_embeddings = embeddings.detach()
161 | with torch.no_grad():
162 | hard_triplets = self.mine(d_embeddings, positives_mask, negatives_mask)
163 | return hard_triplets
164 |
165 | def mine(self, embeddings, positives_mask, negatives_mask):
166 | # Based on pytorch-metric-learning implementation n x n
167 | dist_mat = self.distance(embeddings)
168 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask)
169 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask)
170 | a_keep_idx = torch.where(a1p_keep & a2n_keep)
171 | a = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx]
172 | p = hardest_positive_indices[a_keep_idx]
173 | n = hardest_negative_indices[a_keep_idx]
174 | self.max_pos_pair_dist = torch.max(hardest_positive_dist).item()
175 | self.max_neg_pair_dist = torch.max(hardest_negative_dist).item()
176 | self.mean_pos_pair_dist = torch.mean(hardest_positive_dist).item()
177 | self.mean_neg_pair_dist = torch.mean(hardest_negative_dist).item()
178 | self.min_pos_pair_dist = torch.min(hardest_positive_dist).item()
179 | self.min_neg_pair_dist = torch.min(hardest_negative_dist).item()
180 | return a, p, n
181 |
182 |
183 | def get_max_per_row(mat, mask):
184 | non_zero_rows = torch.any(mask, dim=1)
185 | mat_masked = mat.clone().detach()
186 | mat_masked[~mask] = 0
187 | # mat_masked2=torch.where(mask,mat,torch.zeros(mat.shape))
188 | # pdb.set_trace()
189 | return torch.max(mat_masked, dim=1), non_zero_rows
190 |
191 |
192 | def get_min_per_row(mat, mask):
193 | non_inf_rows = torch.any(mask, dim=1)
194 | mat_masked = mat.clone().detach()
195 | mat_masked[~mask] = float('inf')
196 | return torch.min(mat_masked, dim=1), non_inf_rows
197 |
198 |
199 | class BatchHardTripletLossWithMasks:
200 | def __init__(self, margin, normalize_embeddings, swap):
201 | self.margin = margin
202 | self.normalize_embeddings = normalize_embeddings
203 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings)
204 | # We use triplet loss with Euclidean distance
205 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
206 | self.loss_fn = losses.TripletMarginLoss(margin=self.margin, swap=swap, distance=self.distance)
207 |
208 | # 3510
209 | def __call__(self, embeddings, positives_mask, negatives_mask):
210 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask)
211 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device)
212 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets)
213 | stats = {'metric_loss': loss.detach().cpu().item(),
214 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm,
215 | 'num_non_zero_triplets': self.loss_fn.reducer.triplets_past_filter,
216 | 'num_triplets': len(hard_triplets[0]),
217 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist,
218 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist,
219 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist,
220 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist,
221 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist,
222 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist
223 | }
224 |
225 | return loss, stats, hard_triplets
226 |
227 |
228 | class BatchHardContrastiveLossWithMasks:
229 | def __init__(self, pos_margin, neg_margin, normalize_embeddings):
230 | self.pos_margin = pos_margin
231 | self.neg_margin = neg_margin
232 | self.distance = LpDistance(normalize_embeddings=normalize_embeddings)
233 | self.miner_fn = HardTripletMinerWithMasks(distance=self.distance)
234 | # We use contrastive loss with squared Euclidean distance
235 | self.loss_fn = losses.ContrastiveLoss(pos_margin=self.pos_margin, neg_margin=self.neg_margin,
236 | distance=self.distance)
237 |
238 | def __call__(self, embeddings, positives_mask, negatives_mask):
239 | hard_triplets = self.miner_fn(embeddings, positives_mask, negatives_mask)
240 | dummy_labels = torch.arange(embeddings.shape[0]).to(embeddings.device)
241 | loss = self.loss_fn(embeddings, dummy_labels, hard_triplets)
242 | stats = {'metric_loss': loss.detach().cpu().item(),
243 | 'avg_embedding_norm': self.loss_fn.distance.final_avg_query_norm,
244 | 'pos_pairs_above_low': self.loss_fn.reducer.reducers['pos_loss'].pos_pairs_above_low,
245 | 'neg_pairs_above_low': self.loss_fn.reducer.reducers['neg_loss'].neg_pairs_above_low,
246 | 'pos_loss': self.loss_fn.reducer.reducers['pos_loss'].pos_loss,
247 | 'neg_loss': self.loss_fn.reducer.reducers['neg_loss'].neg_loss,
248 | 'num_pairs': 2 * len(hard_triplets[0]),
249 | 'mean_pos_pair_dist': self.miner_fn.mean_pos_pair_dist,
250 | 'mean_neg_pair_dist': self.miner_fn.mean_neg_pair_dist,
251 | 'max_pos_pair_dist': self.miner_fn.max_pos_pair_dist,
252 | 'max_neg_pair_dist': self.miner_fn.max_neg_pair_dist,
253 | 'min_pos_pair_dist': self.miner_fn.min_pos_pair_dist,
254 | 'min_neg_pair_dist': self.miner_fn.min_neg_pair_dist
255 | }
256 |
257 | return loss, stats, hard_triplets
258 |
--------------------------------------------------------------------------------
/loss/mmd_loss.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 | min_var_est = 1e-8
5 |
6 |
7 | # Consider linear time MMD with a linear kernel:
8 | # K(f(x), f(y)) = f(x)^Tf(y)
9 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
10 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
11 | #
12 | # f_of_X: batch_size * k
13 | # f_of_Y: batch_size * k
14 | def linear_mmd2(f_of_X, f_of_Y):
15 | loss = 0.0
16 | delta = f_of_X - f_of_Y
17 | # pdb.set_trace()
18 | loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
19 | return loss
20 |
21 |
22 | # Consider linear time MMD with a polynomial kernel:
23 | # K(f(x), f(y)) = (alpha*f(x)^Tf(y) + c)^d
24 | # f_of_X: batch_size * k
25 | # f_of_Y: batch_size * k
26 | def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
27 | K_XX = (alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c)
28 | K_XX_mean = torch.mean(K_XX.pow(d))
29 |
30 | K_YY = (alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c)
31 | K_YY_mean = torch.mean(K_YY.pow(d))
32 |
33 | K_XY = (alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c)
34 | K_XY_mean = torch.mean(K_XY.pow(d))
35 |
36 | K_YX = (alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c)
37 | K_YX_mean = torch.mean(K_YX.pow(d))
38 |
39 | return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean
40 |
41 |
42 | def _mix_rbf_kernel(X, Y, sigma_list):
43 | assert (X.size(0) == Y.size(0))
44 | m = X.size(0)
45 |
46 | Z = torch.cat((X, Y), 0)
47 | ZZT = torch.mm(Z, Z.t())
48 | diag_ZZT = torch.diag(ZZT).unsqueeze(1)
49 | Z_norm_sqr = diag_ZZT.expand_as(ZZT)
50 | exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
51 |
52 | K = 0.0
53 | for sigma in sigma_list:
54 | gamma = 1.0 / (2 * sigma ** 2)
55 | K += torch.exp(-gamma * exponent)
56 |
57 | return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
58 |
59 |
60 | def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
61 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
62 | # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
63 | return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
64 |
65 |
66 | def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
67 | K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
68 | # return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
69 | return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
70 |
71 |
72 | ################################################################################
73 | # Helper functions to compute variances based on kernel matrices
74 | ################################################################################
75 |
76 |
77 | def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
78 | m = K_XX.size(0) # assume X, Y are same shape
79 |
80 | # Get the various sums of kernels that we'll use
81 | # Kts drop the diagonal, but we don't need to compute them explicitly
82 | if const_diagonal is not False:
83 | diag_X = diag_Y = const_diagonal
84 | sum_diag_X = sum_diag_Y = m * const_diagonal
85 | else:
86 | diag_X = torch.diag(K_XX) # (m,)
87 | diag_Y = torch.diag(K_YY) # (m,)
88 | sum_diag_X = torch.sum(diag_X)
89 | sum_diag_Y = torch.sum(diag_Y)
90 |
91 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
92 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
93 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
94 |
95 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
96 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
97 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
98 |
99 | if biased:
100 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
101 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
102 | - 2.0 * K_XY_sum / (m * m))
103 | else:
104 | mmd2 = (Kt_XX_sum / (m * (m - 1))
105 | + Kt_YY_sum / (m * (m - 1))
106 | - 2.0 * K_XY_sum / (m * m))
107 |
108 | return mmd2
109 |
110 |
111 | def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
112 | mmd2, var_est = _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased)
113 | loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
114 | return loss, mmd2, var_est
115 |
116 |
117 | def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
118 | m = K_XX.size(0) # assume X, Y are same shape
119 |
120 | # Get the various sums of kernels that we'll use
121 | # Kts drop the diagonal, but we don't need to compute them explicitly
122 | if const_diagonal is not False:
123 | diag_X = diag_Y = const_diagonal
124 | sum_diag_X = sum_diag_Y = m * const_diagonal
125 | sum_diag2_X = sum_diag2_Y = m * const_diagonal ** 2
126 | else:
127 | diag_X = torch.diag(K_XX) # (m,)
128 | diag_Y = torch.diag(K_YY) # (m,)
129 | sum_diag_X = torch.sum(diag_X)
130 | sum_diag_Y = torch.sum(diag_Y)
131 | sum_diag2_X = diag_X.dot(diag_X)
132 | sum_diag2_Y = diag_Y.dot(diag_Y)
133 |
134 | Kt_XX_sums = K_XX.sum(dim=1) - diag_X # \tilde{K}_XX * e = K_XX * e - diag_X
135 | Kt_YY_sums = K_YY.sum(dim=1) - diag_Y # \tilde{K}_YY * e = K_YY * e - diag_Y
136 | K_XY_sums_0 = K_XY.sum(dim=0) # K_{XY}^T * e
137 | K_XY_sums_1 = K_XY.sum(dim=1) # K_{XY} * e
138 |
139 | Kt_XX_sum = Kt_XX_sums.sum() # e^T * \tilde{K}_XX * e
140 | Kt_YY_sum = Kt_YY_sums.sum() # e^T * \tilde{K}_YY * e
141 | K_XY_sum = K_XY_sums_0.sum() # e^T * K_{XY} * e
142 |
143 | Kt_XX_2_sum = (K_XX ** 2).sum() - sum_diag2_X # \| \tilde{K}_XX \|_F^2
144 | Kt_YY_2_sum = (K_YY ** 2).sum() - sum_diag2_Y # \| \tilde{K}_YY \|_F^2
145 | K_XY_2_sum = (K_XY ** 2).sum() # \| K_{XY} \|_F^2
146 |
147 | if biased:
148 | mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
149 | + (Kt_YY_sum + sum_diag_Y) / (m * m)
150 | - 2.0 * K_XY_sum / (m * m))
151 | else:
152 | mmd2 = (Kt_XX_sum / (m * (m - 1))
153 | + Kt_YY_sum / (m * (m - 1))
154 | - 2.0 * K_XY_sum / (m * m))
155 |
156 | var_est = (
157 | 2.0 / (m ** 2 * (m - 1.0) ** 2) * (
158 | 2 * Kt_XX_sums.dot(Kt_XX_sums) - Kt_XX_2_sum + 2 * Kt_YY_sums.dot(Kt_YY_sums) - Kt_YY_2_sum)
159 | - (4.0 * m - 6.0) / (m ** 3 * (m - 1.0) ** 3) * (Kt_XX_sum ** 2 + Kt_YY_sum ** 2)
160 | + 4.0 * (m - 2.0) / (m ** 3 * (m - 1.0) ** 2) * (
161 | K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
162 | - 4.0 * (m - 3.0) / (m ** 3 * (m - 1.0) ** 2) * (K_XY_2_sum) - (8 * m - 12) / (
163 | m ** 5 * (m - 1)) * K_XY_sum ** 2
164 | + 8.0 / (m ** 3 * (m - 1.0)) * (
165 | 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
166 | - Kt_XX_sums.dot(K_XY_sums_1)
167 | - Kt_YY_sums.dot(K_XY_sums_0))
168 | )
169 | return mmd2, var_est
170 |
171 |
172 | # https://github.com/ZongxianLee/MMD_Loss.Pytorch
173 | import torch
174 | import torch.nn as nn
175 |
176 |
177 | class MMD_loss(nn.Module):
178 | def __init__(self, kernel_mul=2.0, kernel_num=5):
179 | super(MMD_loss, self).__init__()
180 | self.kernel_num = kernel_num
181 | self.kernel_mul = kernel_mul
182 | self.fix_sigma = None
183 | return
184 |
185 | @staticmethod
186 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
187 | n_samples = int(source.size()[0]) + int(target.size()[0])
188 | total = torch.cat([source, target], dim=0)
189 |
190 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
191 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
192 | L2_distance = ((total0 - total1) ** 2).sum(2)
193 | if fix_sigma:
194 | bandwidth = fix_sigma
195 | else:
196 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
197 | bandwidth /= kernel_mul ** (kernel_num // 2)
198 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
199 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
200 | return sum(kernel_val)
201 |
202 | def forward(self, source, target):
203 | batch_size = int(source.size()[0])
204 | kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num,
205 | fix_sigma=self.fix_sigma)
206 | XX = kernels[:batch_size, :batch_size]
207 | YY = kernels[batch_size:, batch_size:]
208 | XY = kernels[:batch_size, batch_size:]
209 | YX = kernels[batch_size:, :batch_size]
210 | loss = torch.mean(XX + YY - XY - YX)
211 | return loss
212 |
--------------------------------------------------------------------------------
/loss/ops/emd/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/loss/ops/emd/README.md:
--------------------------------------------------------------------------------
1 | This code was taken from [https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd](https://github.com/Colin97/MSN-Point-Cloud-Completion/tree/master/emd). It is a part of the implementation of the following paper:
2 |
3 | ```
4 | @article{liu2019morphing,
5 | title={Morphing and Sampling Network for Dense Point Cloud Completion},
6 | author={Liu, Minghua and Sheng, Lu and Yang, Sheng and Shao, Jing and Hu, Shi-Min},
7 | journal={arXiv preprint arXiv:1912.00280},
8 | year={2019}
9 | }
10 | ```
11 |
12 | ## Earth Mover's Distance of point clouds
13 |
14 | Compared to the Chamfer Distance (CD), the Earth Mover's Distance (EMD) is more reliable to distinguish the visual quality of the point clouds. See our [paper](http://cseweb.ucsd.edu/~mil070/projects/AAAI2020/paper.pdf) for more details.
15 |
16 | We provide an EMD implementation for point cloud comparison, which only needs $O(n)$ memory and thus enables dense point clouds (with 10,000 points or over) and large batch size. It is based on an approximated algorithm (auction algorithm) and cannot guarantee a (but near) bijection assignment. It employs a parameter $\epsilon$ to balance the error rate and the speed of convergence. Smaller $\epsilon$ achieves more accurate results, but needs a longer time for convergence. The time complexity is $O(n^2k)$, where $k$ is the number of iterations. We set a $\epsilon = 0.005, k = 50$ during training and a $\epsilon = 0.002, k = 10000$ during testing.
17 |
18 | ### Compile
19 | Run `python3 setup.py install` to compile.
20 |
21 | ### Example
22 | See `emd_module.py/test_emd()` for examples.
23 |
24 | ### Input
25 |
26 | - **xyz1, xyz2**: float tensors with shape `[#batch, #points, 3]`. xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud. Two point clouds should have same size and be normalized to [0, 1]. The number of points should be a multiple of 1024. The batch size should be no greater than 512. Since we only calculate gradients for xyz1, please do not swap xyz1 and xyz2.
27 | - **eps**: a float tensor, the parameter balances the error rate and the speed of convergence.
28 | - **iters**: a int tensor, the number of iterations.
29 |
30 | ### Output
31 |
32 | - **dist**: a float tensor with shape `[#batch, #points]`. sqrt(dist) are the L2 distances between the pairs of points.
33 | - **assignment**: a int tensor with shape `[#batch, #points]`. The index of the matched point in the ground truth point cloud.
34 |
--------------------------------------------------------------------------------
/loss/ops/emd/emd.cpp:
--------------------------------------------------------------------------------
1 | // EMD approximation module (based on auction algorithm)
2 | // author: Minghua Liu
3 | #include
4 | #include
5 |
6 | int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
7 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
8 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters);
9 |
10 | int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx);
11 |
12 |
13 |
14 | int emd_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
15 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
16 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) {
17 | return emd_cuda_forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx, unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters);
18 | }
19 |
20 | int emd_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx) {
21 |
22 | return emd_cuda_backward(xyz1, xyz2, gradxyz, graddist, idx);
23 | }
24 |
25 |
26 |
27 |
28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29 | m.def("forward", &emd_forward, "emd forward (CUDA)");
30 | m.def("backward", &emd_backward, "emd backward (CUDA)");
31 | }
--------------------------------------------------------------------------------
/loss/ops/emd/emd_cuda.cu:
--------------------------------------------------------------------------------
1 | // EMD approximation module (based on auction algorithm)
2 | // author: Minghua Liu
3 | #include
4 | #include
5 |
6 | #include
7 | #include
8 | #include
9 |
10 | __device__ __forceinline__ float atomicMax(float *address, float val)
11 | {
12 | int ret = __float_as_int(*address);
13 | while(val > __int_as_float(ret))
14 | {
15 | int old = ret;
16 | if((ret = atomicCAS((int *)address, old, __float_as_int(val))) == old)
17 | break;
18 | }
19 | return __int_as_float(ret);
20 | }
21 |
22 |
23 | __global__ void clear(int b, int * cnt_tmp, int * unass_cnt) {
24 | for (int i = threadIdx.x; i < b; i += blockDim.x) {
25 | cnt_tmp[i] = 0;
26 | unass_cnt[i] = 0;
27 | }
28 | }
29 |
30 | __global__ void calc_unass_cnt(int b, int n, int * assignment, int * unass_cnt) {
31 | // count the number of unassigned points in each batch
32 | const int BLOCK_SIZE = 1024;
33 | __shared__ int scan_array[BLOCK_SIZE];
34 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
35 | scan_array[threadIdx.x] = assignment[i * n + blockIdx.y * BLOCK_SIZE + threadIdx.x] == -1 ? 1 : 0;
36 | __syncthreads();
37 |
38 | int stride = 1;
39 | while(stride <= BLOCK_SIZE / 2) {
40 | int index = (threadIdx.x + 1) * stride * 2 - 1;
41 | if(index < BLOCK_SIZE)
42 | scan_array[index] += scan_array[index - stride];
43 | stride = stride * 2;
44 | __syncthreads();
45 | }
46 | __syncthreads();
47 |
48 | if (threadIdx.x == BLOCK_SIZE - 1) {
49 | atomicAdd(&unass_cnt[i], scan_array[threadIdx.x]);
50 | }
51 | __syncthreads();
52 | }
53 | }
54 |
55 | __global__ void calc_unass_cnt_sum(int b, int * unass_cnt, int * unass_cnt_sum) {
56 | // count the cumulative sum over over unass_cnt
57 | const int BLOCK_SIZE = 512; // batch_size <= 512
58 | __shared__ int scan_array[BLOCK_SIZE];
59 | scan_array[threadIdx.x] = unass_cnt[threadIdx.x];
60 | __syncthreads();
61 |
62 | int stride = 1;
63 | while(stride <= BLOCK_SIZE / 2) {
64 | int index = (threadIdx.x + 1) * stride * 2 - 1;
65 | if(index < BLOCK_SIZE)
66 | scan_array[index] += scan_array[index - stride];
67 | stride = stride * 2;
68 | __syncthreads();
69 | }
70 | __syncthreads();
71 | stride = BLOCK_SIZE / 4;
72 | while(stride > 0) {
73 | int index = (threadIdx.x + 1) * stride * 2 - 1;
74 | if((index + stride) < BLOCK_SIZE)
75 | scan_array[index + stride] += scan_array[index];
76 | stride = stride / 2;
77 | __syncthreads();
78 | }
79 | __syncthreads();
80 |
81 | //printf("%d\n", unass_cnt_sum[b - 1]);
82 | unass_cnt_sum[threadIdx.x] = scan_array[threadIdx.x];
83 | }
84 |
85 | __global__ void calc_unass_idx(int b, int n, int * assignment, int * unass_idx, int * unass_cnt, int * unass_cnt_sum, int * cnt_tmp) {
86 | // list all the unassigned points
87 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
88 | if (assignment[i * n + blockIdx.y * 1024 + threadIdx.x] == -1) {
89 | int idx = atomicAdd(&cnt_tmp[i], 1);
90 | unass_idx[unass_cnt_sum[i] - unass_cnt[i] + idx] = blockIdx.y * 1024 + threadIdx.x;
91 | }
92 | }
93 | }
94 |
95 | __global__ void Bid(int b, int n, const float * xyz1, const float * xyz2, float eps, int * assignment, int * assignment_inv, float * price,
96 | int * bid, float * bid_increments, float * max_increments, int * unass_cnt, int * unass_cnt_sum, int * unass_idx) {
97 | const int batch = 2048, block_size = 1024, block_cnt = n / 1024;
98 | __shared__ float xyz2_buf[batch * 3];
99 | __shared__ float price_buf[batch];
100 | __shared__ float best_buf[block_size];
101 | __shared__ float better_buf[block_size];
102 | __shared__ int best_i_buf[block_size];
103 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
104 | int _unass_cnt = unass_cnt[i];
105 | if (_unass_cnt == 0)
106 | continue;
107 | int _unass_cnt_sum = unass_cnt_sum[i];
108 | int unass_per_block = (_unass_cnt + block_cnt - 1) / block_cnt;
109 | int thread_per_unass = block_size / unass_per_block;
110 | int unass_this_block = max(min(_unass_cnt - (int) blockIdx.y * unass_per_block, unass_per_block), 0);
111 |
112 | float x1, y1, z1, best = -1e9, better = -1e9;
113 | int best_i = -1, _unass_id = -1, thread_in_unass;
114 |
115 | if (threadIdx.x < thread_per_unass * unass_this_block) {
116 | _unass_id = unass_per_block * blockIdx.y + threadIdx.x / thread_per_unass + _unass_cnt_sum - _unass_cnt;
117 | _unass_id = unass_idx[_unass_id];
118 | thread_in_unass = threadIdx.x % thread_per_unass;
119 |
120 | x1 = xyz1[(i * n + _unass_id) * 3 + 0];
121 | y1 = xyz1[(i * n + _unass_id) * 3 + 1];
122 | z1 = xyz1[(i * n + _unass_id) * 3 + 2];
123 | }
124 |
125 | for (int k2 = 0; k2 < n; k2 += batch) {
126 | int end_k = min(n, k2 + batch) - k2;
127 | for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) {
128 | xyz2_buf[j] = xyz2[(i * n + k2) * 3 + j];
129 | }
130 | for (int j = threadIdx.x; j < end_k; j += blockDim.x) {
131 | price_buf[j] = price[i * n + k2 + j];
132 | }
133 | __syncthreads();
134 |
135 | if (_unass_id != -1) {
136 | int delta = (end_k + thread_per_unass - 1) / thread_per_unass;
137 | int l = thread_in_unass * delta;
138 | int r = min((thread_in_unass + 1) * delta, end_k);
139 | for (int k = l; k < r; k++)
140 | //if (!last || assignment_inv[i * n + k + k2] == -1)
141 | {
142 | float x2 = xyz2_buf[k * 3 + 0] - x1;
143 | float y2 = xyz2_buf[k * 3 + 1] - y1;
144 | float z2 = xyz2_buf[k * 3 + 2] - z1;
145 | // the coordinates of points should be normalized to [0, 1]
146 | float d = 3.0 - sqrtf(x2 * x2 + y2 * y2 + z2 * z2) - price_buf[k];
147 | if (d > best) {
148 | better = best;
149 | best = d;
150 | best_i = k + k2;
151 | }
152 | else if (d > better) {
153 | better = d;
154 | }
155 | }
156 | }
157 | __syncthreads();
158 | }
159 |
160 | best_buf[threadIdx.x] = best;
161 | better_buf[threadIdx.x] = better;
162 | best_i_buf[threadIdx.x] = best_i;
163 | __syncthreads();
164 |
165 | if (_unass_id != -1 && thread_in_unass == 0) {
166 | for (int j = threadIdx.x + 1; j < threadIdx.x + thread_per_unass; j++) {
167 | if (best_buf[j] > best) {
168 | better = max(best, better_buf[j]);
169 | best = best_buf[j];
170 | best_i = best_i_buf[j];
171 | }
172 | else better = max(better, best_buf[j]);
173 | }
174 | bid[i * n + _unass_id] = best_i;
175 | bid_increments[i * n + _unass_id] = best - better + eps;
176 | atomicMax(&max_increments[i * n + best_i], best - better + eps);
177 | }
178 | }
179 | }
180 |
181 | __global__ void GetMax(int b, int n, int * assignment, int * bid, float * bid_increments, float * max_increments, int * max_idx) {
182 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
183 | int j = threadIdx.x + blockIdx.y * blockDim.x;
184 | if (assignment[i * n + j] == -1) {
185 | int bid_id = bid[i * n + j];
186 | float bid_inc = bid_increments[i * n + j];
187 | float max_inc = max_increments[i * n + bid_id];
188 | if (bid_inc - 1e-6 <= max_inc && max_inc <= bid_inc + 1e-6)
189 | {
190 | max_idx[i * n + bid_id] = j;
191 | }
192 | }
193 | }
194 | }
195 |
196 | __global__ void Assign(int b, int n, int * assignment, int * assignment_inv, float * price, int * bid, float * bid_increments, float * max_increments, int * max_idx, bool last) {
197 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
198 | int j = threadIdx.x + blockIdx.y * blockDim.x;
199 | if (assignment[i * n + j] == -1) {
200 | int bid_id = bid[i * n + j];
201 | if (last || max_idx[i * n + bid_id] == j)
202 | {
203 | float bid_inc = bid_increments[i * n + j];
204 | int ass_inv = assignment_inv[i * n + bid_id];
205 | if (!last && ass_inv != -1) {
206 | assignment[i * n + ass_inv] = -1;
207 | }
208 | assignment_inv[i * n + bid_id] = j;
209 | assignment[i * n + j] = bid_id;
210 | price[i * n + bid_id] += bid_inc;
211 | max_increments[i * n + bid_id] = -1e9;
212 | }
213 | }
214 | }
215 | }
216 |
217 | __global__ void CalcDist(int b, int n, float * xyz1, float * xyz2, float * dist, int * assignment) {
218 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
219 | int j = threadIdx.x + blockIdx.y * blockDim.x;
220 | int k = assignment[i * n + j];
221 | float deltax = xyz1[(i * n + j) * 3 + 0] - xyz2[(i * n + k) * 3 + 0];
222 | float deltay = xyz1[(i * n + j) * 3 + 1] - xyz2[(i * n + k) * 3 + 1];
223 | float deltaz = xyz1[(i * n + j) * 3 + 2] - xyz2[(i * n + k) * 3 + 2];
224 | dist[i * n + j] = deltax * deltax + deltay * deltay + deltaz * deltaz;
225 | }
226 | }
227 |
228 | int emd_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist, at::Tensor assignment, at::Tensor price,
229 | at::Tensor assignment_inv, at::Tensor bid, at::Tensor bid_increments, at::Tensor max_increments,
230 | at::Tensor unass_idx, at::Tensor unass_cnt, at::Tensor unass_cnt_sum, at::Tensor cnt_tmp, at::Tensor max_idx, float eps, int iters) {
231 |
232 | const auto batch_size = xyz1.size(0);
233 | const auto n = xyz1.size(1); //num_points point cloud A
234 | const auto m = xyz2.size(1); //num_points point cloud B
235 |
236 | if (n != m) {
237 | printf("Input Error! The two point clouds should have the same size.\n");
238 | return -1;
239 | }
240 |
241 | if (batch_size > 512) {
242 | printf("Input Error! The batch size should be less than 512.\n");
243 | return -1;
244 | }
245 |
246 | if (n % 1024 != 0) {
247 | printf("Input Error! The size of the point clouds should be a multiple of 1024.\n");
248 | return -1;
249 | }
250 |
251 | //cudaEvent_t start,stop;
252 | //cudaEventCreate(&start);
253 | //cudaEventCreate(&stop);
254 | //cudaEventRecord(start);
255 | //int iters = 50;
256 | for (int i = 0; i < iters; i++) {
257 | clear<<<1, batch_size>>>(batch_size, cnt_tmp.data(), unass_cnt.data());
258 | calc_unass_cnt<<>>(batch_size, n, assignment.data(), unass_cnt.data());
259 | calc_unass_cnt_sum<<<1, batch_size>>>(batch_size, unass_cnt.data(), unass_cnt_sum.data());
260 | calc_unass_idx<<>>(batch_size, n, assignment.data(), unass_idx.data(), unass_cnt.data(),
261 | unass_cnt_sum.data(), cnt_tmp.data());
262 | Bid<<>>(batch_size, n, xyz1.data(), xyz2.data(), eps, assignment.data(), assignment_inv.data(),
263 | price.data(), bid.data(), bid_increments.data(), max_increments.data(),
264 | unass_cnt.data(), unass_cnt_sum.data(), unass_idx.data());
265 | GetMax<<>>(batch_size, n, assignment.data(), bid.data(), bid_increments.data(), max_increments.data(), max_idx.data());
266 | Assign<<>>(batch_size, n, assignment.data(), assignment_inv.data(), price.data(), bid.data(),
267 | bid_increments.data(), max_increments.data(), max_idx.data(), i == iters - 1);
268 | }
269 | CalcDist<<>>(batch_size, n, xyz1.data(), xyz2.data(), dist.data(), assignment.data());
270 | //cudaEventRecord(stop);
271 | //cudaEventSynchronize(stop);
272 | //float elapsedTime;
273 | //cudaEventElapsedTime(&elapsedTime,start,stop);
274 | //printf("%lf\n", elapsedTime);
275 |
276 | cudaError_t err = cudaGetLastError();
277 | if (err != cudaSuccess) {
278 | printf("error in nnd Output: %s\n", cudaGetErrorString(err));
279 | return 0;
280 | }
281 | return 1;
282 | }
283 |
284 | __global__ void NmDistanceGradKernel(int b, int n, const float * xyz1, const float * xyz2, const float * grad_dist, const int * idx, float * grad_xyz){
285 | for (int i = blockIdx.x; i < b; i += gridDim.x) {
286 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; j += blockDim.x * gridDim.y) {
287 | float x1 = xyz1[(i * n + j) * 3 + 0];
288 | float y1 = xyz1[(i * n + j) * 3 + 1];
289 | float z1 = xyz1[(i * n + j) * 3 + 2];
290 | int j2 = idx[i * n + j];
291 | float x2 = xyz2[(i * n + j2) * 3 + 0];
292 | float y2 = xyz2[(i * n + j2) * 3 + 1];
293 | float z2 = xyz2[(i * n + j2) * 3 + 2];
294 | float g = grad_dist[i * n + j] * 2;
295 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 0]), g * (x1 - x2));
296 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 1]), g * (y1 - y2));
297 | atomicAdd(&(grad_xyz[(i * n + j) * 3 + 2]), g * (z1 - z2));
298 | }
299 | }
300 | }
301 |
302 | int emd_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz, at::Tensor graddist, at::Tensor idx){
303 | const auto batch_size = xyz1.size(0);
304 | const auto n = xyz1.size(1);
305 | const auto m = xyz2.size(1);
306 |
307 | NmDistanceGradKernel<<>>(batch_size, n, xyz1.data(), xyz2.data(), graddist.data(), idx.data(), gradxyz.data());
308 |
309 | cudaError_t err = cudaGetLastError();
310 | if (err != cudaSuccess) {
311 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
312 | return 0;
313 | }
314 | return 1;
315 |
316 | }
317 |
--------------------------------------------------------------------------------
/loss/ops/emd/emd_module.py:
--------------------------------------------------------------------------------
1 | # EMD approximation module (based on auction algorithm)
2 | # memory complexity: O(n)
3 | # time complexity: O(n^2 * iter)
4 | # author: Minghua Liu
5 |
6 | # Input:
7 | # xyz1, xyz2: [#batch, #points, 3]
8 | # where xyz1 is the predicted point cloud and xyz2 is the ground truth point cloud
9 | # two point clouds should have same size and be normalized to [0, 1]
10 | # #points should be a multiple of 1024
11 | # #batch should be no greater than 512
12 | # eps is a parameter which balances the error rate and the speed of convergence
13 | # iters is the number of iteration
14 | # we only calculate gradient for xyz1
15 |
16 | # Output:
17 | # dist: [#batch, #points], sqrt(dist) -> L2 distance
18 | # assignment: [#batch, #points], index of the matched point in the ground truth point cloud
19 | # the result is an approximation and the assignment is not guranteed to be a bijection
20 |
21 | import time
22 |
23 | import emd
24 | import numpy as np
25 | import torch
26 | from torch import nn
27 | from torch.autograd import Function
28 |
29 |
30 | class emdFunction(Function):
31 | @staticmethod
32 | def forward(ctx, xyz1, xyz2, eps, iters):
33 | batchsize, n, _ = xyz1.size()
34 | _, m, _ = xyz2.size()
35 |
36 | assert (n == m)
37 | assert (xyz1.size()[0] == xyz2.size()[0])
38 | assert (n % 1024 == 0)
39 | assert (batchsize <= 512)
40 |
41 | xyz1 = xyz1.contiguous().float().cuda()
42 | xyz2 = xyz2.contiguous().float().cuda()
43 | dist = torch.zeros(batchsize, n, device='cuda').contiguous()
44 | assignment = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous() - 1
45 | assignment_inv = torch.zeros(batchsize, m, device='cuda', dtype=torch.int32).contiguous() - 1
46 | price = torch.zeros(batchsize, m, device='cuda').contiguous()
47 | bid = torch.zeros(batchsize, n, device='cuda', dtype=torch.int32).contiguous()
48 | bid_increments = torch.zeros(batchsize, n, device='cuda').contiguous()
49 | max_increments = torch.zeros(batchsize, m, device='cuda').contiguous()
50 | unass_idx = torch.zeros(batchsize * n, device='cuda', dtype=torch.int32).contiguous()
51 | max_idx = torch.zeros(batchsize * m, device='cuda', dtype=torch.int32).contiguous()
52 | unass_cnt = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
53 | unass_cnt_sum = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
54 | cnt_tmp = torch.zeros(512, dtype=torch.int32, device='cuda').contiguous()
55 |
56 | emd.forward(xyz1, xyz2, dist, assignment, price, assignment_inv, bid, bid_increments, max_increments, unass_idx,
57 | unass_cnt, unass_cnt_sum, cnt_tmp, max_idx, eps, iters)
58 |
59 | ctx.save_for_backward(xyz1, xyz2, assignment)
60 | return dist, assignment
61 |
62 | @staticmethod
63 | def backward(ctx, graddist, gradidx):
64 | xyz1, xyz2, assignment = ctx.saved_tensors
65 | graddist = graddist.contiguous()
66 |
67 | gradxyz1 = torch.zeros(xyz1.size(), device='cuda').contiguous()
68 | gradxyz2 = torch.zeros(xyz2.size(), device='cuda').contiguous()
69 |
70 | emd.backward(xyz1, xyz2, gradxyz1, graddist, assignment)
71 | return gradxyz1, gradxyz2, None, None
72 |
73 |
74 | class emdModule(nn.Module):
75 | def __init__(self):
76 | super(emdModule, self).__init__()
77 |
78 | def forward(self, input1, input2, eps, iters):
79 | return emdFunction.apply(input1, input2, eps, iters)
80 |
81 |
82 | def test_emd():
83 | x1 = torch.rand(20, 8192, 3).cuda()
84 | x2 = torch.rand(20, 8192, 3).cuda()
85 | emdm = emdModule()
86 | start_time = time.perf_counter()
87 | dis, assigment = emdm(x1, x2, 0.05, 3000)
88 | print("Input_size: ", x1.shape)
89 | print("Runtime: %lfs" % (time.perf_counter() - start_time))
90 | print("EMD: %lf" % np.sqrt(dis.cpu()).mean())
91 | print("|set(assignment)|: %d" % assigment.unique().numel())
92 | assigment = assigment.cpu().numpy()
93 | assigment = np.expand_dims(assigment, -1)
94 | x2 = np.take_along_axis(x2, assigment, axis=1)
95 | d = (x1 - x2) * (x1 - x2)
96 | print("Verified EMD: %lf" % np.sqrt(d.cpu().sum(-1)).mean())
97 |
98 |
99 | if __name__ == '__main__':
100 | test_emd()
101 |
--------------------------------------------------------------------------------
/loss/ops/emd/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='emd',
6 | ext_modules=[
7 | CUDAExtension('emd', [
8 | 'emd.cpp',
9 | 'emd_cuda.cu',
10 | ]),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
15 |
--------------------------------------------------------------------------------
/loss/reg_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from scipy.spatial.transform import Rotation
8 |
9 | from loss.ops.emd.emd_module import emdFunction
10 | from misc.point_utils import get_knn_idx_dist, group, gather
11 |
12 |
13 | def cd_loss(preds, gts):
14 | def batch_pairwise_dist(x, y):
15 | bs, num_points_x, points_dim = x.size()
16 | _, num_points_y, _ = y.size()
17 | xx = torch.bmm(x, x.transpose(2, 1))
18 | yy = torch.bmm(y, y.transpose(2, 1))
19 | zz = torch.bmm(x, y.transpose(2, 1))
20 |
21 | diag_ind_x = torch.arange(0, num_points_x).to(device=x.device)
22 | diag_ind_y = torch.arange(0, num_points_y).to(device=y.device)
23 |
24 | rx = xx[:, diag_ind_x, diag_ind_x].unsqueeze(1).expand_as(zz.transpose(2, 1))
25 | ry = yy[:, diag_ind_y, diag_ind_y].unsqueeze(1).expand_as(zz)
26 | P = (rx.transpose(2, 1) + ry - 2 * zz)
27 | return P
28 |
29 | P = batch_pairwise_dist(gts, preds)
30 | mins, _ = torch.min(P, 1)
31 | loss_1 = torch.mean(mins)
32 | mins, _ = torch.min(P, 2)
33 | loss_2 = torch.mean(mins)
34 |
35 | return loss_1 + loss_2
36 |
37 |
38 | def emd_loss(preds, gts, eps=0.005, iters=50):
39 | loss, _ = emdFunction.apply(preds, gts, eps, iters)
40 | return torch.mean(loss)
41 |
42 |
43 | class ChamferLoss(nn.Module):
44 |
45 | def __init__(self):
46 | super().__init__()
47 |
48 | def forward(self, preds, gts, **kwargs):
49 | return cd_loss(preds, gts)
50 |
51 |
52 | class EMDLoss(nn.Module):
53 |
54 | def __init__(self, eps=0.005, iters=50):
55 | super().__init__()
56 | self.eps = eps
57 | self.iters = iters
58 |
59 | def forward(self, preds, gts, **kwargs):
60 | return emd_loss(preds, gts, eps=self.eps, iters=self.iters)
61 |
62 |
63 | class ProjectionLoss(nn.Module):
64 |
65 | def __init__(self, knn=8, sigma_p=0.03, sigma_n=math.radians(15)):
66 | super().__init__()
67 | self.sigma_p = sigma_p
68 | self.sigma_n = sigma_n
69 | self.knn = knn
70 |
71 | def distance_weight(self, dist):
72 | """
73 | :param dist: (B, N, k), Squared L2 distance
74 | :return (B, N, k)
75 | """
76 | return torch.exp(- dist / (self.sigma_p ** 2))
77 |
78 | def angle_weight(self, nb_normals):
79 | """
80 | :param nb_normals: (B, N, k, 3), Normals of neighboring points
81 | :return (B, N, k)
82 | """
83 | estm_normal = nb_normals[:, :, 0:1, :] # (B, N, 1, 3)
84 | inner_prod = (nb_normals * estm_normal.expand_as(nb_normals)).sum(dim=-1) # (B, N, k)
85 | return torch.exp(- (1 - inner_prod) / (1 - math.cos(self.sigma_n)))
86 |
87 | def forward(self, preds, gts, normals, **kwargs):
88 | knn_idx, knn_dist = get_knn_idx_dist(gts, query=preds, k=self.knn, offset=0) # (B, N, k), squared L2 distance
89 |
90 | nb_points = group(gts, idx=knn_idx) # (B, N, k, 3)
91 | nb_normals = group(normals, idx=knn_idx) # (B, N, k, 3)
92 |
93 | distance_w = self.distance_weight(knn_dist) # (B, N, k)
94 | angle_w = self.angle_weight(nb_normals) # (B, N, k)
95 | weights = distance_w * angle_w # (B, N, k)
96 |
97 | inner_prod = ((preds.unsqueeze(-2).expand_as(nb_points) - nb_points) * nb_normals).sum(dim=-1) # (B, N, k)
98 | inner_prod = torch.abs(inner_prod) # (B, N, k)
99 |
100 | point_displacement = (inner_prod * weights).sum(dim=-1) / weights.sum(dim=-1) # (B, N)
101 |
102 | return point_displacement.sum()
103 |
104 |
105 | class UnsupervisedLoss(nn.Module):
106 |
107 | def __init__(self, k=64, radius=0.05, pdf_std=0.5657, inv_scale=0.05, decay_epoch=80, emd_eps=0.005, emd_iters=50):
108 | super().__init__()
109 | self.knn = k
110 | self.radius = radius
111 | self.pdf_std = pdf_std
112 | self.inv_scale = inv_scale
113 | self.decay_epoch = decay_epoch
114 | self.emd_eps = emd_eps
115 | self.emd_iters = emd_iters
116 |
117 | def stochastic_neighborhood(self, inputs):
118 | """
119 | param: inputs: (B, N, 3)
120 | return: knn_idx: (B, N, k), Indices of neighboring points
121 | return: mask: (B, N, k), Mask
122 | """
123 | knn_idx, knn_dist = get_knn_idx_dist(inputs, query=inputs, k=self.knn, offset=1) # (B, N, k), exclude self
124 |
125 | # Gaussian spatial prior
126 | SQRT_2PI = 2.5066282746
127 | prob = torch.exp(- (knn_dist / (self.inv_scale ** 2)) / (2 * self.pdf_std ** 2)) / (
128 | self.pdf_std * SQRT_2PI) # (B, N, k)
129 | mask = torch.bernoulli(prob) # (B, N, k)
130 |
131 | prob = prob * (torch.sqrt(knn_dist) <= self.radius) # Radius cutoff
132 |
133 | # If all the neighbor of a point are rejected, then accept at least one to avoid zero loss
134 | # Here we accept the farthest one, because all-rejected probably happens when the point is displaced too far (high noise).
135 | mask_sum = mask.sum(dim=-1, keepdim=True) # (B, N, 1)
136 | mask_farthest = torch.where(mask_sum == 0, torch.ones_like(mask_sum), torch.zeros_like(mask_sum)) # (B, N, 1)
137 | mask_delta = torch.cat([torch.zeros_like(mask_sum).repeat(1, 1, self.knn - 1), mask_farthest],
138 | dim=-1) # (B, N, k)
139 | mask = mask + mask_delta # (B, N, k)
140 |
141 | return knn_idx, mask, prob
142 |
143 | def forward(self, preds, inputs, epoch, **kwargs):
144 | _, assignment = emdFunction.apply(inputs, preds, self.emd_eps,
145 | self.emd_iters) # (B, N), assign each input point to a predicted point
146 |
147 | # Permute the predicted points according to the assignment
148 | # One-to-one correspondent to input points
149 | permuted_preds = gather(preds, idx=assignment.long()) # (B, N, 3)
150 |
151 | input_nbh_idx, input_nbh_mask, _ = self.stochastic_neighborhood(inputs) # (B, N, k), (B, N, k)
152 | input_nbh_pos = group(inputs, idx=input_nbh_idx) # (B, N, k, 3)
153 |
154 | dist = (permuted_preds.unsqueeze(dim=-2).expand_as(input_nbh_pos) - input_nbh_pos) # (B, N, k, 3)
155 | dist = (dist ** 2).sum(dim=-1) # (B, N, k), squared-L2 distance
156 |
157 | num_nbh = input_nbh_mask.sum(dim=-1) # (B, N), number of neighbors
158 | avg_dist = (dist * input_nbh_mask).sum(dim=-1) / num_nbh # (B, N), average distance
159 |
160 | return avg_dist.sum()
161 |
162 |
163 | def get_loss_layer(name):
164 | if name == 'emd':
165 | return EMDLoss()
166 | elif name == 'cd':
167 | return ChamferLoss()
168 | elif name == 'proj':
169 | return ProjectionLoss()
170 | elif name == 'unsupervised':
171 | return UnsupervisedLoss()
172 | elif name is None or name == 'None':
173 | return None
174 | else:
175 | raise ValueError('Unknown loss: %s ' % name)
176 |
177 |
178 | class RepulsionLoss(nn.Module):
179 |
180 | def __init__(self, knn=4, h=0.03):
181 | super().__init__()
182 | self.knn = knn
183 | self.h = h
184 |
185 | def forward(self, pc):
186 | knn_idx, knn_dist = get_knn_idx_dist(pc, pc, k=self.knn, offset=1) # (B, N, k)
187 | weight = torch.exp(- knn_dist / (self.h ** 2))
188 | loss = torch.sum(- knn_dist * weight)
189 | return loss
190 |
191 |
192 | def make_reg_loss(params):
193 | model_params = params.model_params
194 | if not params.is_register:
195 | return None
196 | if model_params.reg_loss == 'EMD':
197 | loss_fn = EMDLoss(iters=params.max_iter)
198 | elif model_params.reg_loss == 'ChamferLoss':
199 | loss_fn = ChamferLoss()
200 | elif model_params.reg_loss == 'PointMSE':
201 | loss_fn = PointMSE()
202 | elif model_params.reg_loss == 'PoseMSE':
203 | loss_fn = PoseMSE()
204 | else:
205 | print('Unknown loss: {}'.format(model_params.reg_loss))
206 | raise NotImplementedError
207 | return loss_fn
208 |
209 |
210 | class PointMSE(nn.Module):
211 |
212 | def __init__(self):
213 | super().__init__()
214 |
215 | def forward(self, pred_pts, gt_pts):
216 | loss = torch.nn.functional.mse_loss(pred_pts, gt_pts)
217 |
218 | return loss
219 |
220 |
221 | class PoseMSE(nn.Module):
222 |
223 | def __init__(self):
224 | super().__init__()
225 |
226 | def forward(self, rotation_ab_pred, translation_ab_pred, rotation_ab, translation_ab):
227 | batch_size = rotation_ab_pred.size(0)
228 | identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
229 | loss = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
230 | + F.mse_loss(translation_ab_pred, translation_ab)
231 |
232 | return loss
233 |
234 |
235 | def npmat2euler(mats, seq='zyx'):
236 | eulers = []
237 | for i in range(mats.shape[0]):
238 | r = Rotation.from_matrix(mats[i])
239 | eulers.append(r.as_euler(seq, degrees=True))
240 | return np.asarray(eulers, dtype='float32')
241 |
242 |
243 | def dcp_metric(R_ab, t_ab, R_ab_pred, t_ab_pred):
244 | R_ab_pred_euler = npmat2euler(R_ab_pred)
245 | R_ab_euler = npmat2euler(R_ab)
246 | r_mse_ab = np.mean((R_ab_pred_euler - R_ab_euler) ** 2)
247 | r_mae_ab = np.mean(np.abs(R_ab_pred_euler - R_ab_euler))
248 | t_mse_ab = np.mean((t_ab - t_ab_pred) ** 2)
249 | t_mae_ab = np.mean(np.abs(t_ab - t_ab_pred))
250 |
251 | temp_state = {}
252 | temp_state.update({'r_mse_ab': r_mse_ab})
253 | temp_state.update({'r_mae_ab': r_mae_ab})
254 | temp_state.update({'t_mse_ab': t_mse_ab})
255 | temp_state.update({'t_mae_ab': t_mae_ab})
256 |
257 | return temp_state
258 |
259 |
260 | def cor_loss(dict_pose, gt_T):
261 | src = dict_pose['src']
262 | src_corr = dict_pose['src_corr']
263 | rotation_ab_pred = dict_pose['rotation_ab']
264 | translation_ab_pred = dict_pose['translation_ab']
265 | outlier_src_mask = dict_pose['outlier_src_mask']
266 | gt_T = gt_T.view(-1, 4, 4).cuda()
267 | rotation_ab = gt_T[:, :3, :3]
268 | translation_ab = gt_T[:, :3, 3]
269 |
270 | transformed_srcK = transform_point_cloud(src, rotation_ab, translation_ab)
271 | loss = mse_mask(transformed_srcK, src_corr, outlier_src_mask)
272 |
273 | return loss
274 |
275 |
276 | def quat2mat(quat):
277 | x, y, z, w = quat[:, 0], quat[:, 1], quat[:, 2], quat[:, 3]
278 |
279 | B = quat.size(0)
280 |
281 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
282 | wx, wy, wz = w * x, w * y, w * z
283 | xy, xz, yz = x * y, x * z, y * z
284 |
285 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
286 | 2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
287 | 2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3)
288 | return rotMat
289 |
290 |
291 | def transform_point_cloud(point_cloud, rotation, translation):
292 | if len(rotation.size()) == 2:
293 | rot_mat = quat2mat(rotation)
294 | else:
295 | rot_mat = rotation
296 | return torch.matmul(rot_mat, point_cloud) + translation.unsqueeze(2)
297 |
298 |
299 | def mse_mask(transformed_srcK, src_corrK, mask):
300 | b, _, num = transformed_srcK.size()
301 | mask = mask.contiguous().view(b, 1, num).repeat(1, 3, 1)
302 | transformed_srcK = torch.masked_fill(transformed_srcK, mask, 0)
303 | src_corrK = torch.masked_fill(src_corrK, mask, 0)
304 | loss = torch.nn.functional.mse_loss(transformed_srcK, src_corrK)
305 | return loss
306 |
--------------------------------------------------------------------------------
/media/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/media/pipeline.png
--------------------------------------------------------------------------------
/misc/log.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from datetime import datetime
4 |
5 | from dateutil import tz
6 |
7 | log_dir = None
8 | reg_log_dir = None
9 | LOG_FOUT = None
10 | inited = False
11 |
12 |
13 | def setup_log(args):
14 | global LOG_FOUT, log_dir, inited, start_time, reg_log_dir
15 | if inited:
16 | return
17 | inited = True
18 | config = args.config.split('/')[-1].split('.')[0].replace('config_baseline', 'cb')
19 | model_config = args.model_config.split('/')[-1].split('.')[0]
20 | tz_sh = tz.gettz('Asia/Shanghai')
21 | now = datetime.now(tz=tz_sh)
22 | if (not os.path.exists("./tf_logs")):
23 | os.mkdir("./tf_logs")
24 | # dir = '{}-{}-{}'.format(config, model_config, now.strftime("%m%d-%H%M%S"))
25 | dir = '{}-{}'.format(config, model_config)
26 | log_dir = os.path.join("./tf_logs", dir)
27 | if not os.path.exists(log_dir):
28 | os.makedirs(log_dir)
29 | os.system('rm -r {}'.format(os.path.join("./tf_logs", 'latest')))
30 | os.system("cd tf_logs && ln -s {} {} && cd ..".format(dir, "latest"))
31 |
32 | start_time = now
33 | LOG_FOUT = open(os.path.join(log_dir, 'log_train.txt'), 'w')
34 | log_string('log dir: {}'.format(log_dir))
35 | reg_log_dir = os.path.join(log_dir, "registration")
36 | if not os.path.exists(reg_log_dir):
37 | os.makedirs(reg_log_dir)
38 |
39 |
40 | def log_string(out_str, end='\n'):
41 | LOG_FOUT.write(out_str)
42 | LOG_FOUT.write(end)
43 | LOG_FOUT.flush()
44 | print(out_str, end=end, flush=True)
45 |
46 |
47 | def log_silent(out_str, end='\n'):
48 | LOG_FOUT.write(out_str)
49 | LOG_FOUT.write(end)
50 | LOG_FOUT.flush()
51 |
52 |
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('--config', type=str)
55 | parser.add_argument('--model_config', type=str)
56 | parser.add_argument('--debug', dest='debug', action='store_true')
57 | parser.set_defaults(debug=False)
58 | parser.add_argument('--checkpoint', type=str, required=False, help='Trained model weights', default="")
59 | parser.add_argument('--weights')
60 | parser.add_argument('--log', action='store_true')
61 | parser.add_argument('--visualize', dest='visualize', action='store_true')
62 | parser.set_defaults(visualize=False)
63 | args = parser.parse_known_args()[0]
64 | setup_log(args)
65 |
66 |
67 | def is_last_run_end(last_run_file):
68 | with open(last_run_file) as f:
69 | lines = f.readlines()
70 | for i in lines:
71 | if 'end' in i:
72 | return True
73 | return False
74 |
75 |
76 | cuda_dev = os.environ.get('CUDA_VISIBLE_DEVICE')
77 | if cuda_dev is None:
78 | cuda_dev = '0'
79 | last_run = 'lastrun_{}'.format(cuda_dev)
80 | last_run_file = last_run + '.log'
81 | last_run_id = 1
82 | # while os.path.exists(last_run_file) and not is_last_run_end(last_run_file):
83 | # last_run_file = last_run + str(last_run_id) + '.log'
84 | # last_run_id += 1
85 | # with open(last_run_file, 'w') as f:
86 | # f.write(f'start:{start_time.strftime("%m%d-%H%M%S")}\n')
87 | # f.write(f'log_dir:{log_dir}\n')
88 | # for k,v in vars(args).items():
89 | # f.write(f'{k}:{v}\n')
90 |
91 | # @atexit.register
92 | # def end_last_run():
93 | # tz_sh = tz.gettz('Asia/Shanghai')
94 | # now = datetime.now(tz=tz_sh)
95 | # with open(last_run_file, 'a') as f:
96 | # f.write(f'end:{now.strftime("%m%d-%H%M%S")}\n')
97 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/__init__.py
--------------------------------------------------------------------------------
/models/discriminator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/discriminator/__init__.py
--------------------------------------------------------------------------------
/models/discriminator/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | # of the Software, and to permit persons to whom the Software is furnished to do
8 | # so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 | #
21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
23 | # of the code.
24 | import MinkowskiEngine as ME
25 | import torch.nn as nn
26 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
27 |
28 |
29 | class ResNetBase(nn.Module):
30 | BLOCK = None
31 | LAYERS = ()
32 | INIT_DIM = 64
33 | PLANES = (64, 128, 256, 512)
34 |
35 | def __init__(self, in_channels, out_channels, D=3):
36 | nn.Module.__init__(self)
37 | self.D = D
38 | assert self.BLOCK is not None
39 |
40 | self.network_initialization(in_channels, out_channels, D)
41 | self.weight_initialization()
42 |
43 | def network_initialization(self, in_channels, out_channels, D):
44 |
45 | self.inplanes = self.INIT_DIM
46 | self.conv1 = nn.Sequential(
47 | ME.MinkowskiConvolution(
48 | in_channels, self.inplanes, kernel_size=3, stride=2, dimension=D
49 | ),
50 | ME.MinkowskiBatchNorm(self.inplanes),
51 | ME.MinkowskiReLU(inplace=True),
52 | ME.MinkowskiMaxPooling(kernel_size=2, stride=2, dimension=D),
53 | )
54 |
55 | self.layer1 = self._make_layer(
56 | self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2
57 | )
58 | self.layer2 = self._make_layer(
59 | self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2
60 | )
61 | self.layer3 = self._make_layer(
62 | self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2
63 | )
64 | self.layer4 = self._make_layer(
65 | self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2
66 | )
67 |
68 | self.conv5 = nn.Sequential(
69 | ME.MinkowskiDropout(),
70 | ME.MinkowskiConvolution(
71 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D
72 | ),
73 | ME.MinkowskiBatchNorm(self.inplanes),
74 | ME.MinkowskiGELU(),
75 | )
76 |
77 | self.glob_pool = ME.MinkowskiGlobalMaxPooling()
78 |
79 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
80 |
81 | def weight_initialization(self):
82 | for m in self.modules():
83 | if isinstance(m, ME.MinkowskiConvolution):
84 | ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")
85 |
86 | if isinstance(m, ME.MinkowskiBatchNorm):
87 | nn.init.constant_(m.bn.weight, 1)
88 | nn.init.constant_(m.bn.bias, 0)
89 |
90 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1):
91 | downsample = None
92 | if stride != 1 or self.inplanes != planes * block.expansion:
93 | downsample = nn.Sequential(
94 | ME.MinkowskiConvolution(
95 | self.inplanes,
96 | planes * block.expansion,
97 | kernel_size=1,
98 | stride=stride,
99 | dimension=self.D,
100 | ),
101 | ME.MinkowskiBatchNorm(planes * block.expansion),
102 | )
103 | layers = []
104 | layers.append(
105 | block(
106 | self.inplanes,
107 | planes,
108 | stride=stride,
109 | dilation=dilation,
110 | downsample=downsample,
111 | dimension=self.D,
112 | )
113 | )
114 | self.inplanes = planes * block.expansion
115 | for i in range(1, blocks):
116 | layers.append(
117 | block(
118 | self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D
119 | )
120 | )
121 |
122 | return nn.Sequential(*layers)
123 |
124 | def forward(self, x):
125 | x = self.conv1(x)
126 | x = self.layer1(x)
127 | x = self.layer2(x)
128 | x = self.layer3(x)
129 | x = self.layer4(x)
130 | x = self.conv5(x)
131 | x = self.glob_pool(x)
132 | return self.final(x)
133 |
134 |
135 | class ResNet14(ResNetBase):
136 | BLOCK = BasicBlock
137 | LAYERS = (1, 1, 1, 1)
138 |
139 |
140 | class ResNet18(ResNetBase):
141 | BLOCK = BasicBlock
142 | LAYERS = (2, 2, 2, 2)
143 |
144 |
145 | class ResNet34(ResNetBase):
146 | BLOCK = BasicBlock
147 | LAYERS = (3, 4, 6, 3)
148 |
149 |
150 | class ResNet50(ResNetBase):
151 | BLOCK = Bottleneck
152 | LAYERS = (3, 4, 6, 3)
153 |
154 |
155 | class ResNet101(ResNetBase):
156 | BLOCK = Bottleneck
157 | LAYERS = (3, 4, 23, 3)
158 |
159 |
160 | class ResFieldNetBase(ResNetBase):
161 | def network_initialization(self, in_channels, out_channels, D):
162 | field_ch = 32
163 | field_ch2 = 64
164 | self.field_network = nn.Sequential(
165 | ME.MinkowskiSinusoidal(in_channels, field_ch),
166 | ME.MinkowskiBatchNorm(field_ch),
167 | ME.MinkowskiReLU(inplace=True),
168 | ME.MinkowskiLinear(field_ch, field_ch),
169 | ME.MinkowskiBatchNorm(field_ch),
170 | ME.MinkowskiReLU(inplace=True),
171 | ME.MinkowskiToSparseTensor(),
172 | )
173 | self.field_network2 = nn.Sequential(
174 | ME.MinkowskiSinusoidal(field_ch + in_channels, field_ch2),
175 | ME.MinkowskiBatchNorm(field_ch2),
176 | ME.MinkowskiReLU(inplace=True),
177 | ME.MinkowskiLinear(field_ch2, field_ch2),
178 | ME.MinkowskiBatchNorm(field_ch2),
179 | ME.MinkowskiReLU(inplace=True),
180 | ME.MinkowskiToSparseTensor(),
181 | )
182 |
183 | ResNetBase.network_initialization(self, field_ch2, out_channels, D)
184 |
185 | def forward(self, x):
186 | otensor = self.field_network(x)
187 | otensor2 = self.field_network2(otensor.cat_slice(x))
188 | return ResNetBase.forward(self, otensor2)
189 |
190 |
191 | class ResFieldNet14(ResFieldNetBase):
192 | BLOCK = BasicBlock
193 | LAYERS = (1, 1, 1, 1)
194 |
195 |
196 | class ResFieldNet18(ResFieldNetBase):
197 | BLOCK = BasicBlock
198 | LAYERS = (2, 2, 2, 2)
199 |
200 |
201 | class ResFieldNet34(ResFieldNetBase):
202 | BLOCK = BasicBlock
203 | LAYERS = (3, 4, 6, 3)
204 |
205 |
206 | class ResFieldNet50(ResFieldNetBase):
207 | BLOCK = Bottleneck
208 | LAYERS = (3, 4, 6, 3)
209 |
210 |
211 | class ResFieldNet101(ResFieldNetBase):
212 | BLOCK = Bottleneck
213 | LAYERS = (3, 4, 23, 3)
214 |
215 |
216 | if __name__ == "__main__":
217 | criterion = nn.CrossEntropyLoss()
218 | net = ResNet14(in_channels=3, out_channels=5, D=2)
219 | print(net)
220 |
--------------------------------------------------------------------------------
/models/discriminator/resnetD.py:
--------------------------------------------------------------------------------
1 | import MinkowskiEngine as ME
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | from MinkowskiEngine.modules.resnet_block import BasicBlock
6 |
7 | from misc.log import log_string
8 | from models.discriminator.resnet import ResNetBase
9 |
10 |
11 | class ResNet14(ResNetBase):
12 | BLOCK = BasicBlock
13 | LAYERS = (1, 1, 1, 1)
14 | INIT_DIM = 32
15 | PLANES = (32, 64, 128, 128)
16 |
17 |
18 | class ResnetDor(torch.nn.Module):
19 | def __init__(self, planes=(32, 64, 64), n_layers=1, D=3):
20 | super().__init__()
21 | self.planes = planes
22 | self.n_layers = n_layers
23 | self.D = D
24 | self.network_initialization()
25 |
26 | def network_initialization(self):
27 |
28 | layers = len(self.planes)
29 | self.convs = nn.ModuleList()
30 | self.conv1 = nn.Sequential(
31 | ME.MinkowskiConvolution(self.planes[0], self.planes[0], kernel_size=2, stride=2, dimension=self.D),
32 | ME.MinkowskiBatchNorm(self.planes[0]),
33 | ME.MinkowskiReLU(inplace=True))
34 | self.convs.append(self.conv1)
35 | for i in range(layers - 1):
36 | self.convs.append(nn.Sequential(
37 | ME.MinkowskiConvolution(self.planes[i] * 2, self.planes[i + 1], kernel_size=2, stride=2,
38 | dimension=self.D),
39 | ME.MinkowskiBatchNorm(self.planes[i + 1]),
40 | ME.MinkowskiReLU(inplace=True)))
41 |
42 | self.pool = ME.MinkowskiGlobalAvgPooling()
43 |
44 | self.fc = []
45 | for j in range(self.n_layers):
46 | self.fc += [ME.MinkowskiLinear(self.planes[layers - 1], self.planes[layers - 1], bias=True)]
47 | self.fc += [ME.MinkowskiReLU(inplace=True)]
48 | self.fc += [ME.MinkowskiLinear(self.planes[layers - 1], 1, bias=True)]
49 | self.fc = nn.Sequential(*self.fc)
50 |
51 | # 32,32,64,64
52 | def forward(self, mid_feature):
53 | last_layer = None
54 | for i, conv in enumerate(self.convs):
55 | layer = mid_feature[i] if last_layer is None else ME.cat(last_layer, mid_feature[i])
56 | layer = conv(layer)
57 | last_layer = layer
58 | layer = self.pool(layer)
59 | pred = self.fc(layer)
60 |
61 | return pred.F
62 |
63 | def print_info(self):
64 |
65 | para = sum([np.prod(list(p.size())) for p in self.parameters()])
66 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000))
67 |
68 |
69 | class MeanFCDor(torch.nn.Module):
70 | def __init__(self, planes=(32, 64, 64), feature_dim=256, n_layers=3, D=3):
71 | super().__init__()
72 | self.feature_dim = feature_dim
73 | self.n_layers = n_layers
74 | self.D = D
75 | self.network_initialization()
76 | self.pool = ME.MinkowskiGlobalAvgPooling()
77 |
78 | def network_initialization(self):
79 | self.fc = []
80 | lay_dim = self.feature_dim
81 | lay_dim = 960
82 | # for j in range(self.n_layers):
83 | self.fc += [nn.Linear(960, 64, bias=True)]
84 | self.fc += [nn.ReLU(inplace=True)]
85 | self.fc += [nn.Linear(64, 64, bias=True)]
86 | self.fc += [nn.ReLU(inplace=True)]
87 | self.fc += [nn.Linear(64, 1, bias=True)]
88 | self.fc = nn.Sequential(*self.fc)
89 |
90 | # 32,32,64,64
91 | def forward(self, mid_feature):
92 | mid_layers = []
93 | for i in range(len(mid_feature) - 1):
94 | mid_layers.append(self.pool(mid_feature[i]).F)
95 | mid_layers.append(mid_feature[-1])
96 | mid_layers = torch.cat(mid_layers, dim=1)
97 | pred = self.fc(mid_layers)
98 |
99 | return pred
100 |
101 | def print_info(self):
102 | para = sum([np.prod(list(p.size())) for p in self.parameters()])
103 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000))
104 |
105 |
106 | class FCDor(torch.nn.Module):
107 | def __init__(self, feature_dim=256, n_layers=3, D=3):
108 | super().__init__()
109 | self.feature_dim = feature_dim
110 | self.n_layers = n_layers
111 | self.D = D
112 | self.network_initialization()
113 |
114 | def network_initialization(self):
115 | self.fc = []
116 | lay_dim = self.feature_dim
117 | for j in range(self.n_layers):
118 | self.fc += [nn.Linear(lay_dim, lay_dim // (j + 1), bias=True)]
119 | self.fc += [nn.ReLU(inplace=True)]
120 | lay_dim = lay_dim // (j + 1)
121 | self.fc += [nn.Linear(lay_dim, 1, bias=True)]
122 | self.fc = nn.Sequential(*self.fc)
123 |
124 | # 32,32,64,64
125 | def forward(self, mid_feature):
126 | # layer = mid_feature
127 | pred = self.fc(mid_feature)
128 |
129 | return pred
130 |
131 | def print_info(self):
132 | para = sum([np.prod(list(p.size())) for p in self.parameters()])
133 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000))
134 |
135 |
136 | class ResnetFDor(torch.nn.Module):
137 | def __init__(self, planes=(32, 64, 64), n_layers=1, D=3):
138 | super().__init__()
139 | self.planes = planes
140 | self.n_layers = n_layers
141 | self.D = D
142 | self.network_initialization()
143 |
144 | def network_initialization(self):
145 |
146 | layers = len(self.planes)
147 | self.convs = nn.ModuleList()
148 | self.conv1 = nn.Sequential(
149 | nn.Conv1d(self.planes[0], self.planes[0], kernel_size=2, stride=2),
150 | nn.BatchNorm1d(self.planes[0]),
151 | nn.ReLU(inplace=True))
152 | self.convs.append(self.conv1)
153 | for i in range(layers - 1):
154 | self.convs.append(nn.Sequential(
155 | nn.Conv1d(self.planes[i] * 2, self.planes[i + 1], kernel_size=2, stride=2),
156 | nn.BatchNorm1d(self.planes[i + 1]),
157 | nn.ReLU(inplace=True)))
158 |
159 | self.pool = lambda x: torch.mean(x, dim=-1, keepdim=True)
160 |
161 | self.fc = []
162 | for j in range(self.n_layers):
163 | self.fc += [nn.Linear(self.planes[layers - 1], self.planes[layers - 1], bias=True)]
164 | self.fc += [nn.ReLU(inplace=True)]
165 | self.fc += [nn.Linear(self.planes[layers - 1], 1, bias=True)]
166 | self.fc = nn.Sequential(*self.fc)
167 | pass
168 |
169 | # 32,32,64,64
170 | def forward(self, mid_feature):
171 | last_layer = None
172 | for i, conv in enumerate(self.convs):
173 | layer = mid_feature[i] if last_layer is None else torch.cat([last_layer, mid_feature[i]], dim=2)
174 | layer = conv(layer)
175 | last_layer = layer
176 | layer = self.pool(layer)
177 | pred = self.fc(layer)
178 |
179 | return pred
180 |
181 | def print_info(self):
182 |
183 | para = sum([np.prod(list(p.size())) for p in self.parameters()])
184 | log_string('Model {} : params: {:4f}M'.format(self._get_name(), para * 4 / 1000 / 1000))
185 |
--------------------------------------------------------------------------------
/models/minkloc3d/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/minkloc3d/__init__.py
--------------------------------------------------------------------------------
/models/minkloc3d/minkfpn.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 | import MinkowskiEngine as ME
6 | import torch.nn as nn
7 | from MinkowskiEngine.modules.resnet_block import BasicBlock
8 |
9 | from models.minkloc3d.resnet_mink import ResNetBase
10 | from models.minkloc3d.senet_block import SEBasicBlock
11 |
12 |
13 | class MinkFPN(ResNetBase):
14 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks
15 | def __init__(self, in_channels=1, model_params=None, block=BasicBlock):
16 | assert len(model_params.layers) == len(model_params.planes)
17 | assert 1 <= len(model_params.layers)
18 | assert 0 <= model_params.num_top_down <= len(model_params.layers)
19 | self.num_bottom_up = len(model_params.layers)
20 | self.num_top_down = model_params.num_top_down
21 | self.conv0_kernel_size = model_params.conv0_kernel_size
22 | self.block = block
23 | self.layers = model_params.layers
24 | self.planes = model_params.planes
25 | self.lateral_dim = model_params.feature_size
26 | self.init_dim = model_params.planes[0]
27 | ResNetBase.__init__(self, in_channels, model_params.feature_size, D=3)
28 |
29 | def network_initialization(self, in_channels, out_channels, D):
30 | assert len(self.layers) == len(self.planes)
31 | assert len(self.planes) == self.num_bottom_up
32 |
33 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2
34 | self.bn = nn.ModuleList() # Bottom-up BatchNorms
35 | self.blocks = nn.ModuleList() # Bottom-up blocks
36 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions
37 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections
38 |
39 | # The first convolution is special case, with kernel size = 5
40 | self.inplanes = conv0_out = self.planes[0]
41 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size,
42 | dimension=D)
43 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
44 |
45 | for plane, layer in zip(self.planes, self.layers):
46 | self.convs.append(
47 | ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D))
48 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes))
49 | self.blocks.append(self._make_layer(self.block, plane, layer))
50 |
51 | # Lateral connections
52 | for i in range(self.num_top_down):
53 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1,
54 | stride=1, dimension=D))
55 |
56 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2,
57 | stride=2, dimension=D))
58 | # There's one more lateral connection than top-down TConv blocks
59 | if self.num_top_down < self.num_bottom_up:
60 | # Lateral connection from Conv block 1 or above
61 | self.conv1x1.append(
62 | ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1,
63 | stride=1, dimension=D))
64 | else:
65 | # Lateral connection from Con0 block
66 | self.conv1x1.append(ME.MinkowskiConvolution(conv0_out, self.lateral_dim, kernel_size=1,
67 | stride=1, dimension=D))
68 |
69 | # self.relu = ME.MinkowskiPReLU()
70 | self.relu = ME.MinkowskiReLU(inplace=True)
71 |
72 | def forward(self, x):
73 | # *** BOTTOM-UP PASS ***
74 | # First bottom-up convolution is special (with bigger stride)
75 | feature_maps = []
76 | x = self.conv0(x)
77 | x = self.bn0(x)
78 | x = self.relu(x)
79 | if self.num_top_down == self.num_bottom_up:
80 | feature_maps.append(x)
81 |
82 | # BOTTOM-UP PASS
83 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)):
84 | x = conv(x) # Decreases spatial resolution (conv stride=2)
85 | x = bn(x)
86 | x = self.relu(x)
87 | x = block(x)
88 |
89 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1:
90 | feature_maps.append(x)
91 |
92 | assert len(feature_maps) == self.num_top_down
93 |
94 | x = self.conv1x1[0](x)
95 |
96 | # TOP-DOWN PASS
97 | for ndx, tconv in enumerate(self.tconvs):
98 | x = tconv(x) # Upsample using transposed convolution
99 | x = x + self.conv1x1[ndx + 1](feature_maps[-ndx - 1])
100 |
101 | return x
102 |
103 |
104 | class MinkSEFPN(MinkFPN):
105 | def __init__(self, in_channels=1, model_params=None, block=SEBasicBlock):
106 | MinkFPN.__init__(self, in_channels=in_channels, model_params=model_params, block=block)
107 |
--------------------------------------------------------------------------------
/models/minkloc3d/minkpool.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski
2 | # Warsaw University of Technology
3 |
4 | import MinkowskiEngine as ME
5 | import torch
6 |
7 | import models.minkloc3d.pooling as pooling
8 | from models.minkloc3d.minkfpn import MinkFPN
9 |
10 |
11 | class MinkPool(torch.nn.Module):
12 | def __init__(self, model_params, in_channels):
13 | super().__init__()
14 | self.model = model_params.mink_model
15 | self.in_channels = in_channels
16 | self.feature_size = model_params.feature_size # Size of local features produced by local feature extraction block
17 | self.output_dim = model_params.output_dim # Dimensionality of the global descriptor
18 | self.backbone = MinkFPN(in_channels=in_channels, model_params=model_params)
19 | self.n_backbone_features = model_params.output_dim
20 | self.pooling = pooling.GeM()
21 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 |
23 | # [B,dims,num_points,1], [B,1,num_points,3]
24 | def __input(self, feat_map, coords_xyz=None):
25 | # coords_xyz = coords_xyz.cpu().squeeze(1)
26 | coords = []
27 | feat_maps = []
28 | for e in range(len(coords_xyz)):
29 | coords_e, inds_e = ME.utils.sparse_quantize(coords_xyz[e][0].cpu().squeeze(1), return_index=True,
30 | quantization_size=0.01)
31 | feat_map_e = feat_map[e][:, inds_e, :].squeeze(-1).transpose(-1, -2)
32 | coords.append(coords_e)
33 | feat_maps.append(feat_map_e)
34 | # coords = [ME.utils.sparse_quantize(coords_xyz[e,...], quantization_size=0.01)
35 | # for e in range(coords_xyz.shape[0])]
36 | coords = ME.utils.batched_coordinates(coords)
37 | # Assign a dummy feature equal to 1 to each point
38 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation
39 | feats = torch.vstack(feat_maps)
40 | batch = {'coords': coords, 'features': feats}
41 | return batch
42 |
43 | def forward(self, feat_map, coords_xyz=None):
44 | batch = self.__input(feat_map, coords_xyz=coords_xyz)
45 | # Coords must be on CPU, features can be on GPU - see MinkowskiEngine documentation
46 | x = ME.SparseTensor(batch['features'], coordinates=batch['coords'], device=self.device)
47 | x = self.backbone(x)
48 | # x is (num_points, n_features) tensor
49 | assert x.shape[1] == self.feature_size, 'Backbone output tensor has: {} channels. Expected: {}'.format(
50 | x.shape[1], self.feature_size)
51 | x = self.pooling(x)
52 | assert x.dim() == 2, 'Expected 2-dimensional tensor (batch_size,output_dim). Got {} dimensions.'.format(x.dim())
53 | assert x.shape[1] == self.output_dim, 'Output tensor has: {} channels. Expected: {}'.format(x.shape[1],
54 | self.output_dim)
55 | # x is (batch_size, output_dim) tensor
56 | return {'g_fea': x}
57 |
58 | def print_info(self):
59 | print('Model class: MinkPool')
60 | n_params = sum([param.nelement() for param in self.parameters()])
61 | print('Total parameters: {}'.format(n_params))
62 | n_params = sum([param.nelement() for param in self.backbone.parameters()])
63 | print('Backbone parameters: {}'.format(n_params))
64 | n_params = sum([param.nelement() for param in self.pooling.parameters()])
65 | print('Aggregation parameters: {}'.format(n_params))
66 | if hasattr(self.backbone, 'print_info'):
67 | self.backbone.print_info()
68 | if hasattr(self.pooling, 'print_info'):
69 | self.pooling.print_info()
70 |
--------------------------------------------------------------------------------
/models/minkloc3d/pooling.py:
--------------------------------------------------------------------------------
1 | # Code taken from: https://github.com/filipradenovic/cnnimageretrieval-pytorch
2 | # and ported to MinkowskiEngine by Jacek Komorowski
3 |
4 | import MinkowskiEngine as ME
5 | import torch
6 | import torch.nn as nn
7 |
8 |
9 | class MAC(nn.Module):
10 | def __init__(self):
11 | super().__init__()
12 | self.f = ME.MinkowskiGlobalMaxPooling()
13 |
14 | def forward(self, x: ME.SparseTensor):
15 | x = self.f(x)
16 | return x.F # Return (batch_size, n_features) tensor
17 |
18 |
19 | class SPoC(nn.Module):
20 | def __init__(self):
21 | super().__init__()
22 | self.f = ME.MinkowskiGlobalAvgPooling()
23 |
24 | def forward(self, x: ME.SparseTensor):
25 | x = self.f(x)
26 | return x.F # Return (batch_size, n_features) tensor
27 |
28 |
29 | class GeM(nn.Module):
30 | def __init__(self, p=3, eps=1e-6):
31 | super(GeM, self).__init__()
32 | self.p = nn.Parameter(torch.ones(1) * p)
33 | self.eps = eps
34 | self.f = ME.MinkowskiGlobalAvgPooling()
35 |
36 | def forward(self, x: ME.SparseTensor):
37 | # This implicitly applies ReLU on x (clamps negative values)
38 | temp = ME.SparseTensor(x.F.clamp(min=self.eps).pow(self.p), coordinates=x.C)
39 | temp = self.f(temp) # Apply ME.MinkowskiGlobalAvgPooling
40 | return temp.F.pow(1. / self.p) # Return (batch_size, n_features) tensor
41 |
--------------------------------------------------------------------------------
/models/minkloc3d/resnet_mink.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | # of the Software, and to permit persons to whom the Software is furnished to do
8 | # so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 | #
21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
23 | # of the code.
24 |
25 | import MinkowskiEngine as ME
26 | import torch.nn as nn
27 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
28 |
29 |
30 | class ResNetBase(nn.Module):
31 | block = None
32 | layers = ()
33 | init_dim = 64
34 | planes = (64, 128, 256, 512)
35 |
36 | # default: 1, 256
37 | def __init__(self, in_channels, out_channels, D=3):
38 | nn.Module.__init__(self)
39 | self.D = D
40 | assert self.block is not None
41 |
42 | self.network_initialization(in_channels, out_channels, D)
43 | self.weight_initialization()
44 |
45 | def network_initialization(self, in_channels, out_channels, D):
46 | self.inplanes = self.init_dim
47 | self.conv1 = ME.MinkowskiConvolution(
48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D)
49 |
50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
51 | self.relu = ME.MinkowskiReLU(inplace=True)
52 |
53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D)
54 |
55 | self.layer1 = self._make_layer(
56 | self.block, self.planes[0], self.layers[0], stride=2)
57 | self.layer2 = self._make_layer(
58 | self.block, self.planes[1], self.layers[1], stride=2)
59 | self.layer3 = self._make_layer(
60 | self.block, self.planes[2], self.layers[2], stride=2)
61 | self.layer4 = self._make_layer(
62 | self.block, self.planes[3], self.layers[3], stride=2)
63 |
64 | self.conv5 = ME.MinkowskiConvolution(
65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D)
66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes)
67 |
68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling()
69 |
70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
71 |
72 | def weight_initialization(self):
73 | for m in self.modules():
74 | if isinstance(m, ME.MinkowskiConvolution):
75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
76 |
77 | if isinstance(m, ME.MinkowskiBatchNorm):
78 | nn.init.constant_(m.bn.weight, 1)
79 | nn.init.constant_(m.bn.bias, 0)
80 |
81 | def _make_layer(self,
82 | block,
83 | planes,
84 | blocks,
85 | stride=1,
86 | dilation=1,
87 | bn_momentum=0.1):
88 | downsample = None
89 | if stride != 1 or self.inplanes != planes * block.expansion:
90 | downsample = nn.Sequential(
91 | ME.MinkowskiConvolution(
92 | self.inplanes,
93 | planes * block.expansion,
94 | kernel_size=1,
95 | stride=stride,
96 | dimension=self.D),
97 | ME.MinkowskiBatchNorm(planes * block.expansion))
98 | layers = []
99 | layers.append(
100 | block(
101 | self.inplanes,
102 | planes,
103 | stride=stride,
104 | dilation=dilation,
105 | downsample=downsample,
106 | dimension=self.D))
107 | self.inplanes = planes * block.expansion
108 | for i in range(1, blocks):
109 | layers.append(
110 | block(
111 | self.inplanes,
112 | planes,
113 | stride=1,
114 | dilation=dilation,
115 | dimension=self.D))
116 |
117 | return nn.Sequential(*layers)
118 |
119 | def forward(self, x):
120 | x = self.conv1(x)
121 | x = self.bn1(x)
122 | x = self.relu(x)
123 | x = self.pool(x)
124 |
125 | x = self.layer1(x)
126 | x = self.layer2(x)
127 | x = self.layer3(x)
128 | x = self.layer4(x)
129 |
130 | x = self.conv5(x)
131 | x = self.bn5(x)
132 | x = self.relu(x)
133 |
134 | x = self.glob_avg(x)
135 | return self.final(x)
136 |
137 |
138 | class ResNet14(ResNetBase):
139 | BLOCK = BasicBlock
140 | LAYERS = (1, 1, 1, 1)
141 |
142 |
143 | class ResNet18(ResNetBase):
144 | BLOCK = BasicBlock
145 | LAYERS = (2, 2, 2, 2)
146 |
147 |
148 | class ResNet34(ResNetBase):
149 | BLOCK = BasicBlock
150 | LAYERS = (3, 4, 6, 3)
151 |
152 |
153 | class ResNet50(ResNetBase):
154 | BLOCK = Bottleneck
155 | LAYERS = (3, 4, 6, 3)
156 |
157 |
158 | class ResNet101(ResNetBase):
159 | BLOCK = Bottleneck
160 | LAYERS = (3, 4, 23, 3)
161 |
--------------------------------------------------------------------------------
/models/minkloc3d/senet_block.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | # of the Software, and to permit persons to whom the Software is furnished to do
8 | # so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 | #
21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
23 | # of the code.
24 | import MinkowskiEngine as ME
25 | import torch.nn as nn
26 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
27 |
28 |
29 | class SELayer(nn.Module):
30 |
31 | def __init__(self, channel, reduction=16, dimension=-1):
32 | # Global coords does not require coords_key
33 | super(SELayer, self).__init__()
34 | self.fc = nn.Sequential(
35 | ME.MinkowskiLinear(channel, channel // reduction),
36 | ME.MinkowskiReLU(inplace=True),
37 | ME.MinkowskiLinear(channel // reduction, channel),
38 | ME.MinkowskiSigmoid())
39 | self.pooling = ME.MinkowskiGlobalPooling()
40 | self.broadcast_mul = ME.MinkowskiBroadcastMultiplication()
41 |
42 | def forward(self, x):
43 | y = self.pooling(x)
44 | y = self.fc(y)
45 | return self.broadcast_mul(x, y)
46 |
47 |
48 | class SEBasicBlock(BasicBlock):
49 |
50 | def __init__(self,
51 | inplanes,
52 | planes,
53 | stride=1,
54 | dilation=1,
55 | downsample=None,
56 | reduction=16,
57 | dimension=-1):
58 | super(SEBasicBlock, self).__init__(
59 | inplanes,
60 | planes,
61 | stride=stride,
62 | dilation=dilation,
63 | downsample=downsample,
64 | dimension=dimension)
65 | self.se = SELayer(planes, reduction=reduction, dimension=dimension)
66 |
67 | def forward(self, x):
68 | residual = x
69 |
70 | out = self.conv1(x)
71 | out = self.norm1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.norm2(out)
76 | out = self.se(out)
77 |
78 | if self.downsample is not None:
79 | residual = self.downsample(x)
80 |
81 | out += residual
82 | out = self.relu(out)
83 |
84 | return out
85 |
86 |
87 | class SEBottleneck(Bottleneck):
88 |
89 | def __init__(self,
90 | inplanes,
91 | planes,
92 | stride=1,
93 | dilation=1,
94 | downsample=None,
95 | dimension=3,
96 | reduction=16):
97 | super(SEBottleneck, self).__init__(
98 | inplanes,
99 | planes,
100 | stride=stride,
101 | dilation=dilation,
102 | downsample=downsample,
103 | dimension=dimension)
104 | self.se = SELayer(planes * self.expansion, reduction=reduction, dimension=dimension)
105 |
106 | def forward(self, x):
107 | residual = x
108 |
109 | out = self.conv1(x)
110 | out = self.norm1(out)
111 | out = self.relu(out)
112 |
113 | out = self.conv2(out)
114 | out = self.norm2(out)
115 | out = self.relu(out)
116 |
117 | out = self.conv3(out)
118 | out = self.norm3(out)
119 | out = self.se(out)
120 |
121 | if self.downsample is not None:
122 | residual = self.downsample(x)
123 |
124 | out += residual
125 | out = self.relu(out)
126 |
127 | return out
128 |
--------------------------------------------------------------------------------
/models/model_factory.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 | import os
6 |
7 | import numpy as np
8 | import torch
9 |
10 | from misc.log import log_string
11 | from models.discriminator.resnetD import FCDor
12 | from models.vcrnet.vcrnet import VCRNet
13 | from models.vlpdnet.vlpdnet import vLPDNet
14 |
15 |
16 | def model_factory(params):
17 | if 'vLPDNet' in params.model_params.model:
18 | model = vLPDNet(params)
19 | else:
20 | raise NotImplementedError('Model not implemented: {}'.format(params.model_params.model))
21 |
22 | if hasattr(model, 'print_info'):
23 | model.print_info()
24 |
25 | para = sum([np.prod(list(p.size())) for p in model.parameters()])
26 | log_string('Model {} : params: {:4f}M'.format(model._get_name(), para * 4 / 1000 / 1000))
27 |
28 | if params.domain_adapt:
29 | if params.d_model == 'FCDor':
30 | d_model = FCDor(feature_dim=params.model_params.feature_size, n_layers=3, D=3)
31 | else:
32 | raise NotImplementedError('d_model not implemented: {}'.format(params.d_model))
33 |
34 | if hasattr(d_model, 'print_info'):
35 | d_model.print_info()
36 | else:
37 | d_model = None
38 |
39 | vcr_model = None
40 | if params.is_register:
41 | vcr_model = VCRNet(params.model_params)
42 | if params.model_params.reg_model_path != "":
43 | checkpoint_dict = torch.load(params.model_params.reg_model_path, map_location=torch.device('cpu'))
44 | vcr_model.load_state_dict(checkpoint_dict, strict=True)
45 | log_string('load vcr_model with {}'.format(params.model_params.reg_model_path))
46 |
47 | # Move the model to the proper device before configuring the optimizer
48 | if torch.cuda.is_available():
49 | device = "cuda"
50 | model.to(device)
51 | d_model = d_model.to(device) if params.domain_adapt else None
52 | vcr_model = vcr_model.to(device) if params.is_register else None
53 | if torch.cuda.device_count() > 1:
54 | vcr_model = torch.nn.DataParallel(vcr_model)
55 | else:
56 | device = "cpu"
57 |
58 | log_string('Model device: {}'.format(device))
59 |
60 | return model, device, d_model, vcr_model
61 |
62 |
63 | def load_weights(weights, model, optimizer=None, scheduler=None):
64 | starting_epoch = 0
65 |
66 | if weights is None or weights == "":
67 | return starting_epoch
68 |
69 | if not os.path.exists(weights):
70 | log_string("error: checkpoint not exists!")
71 |
72 | checkpoint_dict = torch.load(weights, map_location=torch.device('cpu'))
73 | if 'state_dict' in checkpoint_dict.keys():
74 | saved_state_dict = checkpoint_dict['state_dict']
75 | else:
76 | saved_state_dict = checkpoint_dict
77 |
78 | model_state_dict = model.state_dict() # 获取已创建net的state_dict
79 | saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_state_dict}
80 | model_state_dict.update(saved_state_dict)
81 | model.load_state_dict(model_state_dict, strict=True)
82 |
83 | if 'epoch' in checkpoint_dict.keys():
84 | starting_epoch = checkpoint_dict['epoch'] + 1
85 |
86 | if 'optimizer' in checkpoint_dict.keys() and optimizer is not None:
87 | optimizer.load_state_dict(checkpoint_dict['optimizer'])
88 |
89 | if 'scheduler' in checkpoint_dict.keys() and scheduler is not None:
90 | scheduler.load_state_dict(checkpoint_dict['scheduler'])
91 |
92 | log_string("load checkpoint " + weights + " starting_epoch: " + str(starting_epoch))
93 | torch.cuda.empty_cache()
94 |
95 | return starting_epoch
96 |
--------------------------------------------------------------------------------
/models/vcrnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/vcrnet/__init__.py
--------------------------------------------------------------------------------
/models/vcrnet/vcrnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from misc.utils import ModelParams
6 | from models.vcrnet.transformer import Transformer
7 | from models.vlpdnet.lpdnet_model import LPDNet, LPDNetOrign
8 |
9 |
10 | class PoseSolver(nn.Module):
11 | def __init__(self, model_params, origin=False):
12 | super(PoseSolver, self).__init__()
13 |
14 | self.pointer = Transformer(model_params=model_params)
15 |
16 | self.head = EPCOR(model_params=model_params)
17 |
18 | self.svd = SVD_Weighted()
19 |
20 | self.loss = torch.nn.functional.mse_loss
21 |
22 | self.origin = origin
23 |
24 | def forward(self, src, tgt, src_embedding, tgt_embedding, positive_T, svd=True):
25 | # input: #[B,1,num,3] [B,1,num,3] [B,1,C,num,1] [B,posi_num,C,num,1]
26 | # expected: [B,C,num]
27 | batch, posi_num, num_points, C = tgt.size()
28 | # src = src.repeat(1,posi_num,1,1).transpose(-2,-1).contiguous().view(-1,C,num_points)
29 | src = src.transpose(-2, -1).contiguous().view(-1, C, num_points)
30 | tgt = tgt.transpose(-2, -1).contiguous().view(-1, C, num_points)
31 | C = tgt_embedding.size(2)
32 | src_embedding = src_embedding.squeeze(-1).repeat(1, posi_num, 1, 1).contiguous().view(-1, C, num_points)
33 | tgt_embedding = tgt_embedding.squeeze(-1).contiguous().view(-1, C, num_points)
34 |
35 | src_embedding_p, tgt_embedding_p = self.pointer(src_embedding, tgt_embedding)
36 | src_embedding = src_embedding + src_embedding_p
37 | tgt_embedding = tgt_embedding + tgt_embedding_p
38 |
39 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.head(src_embedding, tgt_embedding, src, tgt)
40 |
41 | if svd:
42 | with torch.no_grad():
43 | rotation_ab_pred, translation_ab_pred = self.svd(src, src_corr, src_weight)
44 |
45 | if self.origin:
46 | b, _, num = tgt.size()
47 | mask = outlier_src_mask.contiguous().view(b, 1, num).repeat(1, 3, 1)
48 | srcK = torch.masked_fill(src, mask, 0)
49 | src_corrK = torch.masked_fill(src_corr, mask, 0)
50 | return {'rotation_ab_pred': rotation_ab_pred, 'translation_ab_pred': translation_ab_pred,
51 | 'srcK': srcK, 'src_corrK': src_corrK}
52 |
53 | if positive_T is None:
54 | return rotation_ab_pred, translation_ab_pred
55 |
56 | if svd:
57 | loss = self.cor_loss(src, src_corr, outlier_src_mask, positive_T)
58 |
59 | rotation_ab, translation_ab, loss_pose = self.pose_loss(rotation_ab_pred, translation_ab_pred, positive_T)
60 |
61 | return {'rotation_ab_pred': rotation_ab_pred, 'translation_ab_pred': translation_ab_pred,
62 | 'rotation_ab': rotation_ab, 'translation_ab': translation_ab,
63 | 'loss_point': loss, 'loss_pose': loss_pose, 'mask_tgt': mask_tgt}
64 | else:
65 | return {'rotation_ab_pred': None, 'translation_ab_pred': None,
66 | 'rotation_ab': None, 'translation_ab': None,
67 | 'loss_point': None, 'loss_pose': None, 'mask_tgt': mask_tgt}
68 |
69 | def pose_loss(self, rotation_ab_pred, translation_ab_pred, gt_T):
70 |
71 | gt_T = gt_T.view(-1, 4, 4)
72 | rotation_ab = gt_T[:, :3, :3]
73 | translation_ab = gt_T[:, :3, 3]
74 | batch_size = translation_ab.size(0)
75 | identity = torch.eye(3).cuda().unsqueeze(0).repeat(batch_size, 1, 1)
76 | loss_pose = F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \
77 | + F.mse_loss(translation_ab_pred, translation_ab)
78 | return rotation_ab, translation_ab, loss_pose
79 |
80 | def cor_loss(self, src, src_corr, outlier_src_mask, gt_T):
81 |
82 | gt_T = gt_T.view(-1, 4, 4)
83 | rotation_ab = gt_T[:, :3, :3]
84 | translation_ab = gt_T[:, :3, 3]
85 |
86 | transformed_srcK = torch.matmul(rotation_ab, src) + translation_ab.unsqueeze(2)
87 |
88 | b, _, num = transformed_srcK.size()
89 | mask = outlier_src_mask.contiguous().view(b, 1, num).repeat(1, 3, 1)
90 |
91 | transformed_srcK = torch.masked_fill(transformed_srcK, mask, 0)
92 | src_corrK = torch.masked_fill(src_corr, mask, 0)
93 |
94 | loss = self.loss(transformed_srcK, src_corrK)
95 |
96 | return loss
97 |
98 |
99 | class VCRNet(nn.Module):
100 | def __init__(self, model_params: ModelParams):
101 | super(VCRNet, self).__init__()
102 | if model_params.featnet == "lpdnet":
103 | self.emb_nn = LPDNet(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels)
104 | elif model_params.featnet.lower() == "lpdnetorigin":
105 | self.emb_nn = LPDNetOrign(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels)
106 | else:
107 | print("featnet error")
108 | self.solver = PoseSolver(model_params, origin=True)
109 |
110 | def forward(self, *input):
111 | src = input[0]
112 | tgt = input[1]
113 |
114 | src_embedding = self.emb_nn(src)
115 | tgt_embedding = self.emb_nn(tgt)
116 |
117 | src = src.unsqueeze(1)
118 | tgt = tgt.unsqueeze(1)
119 | src_embedding = src_embedding.unsqueeze(1)
120 | tgt_embedding = tgt_embedding.unsqueeze(1)
121 |
122 | # input: #[B,1,num,3] [B,posi_num,num,3] [B,1,C,num,1] [B,posi_num,C,num,1]
123 | reg_dict = self.solver(src, tgt, src_embedding, tgt_embedding, None)
124 | rotation_ab = reg_dict['rotation_ab_pred']
125 | translation_ab = reg_dict['translation_ab_pred']
126 | rotation_ba = rotation_ab.transpose(2, 1).contiguous()
127 | translation_ba = -torch.matmul(rotation_ba, translation_ab.unsqueeze(2)).squeeze(2)
128 | srcK = reg_dict['srcK']
129 | src_corrK = reg_dict['src_corrK']
130 |
131 | return srcK, src_corrK, rotation_ab, translation_ab, rotation_ba, translation_ba
132 |
133 |
134 | class SVD_Weighted(nn.Module):
135 | def __init__(self):
136 | super(SVD_Weighted, self).__init__()
137 |
138 | def forward(self, src, src_corr, weights: torch.Tensor = None):
139 | _EPS = 1e-5 # To prevent division by zero
140 | a = src.transpose(-1, -2)
141 | b = src_corr.transpose(-1, -2)
142 | batch_size, num_src, dims = a.size()
143 |
144 | if weights is None:
145 | weights = torch.ones(size=(batch_size, num_src), device=torch.device('cuda')) / num_src
146 |
147 | """Compute rigid transforms between two point sets
148 |
149 | Args:
150 | a (torch.Tensor): (B, M, 3) points
151 | b (torch.Tensor): (B, N, 3) points
152 | weights (torch.Tensor): (B, M)
153 |
154 | Returns:
155 | Transform T (B, 3, 4) to get from a to b, i.e. T*a = b
156 | """
157 |
158 | weights_normalized = weights[..., None] / (torch.sum(weights[..., None], dim=1, keepdim=True) + _EPS)
159 | centroid_a = torch.sum(a * weights_normalized, dim=1)
160 | centroid_b = torch.sum(b * weights_normalized, dim=1)
161 | a_centered = a - centroid_a[:, None, :]
162 | b_centered = b - centroid_b[:, None, :]
163 | cov = a_centered.transpose(-2, -1) @ (b_centered * weights_normalized)
164 |
165 | # Compute rotation using Kabsch algorithm. Will compute two copies with +/-V[:,:3]
166 | # and choose based on determinant to avoid flips
167 | u, s, v = torch.svd(cov, some=False, compute_uv=True)
168 | rot_mat_pos = v @ u.transpose(-1, -2)
169 | v_neg = v.clone()
170 | v_neg[:, :, 2] *= -1
171 | rot_mat_neg = v_neg @ u.transpose(-1, -2)
172 | rot_mat = torch.where(torch.det(rot_mat_pos)[:, None, None] > 0, rot_mat_pos, rot_mat_neg)
173 | assert torch.all(torch.det(rot_mat) > 0)
174 |
175 | # Compute translation (uncenter centroid)
176 | translation = -rot_mat @ centroid_a[:, :, None] + centroid_b[:, :, None]
177 |
178 | return rot_mat, translation.view(batch_size, 3)
179 |
180 |
181 | class EPCOR(nn.Module):
182 | """Generate the VCP points based on K most similar points
183 |
184 | """
185 |
186 | def __init__(self, model_params):
187 | super(EPCOR, self).__init__()
188 | self.model_params = model_params
189 |
190 | def forward(self, *input):
191 |
192 | src_emb = input[0]
193 | tgt_emb = input[1]
194 | src = input[2]
195 | tgt = input[3]
196 |
197 | if self.training:
198 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.getCopairALL(src, src_emb, tgt, tgt_emb)
199 | else:
200 | src_corr, src_weight, outlier_src_mask, mask_tgt = self.selectCom_adap(src, src_emb, tgt, tgt_emb)
201 |
202 | return src_corr, src_weight, outlier_src_mask, mask_tgt
203 |
204 | def get_sparse_w(self, pairwise_distance, tau, K):
205 | batch_size, num_points, num_points_t = pairwise_distance.size()
206 | tgt_K = int(num_points_t * tau)
207 | src_K = int(num_points * tau)
208 |
209 | scoresSoftCol = torch.softmax(pairwise_distance, dim=2) # [b,num,num]
210 | scoresColSum = torch.sum(scoresSoftCol, dim=1, keepdim=True)
211 | topmin = scoresColSum.topk(k=tgt_K, dim=-1, largest=False)[0][..., -1].unsqueeze(-1)
212 | mask_tgt = scoresColSum < topmin
213 | # a = torch.sum(mask_tgt, dim=-1)
214 |
215 | scoresSoftRow = torch.softmax(pairwise_distance, dim=1) # [b,num,num]
216 | scoresRowSum = torch.sum(scoresSoftRow, dim=2, keepdim=True)
217 | topmin = scoresRowSum.topk(k=src_K, dim=1, largest=False)[0][:, -1, :].unsqueeze(-1)
218 | mask_src = scoresRowSum < topmin
219 | # mask_src = scoresRowSum < tau
220 |
221 | mask = mask_src + mask_tgt
222 |
223 | s = (torch.sum(mask_src, dim=-2) + torch.sum(mask_tgt, dim=-1)) / torch.full(size=(batch_size, 1),
224 | fill_value=num_points + num_points_t).to(
225 | mask_tgt.device).type_as(pairwise_distance)
226 |
227 | topk = scoresSoftCol.topk(k=K, dim=2)[0][..., -1].unsqueeze(-1)
228 | mask = mask + scoresSoftCol < topk
229 |
230 | src_tgt_weight_sparse = torch.masked_fill(scoresSoftCol, mask, 0)
231 | scoresColSum_sparse = torch.sum(src_tgt_weight_sparse, dim=-1, keepdim=True) # [B, num, 1]
232 | scoresColSum_sparse = torch.masked_fill(scoresColSum_sparse, scoresColSum_sparse < 1e-5, 1e-5) # 防止除以0
233 | src_tgt_weight = torch.div(src_tgt_weight_sparse, scoresColSum_sparse)
234 |
235 | # src_tgt_weight = torch.softmax(src_tgt_weight, dim=2)
236 |
237 | # 计算
238 | val_sum = torch.sum(src_tgt_weight, dim=-1, keepdim=True) # [B, num, 1]
239 | val_sum_inlier = torch.masked_fill(val_sum, mask_src, 0.0).squeeze(-1)
240 | sum_ = torch.sum(val_sum_inlier, dim=[-1], keepdim=True)
241 | src_weight = torch.div(val_sum_inlier, sum_)
242 |
243 | return s, src_weight, mask_src, mask_tgt, src_tgt_weight
244 |
245 | def selectCom_adap(self, src, src_emb, tgt, tgt_emb, tau=0.3):
246 |
247 | batch_size, _, num_points = src.size()
248 | batch_size, _, num_points_t = tgt.size()
249 |
250 | inner = -2 * torch.matmul(src_emb.transpose(2, 1).contiguous(), tgt_emb)
251 | xx = torch.sum(src_emb ** 2, dim=1, keepdim=True).transpose(2, 1).contiguous()
252 | yy = torch.sum(tgt_emb ** 2, dim=1, keepdim=True)
253 |
254 | pairwise_distance = -xx - inner
255 | src_tgt_weight = pairwise_distance - yy
256 |
257 | s, src_weight, mask_src, mask_tgt, src_tgt_weight = self.get_sparse_w(src_tgt_weight, tau, K=1)
258 |
259 | src_corr = torch.matmul(tgt, src_tgt_weight.transpose(2, 1).contiguous())
260 |
261 | return src_corr, src_weight, mask_src, mask_tgt
262 |
263 | def getCopairALL(self, src, src_emb, tgt, tgt_emb):
264 |
265 | batch_size, n_dims, num_points = src.size()
266 | # Calculate the distance matrix
267 | inner = -2 * torch.matmul(src_emb.transpose(2, 1).contiguous(), tgt_emb)
268 | xx = torch.sum(src_emb ** 2, dim=1, keepdim=True).transpose(2, 1).contiguous()
269 | yy = torch.sum(tgt_emb ** 2, dim=1, keepdim=True)
270 |
271 | pairwise_distance = -xx - inner
272 | pairwise_distance = pairwise_distance - yy
273 |
274 | scores = torch.softmax(pairwise_distance, dim=2) # [b,num,num]
275 | src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
276 |
277 | src_weight = torch.ones(size=(batch_size, num_points), device=src_corr.device) / num_points
278 |
279 | outlier_src_mask = torch.full((batch_size, num_points, 1), False, device=src_corr.device, dtype=torch.bool)
280 | mask_tgt = torch.full((batch_size, num_points, 1), False, device=src_corr.device, dtype=torch.bool)
281 |
282 | return src_corr, src_weight, outlier_src_mask, mask_tgt
283 |
284 | def return_mask(self):
285 | return self.mask_src, self.mask_tgt
286 |
--------------------------------------------------------------------------------
/models/vlpdnet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/models/vlpdnet/__init__.py
--------------------------------------------------------------------------------
/models/vlpdnet/lpdnet_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from misc.point_utils import get_graph_feature_Origin, get_graph_feature, knn
6 |
7 |
8 | class LPDNetOrign(nn.Module):
9 | def __init__(self, emb_dims=512, channels=[64, 64, 128]):
10 | super(LPDNetOrign, self).__init__()
11 | self.act_f = nn.LeakyReLU(inplace=True, negative_slope=1e-2)
12 | self.k = 20
13 | self.emb_dims = emb_dims
14 | self.convDG1 = nn.Sequential(nn.Conv2d(channels[0] * 2, channels[0], kernel_size=1, bias=False),
15 | nn.BatchNorm2d(channels[0]), self.act_f)
16 | self.convDG2 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False),
17 | nn.BatchNorm2d(channels[0]), self.act_f)
18 | self.convSN1 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False),
19 | nn.BatchNorm2d(channels[0]), self.act_f)
20 | self.convSN2 = nn.Sequential(nn.Conv2d(channels[0], channels[0], kernel_size=1, bias=False),
21 | nn.BatchNorm2d(channels[0]), self.act_f)
22 | self.conv1_lpd = nn.Sequential(nn.Conv1d(3, channels[0], kernel_size=1, bias=False),
23 | nn.BatchNorm1d(channels[0]), self.act_f)
24 | self.conv2_lpd = nn.Sequential(nn.Conv1d(channels[0], channels[0], kernel_size=1, bias=False),
25 | nn.BatchNorm1d(channels[0]), self.act_f)
26 | self.conv3_lpd = nn.Sequential(nn.Conv1d(channels[0], channels[1], kernel_size=1, bias=False),
27 | nn.BatchNorm1d(channels[1]), self.act_f)
28 | self.conv4_lpd = nn.Sequential(nn.Conv1d(channels[1], channels[2], kernel_size=1, bias=False),
29 | nn.BatchNorm1d(channels[2]), self.act_f)
30 | self.conv5_lpd = nn.Sequential(nn.Conv1d(channels[2], self.emb_dims, kernel_size=1, bias=False),
31 | nn.BatchNorm1d(self.emb_dims), self.act_f)
32 |
33 | # input x: # [B,1,num,num_dims]
34 | # output x: # [b,emb_dims,num,1]
35 | def forward(self, x):
36 | x = torch.squeeze(x, dim=1).transpose(2, 1) # [B,num_dims,num]
37 | batch_size, num_dims, num_points = x.size()
38 |
39 | xInit3d = x
40 |
41 | x = self.conv1_lpd(x)
42 | x = self.conv2_lpd(x)
43 |
44 | # Serial structure
45 | # Danymic Graph cnn for feature space
46 | x = get_graph_feature_Origin(x, k=self.k) # [b,channel*2,num,20]
47 | x = self.convDG1(x) # [b,channel,num,20]
48 | x = self.convDG2(x) # [b,channel,num,20]
49 | x = x.max(dim=-1, keepdim=True)[0] # [b,channel,num,1]
50 |
51 | # Spatial Neighborhood fusion for cartesian space
52 | idx = knn(xInit3d, k=self.k)
53 | x = get_graph_feature_Origin(x, idx=idx, k=self.k, cat=False) # [b,channel,num,20]
54 | x = self.convSN1(x) # [b,channel,num,20]
55 | x = self.convSN2(x) # [b,channel,num,20]
56 | x = x.max(dim=-1, keepdim=True)[0].squeeze(-1) # [b,channel,num]
57 |
58 | x = self.conv3_lpd(x) # [b,channel,num]
59 | x = self.conv4_lpd(x) # [b,128,num]
60 | x = self.conv5_lpd(x) # [b,emb_dims,num]
61 | x = x.unsqueeze(-1) # [b,emb_dims,num,1]
62 |
63 | return x
64 |
65 |
66 | class LPDNet(nn.Module):
67 | """Implement for LPDNet using pytorch package.
68 |
69 | """
70 |
71 | def __init__(self, emb_dims=1024, channels=[64, 128, 256]):
72 | super(LPDNet, self).__init__()
73 | self.negative_slope = 0.0
74 | self.k = 20
75 | self.emb_dims = emb_dims
76 | # [b,6,num,20]
77 | self.convDG1 = nn.Sequential(nn.Conv2d(channels[0] * 2, channels[1], kernel_size=1, bias=True),
78 | nn.LeakyReLU(negative_slope=self.negative_slope))
79 | self.convDG2 = nn.Sequential(nn.Conv2d(channels[1], channels[1], kernel_size=1, bias=True),
80 | nn.LeakyReLU(negative_slope=self.negative_slope))
81 | self.convSN1 = nn.Sequential(nn.Conv2d(channels[1] * 2, channels[2], kernel_size=1, bias=True),
82 | nn.LeakyReLU(negative_slope=self.negative_slope))
83 |
84 | self.conv1_lpd = nn.Conv1d(3, channels[0], kernel_size=1, bias=True)
85 | self.conv2_lpd = nn.Conv1d(channels[0], channels[0], kernel_size=1, bias=True)
86 | self.conv3_lpd = nn.Conv1d((channels[1] + channels[1] + channels[2]), self.emb_dims, kernel_size=1, bias=True)
87 |
88 | # input x: # [B,num_dims,num]
89 | # output x: # [b,emb_dims,num]
90 | def forward(self, x):
91 | x = torch.squeeze(x, dim=1).transpose(2, 1) # [B,num_dims,num]
92 | batch_size, num_dims, num_points = x.size()
93 | #
94 | xInit3d = x
95 |
96 | x = F.leaky_relu(self.conv1_lpd(x), negative_slope=self.negative_slope)
97 | x = F.leaky_relu(self.conv2_lpd(x), negative_slope=self.negative_slope)
98 |
99 | # Serial structure
100 | # Dynamic Graph cnn for feature space
101 | x = get_graph_feature(x, k=self.k) # [b,64*2,num,20]
102 | x = self.convDG1(x) # [b,128,num,20]
103 | x1 = x.max(dim=-1, keepdim=True)[0] # [b,128,num,1]
104 | x = self.convDG2(x) # [b,128,num,20]
105 | x2 = x.max(dim=-1, keepdim=True)[0] # [b,128,num,1]
106 |
107 | # Spatial Neighborhood fusion for cartesian space
108 | idx = knn(xInit3d, k=self.k)
109 | x = get_graph_feature(x2.squeeze(-1), idx=idx, k=self.k) # [b,128*2,num,20]
110 | x = self.convSN1(x) # [b,256,num,20]
111 | x3 = x.max(dim=-1, keepdim=True)[0] # [b,256,num,1]
112 |
113 | x = torch.cat((x1, x2, x3), dim=1).squeeze(-1) # [b,512,num]
114 | x = F.leaky_relu(self.conv3_lpd(x), negative_slope=self.negative_slope).view(batch_size, -1,
115 | num_points) # [b,emb_dims,num]
116 | x = x.unsqueeze(-1) # [b,emb_dims,num,1]
117 | return x
118 |
--------------------------------------------------------------------------------
/models/vlpdnet/vlpdnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import math
4 |
5 | import MinkowskiEngine as ME
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.nn.parallel
11 | import torch.utils.data
12 |
13 | from models.minkloc3d.minkpool import MinkPool
14 | from models.vcrnet.vcrnet import PoseSolver
15 | from models.vlpdnet.lpdnet_model import LPDNet, LPDNetOrign
16 |
17 |
18 | class vLPDNet(nn.Module):
19 | def __init__(self, params):
20 | super(vLPDNet, self).__init__()
21 | self.params = params
22 | model_params = params.model_params
23 | self.featnet = model_params.featnet
24 | if model_params.featnet == "lpdnet":
25 | self.emb_nn = LPDNet(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels)
26 | elif model_params.featnet.lower() == "lpdnetorigin":
27 | self.emb_nn = LPDNetOrign(emb_dims=model_params.emb_dims, channels=model_params.lpd_channels)
28 | else:
29 | print("featnet error")
30 | if params.lpd_fixed:
31 | self.emb_nn.requires_grad = False
32 | self.emb_nn.eval()
33 | # self.net_vlad = NetVLADLoupe(feature_size=args.emb_dims, cluster_size=64,
34 | # output_dim=args.output_dim, gating=True, add_batch_norm=True,
35 | # is_training=True)
36 | self.net_vlad = MinkPool(model_params, in_channels=model_params.emb_dims)
37 |
38 | self.is_register = params.is_register
39 | # self.is_register = True
40 | if params.domain_adapt:
41 | self.is_register = False
42 | self.params.lpd_fixed = False
43 | if params.is_register:
44 | self.pose_solver = PoseSolver(model_params=model_params)
45 | else:
46 | self.pose_solver = None
47 |
48 | # input x [B,1,N,3]
49 | # intermediate feat_x [B,C,N,1]
50 | # output g_feat [N,3]
51 | def forward(self, source_batch=None, target_batch=None, gt_T=None):
52 | if self.params.lpd_fixed:
53 | self.emb_nn.eval()
54 | with torch.no_grad():
55 | if self.is_register and source_batch != None:
56 | source = source_batch["cloud"].unsqueeze(1)
57 | feat_x_s = self.emb_nn(source)
58 | target = target_batch["cloud"].unsqueeze(1)
59 | feat_x_t = self.emb_nn(target)
60 | else:
61 | if self.is_register and source_batch != None:
62 | source = source_batch["cloud"].unsqueeze(1)
63 | feat_x_s = self.emb_nn(source)
64 | target = target_batch["cloud"].unsqueeze(1)
65 | feat_x_t = self.emb_nn(target)
66 |
67 | # registration
68 | # batch_R, batch_t, transformed_xyz = self.reg_model(feat_x_s, feat_x_t, source)
69 | if self.is_register and source_batch != None:
70 | reg_dict = self.pose_solver(source, target, feat_x_s.unsqueeze(1), feat_x_t.unsqueeze(1), gt_T, svd=False)
71 | fs = []
72 | pcls = []
73 | mask_tgt = reg_dict['mask_tgt'].squeeze(-1)
74 | for i in range(mask_tgt.shape[0]):
75 | mask = ~mask_tgt[i]
76 | # a = torch.sum(mask)
77 | fs.append(feat_x_t[i][:, mask, :])
78 | pcls.append(target[i][:, mask, :])
79 | else:
80 | reg_dict = None
81 | fs = []
82 | pcls = []
83 | for i in range(feat_x_t.shape[0]):
84 | fs.append(feat_x_t[i])
85 | pcls.append(target[i])
86 |
87 | # g_fea = self.net_vlad(feat_x_t, target)
88 | g_fea = self.net_vlad(fs, pcls)
89 |
90 | if isinstance(g_fea, dict):
91 | g_fea = g_fea['g_fea']
92 |
93 | return {
94 | 'embeddings': g_fea,
95 | 'reg_dict': reg_dict
96 | }
97 |
98 | return {'g_fea': g_fea, 'l_fea': feat_x_t, 'node_fea': None}
99 |
100 | def registration(self, src, tgt, src_embedding, tgt_embedding, gt_T):
101 |
102 | reg_dict = self.pose_solver(src, tgt, src_embedding, tgt_embedding, gt_T)
103 |
104 | return reg_dict
105 |
106 | def registration_only(self, src, tgt, gt_T):
107 | # input x [B,1,N,3]
108 | src_embedding = self.emb_nn(src)
109 | src_embedding = src_embedding.permute(0, 3, 1, 2)
110 | if self.is_register:
111 | tgt_embedding = self.emb_nn(tgt)
112 | tgt_embedding = tgt_embedding.permute(0, 3, 1, 2)
113 | reg_dict = self.pose_solver(src, tgt, src_embedding, tgt_embedding, gt_T)
114 | return reg_dict
115 |
116 |
117 | class NetVLADLoupe(nn.Module):
118 | def __init__(self, feature_size, cluster_size, output_dim,
119 | gating=True, add_batch_norm=True, is_training=True):
120 | super(NetVLADLoupe, self).__init__()
121 | self.feature_size = feature_size
122 | self.output_dim = output_dim
123 | self.is_training = is_training
124 | self.gating = gating
125 | self.add_batch_norm = add_batch_norm
126 | self.cluster_size = cluster_size
127 | self.softmax = nn.Softmax(dim=-1)
128 | self.cluster_weights = nn.Parameter(torch.randn(
129 | feature_size, cluster_size) * 1 / math.sqrt(feature_size))
130 | self.cluster_weights2 = nn.Parameter(torch.randn(
131 | 1, feature_size, cluster_size) * 1 / math.sqrt(feature_size))
132 | self.hidden1_weights = nn.Parameter(
133 | torch.randn(cluster_size * feature_size, output_dim) * 1 / math.sqrt(feature_size))
134 |
135 | if add_batch_norm:
136 | self.cluster_biases = None
137 | self.bn1 = nn.BatchNorm1d(cluster_size)
138 | else:
139 | self.cluster_biases = nn.Parameter(torch.randn(
140 | cluster_size) * 1 / math.sqrt(feature_size))
141 | self.bn1 = None
142 |
143 | self.bn2 = nn.BatchNorm1d(output_dim)
144 |
145 | self.weight_initialization_all()
146 |
147 | # [B, dims, num_points, 1]
148 | def forward(self, x, coord=None):
149 | max_samples = x.size(2) # num of points
150 | x = x.transpose(1, 3).contiguous()
151 | x = x.view((-1, max_samples, self.feature_size))
152 | activation = torch.matmul(x, self.cluster_weights)
153 | if self.add_batch_norm:
154 | # activation = activation.transpose(1,2).contiguous()
155 | activation = activation.view(-1, self.cluster_size)
156 | activation = self.bn1(activation)
157 | activation = activation.view(-1,
158 | max_samples, self.cluster_size)
159 | # activation = activation.transpose(1,2).contiguous()
160 | else:
161 | activation = activation + self.cluster_biases
162 | activation = self.softmax(activation)
163 | activation = activation.view((-1, max_samples, self.cluster_size))
164 |
165 | a_sum = activation.sum(-2, keepdim=True)
166 | a = a_sum * self.cluster_weights2
167 |
168 | activation = torch.transpose(activation, 2, 1)
169 | x = x.view((-1, max_samples, self.feature_size))
170 | vlad = torch.matmul(activation, x)
171 | vlad = torch.transpose(vlad, 2, 1)
172 | vlad = vlad - a
173 |
174 | vlad = F.normalize(vlad, dim=1, p=2)
175 | # print(vlad.shape,self.cluster_size,self.feature_size,self.cluster_size * self.feature_size)
176 | vlad = vlad.reshape((-1, self.cluster_size * self.feature_size))
177 | vlad = F.normalize(vlad, dim=1, p=2)
178 |
179 | vlad = torch.matmul(vlad, self.hidden1_weights)
180 |
181 | vlad = self.bn2(vlad)
182 |
183 | if self.gating:
184 | vlad = self.context_gating(vlad)
185 |
186 | return vlad
187 |
188 | def weight_initialization_all(self):
189 | for m in self.modules():
190 | if isinstance(m, nn.Sequential):
191 | for m in self.modules():
192 | self.weight_init(m)
193 | else:
194 | self.weight_init(m)
195 |
196 | def weight_init(self, m):
197 | if isinstance(m, nn.Conv1d):
198 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
199 | if isinstance(m, nn.Conv2d):
200 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
201 | elif isinstance(m, nn.BatchNorm1d):
202 | nn.init.constant_(m.weight, 1)
203 | nn.init.constant_(m.bias, 0)
204 | elif isinstance(m, nn.BatchNorm2d):
205 | nn.init.constant_(m.weight, 1)
206 | nn.init.constant_(m.bias, 0)
207 |
208 |
209 | class GatingContext(nn.Module):
210 | def __init__(self, dim, add_batch_norm=True):
211 | super(GatingContext, self).__init__()
212 | self.dim = dim
213 | self.add_batch_norm = add_batch_norm
214 | self.gating_weights = nn.Parameter(
215 | torch.randn(dim, dim) * 1 / math.sqrt(dim))
216 | self.sigmoid = nn.Sigmoid()
217 |
218 | if add_batch_norm:
219 | self.gating_biases = None
220 | self.bn1 = nn.BatchNorm1d(dim)
221 | else:
222 | self.gating_biases = nn.Parameter(
223 | torch.randn(dim) * 1 / math.sqrt(dim))
224 | self.bn1 = None
225 |
226 | self.weight_initialization_all()
227 |
228 | def weight_initialization_all(self):
229 | for m in self.modules():
230 | if isinstance(m, nn.Sequential):
231 | for m in self.modules():
232 | self.weight_init(m)
233 | else:
234 | self.weight_init(m)
235 |
236 | def weight_init(self, m):
237 | if isinstance(m, nn.Conv1d):
238 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
239 | if isinstance(m, nn.Conv2d):
240 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
241 | elif isinstance(m, nn.BatchNorm1d):
242 | nn.init.constant_(m.weight, 1)
243 | nn.init.constant_(m.bias, 0)
244 | elif isinstance(m, nn.BatchNorm2d):
245 | nn.init.constant_(m.weight, 1)
246 | nn.init.constant_(m.bias, 0)
247 |
248 | def forward(self, x):
249 | gates = torch.matmul(x, self.gating_weights)
250 |
251 | if self.add_batch_norm:
252 | if gates.size(0) == 1:
253 | gates = gates
254 | else:
255 | gates = self.bn1(gates)
256 | else:
257 | gates = gates + self.gating_biases
258 |
259 | gates = self.sigmoid(gates)
260 |
261 | activation = x * gates
262 |
263 | return activation
264 |
265 |
266 | def extract_features(model,
267 | xyz,
268 | rgb=None,
269 | normal=None,
270 | voxel_size=0.05,
271 | device=None,
272 | skip_check=False,
273 | is_eval=True):
274 | '''
275 | xyz is a N x 3 matrix
276 | rgb is a N x 3 matrix and all color must range from [0, 1] or None
277 | normal is a N x 3 matrix and all normal range from [-1, 1] or None
278 |
279 | if both rgb and normal are None, we use Nx1 one vector as an input
280 |
281 | if device is None, it tries to use gpu by default
282 |
283 | if skip_check is True, skip rigorous checks to speed up
284 |
285 | model = model.to(device)
286 | xyz, feats = extract_features(model, xyz)
287 | '''
288 | if is_eval:
289 | model.eval()
290 |
291 | if device is None:
292 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
293 |
294 | feats = []
295 | if rgb is not None:
296 | # [0, 1]
297 | feats.append(rgb - 0.5)
298 |
299 | if normal is not None:
300 | # [-1, 1]
301 | feats.append(normal / 2)
302 |
303 | if rgb is None and normal is None:
304 | feats.append(np.ones((len(xyz), 1)))
305 |
306 | feats = np.hstack(feats)
307 |
308 | # Voxelize xyz and feats
309 | coords = np.floor(xyz / voxel_size)
310 |
311 | coords, inds = ME.utils.sparse_quantize(coords, return_index=True)
312 |
313 | # Convert to batched coords compatible with ME
314 | coords = ME.utils.batched_coordinates([coords])
315 | return_coords = xyz[inds]
316 |
317 | feats = feats[inds]
318 |
319 | feats = torch.tensor(feats, dtype=torch.float32)
320 | coords = coords.clone().detach()
321 |
322 | stensor = ME.SparseTensor(feats, coordinates=coords, device=device)
323 |
324 | return return_coords, model(stensor).F
325 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Author: Zhijian Qiao
2 | # Shanghai Jiao Tong University
3 | # Code adapted from PointNetVlad code: https://github.com/jac99/MinkLoc3D.git
4 |
5 |
6 | import argparse
7 | import os
8 | import sys
9 |
10 | import torch
11 |
12 | sys.path.append(os.path.join(os.path.abspath("./"), "../"))
13 | from training.trainer import Trainer
14 | from misc.utils import MinkLocParams
15 | from misc.log import log_string
16 | from dataloader.dataset_utils import make_dataloaders
17 | import pdb
18 |
19 | if __name__ == '__main__':
20 | parser = argparse.ArgumentParser(description='Train Minkowski Net embeddings using BatchHard negative mining')
21 | parser.add_argument('--config', type=str, required=True, help='Path to configuration file')
22 | parser.add_argument('--model_config', type=str, required=True, help='Path to the model-specific configuration file')
23 | parser.add_argument('--debug', dest='debug', action='store_true')
24 | parser.add_argument('--qzj_debug', action='store_true')
25 | parser.set_defaults(debug=False)
26 | parser.add_argument('--checkpoint', type=str, required=False, help='Trained model weights', default="")
27 | parser.add_argument('--visualize', dest='visualize', action='store_true')
28 | parser.set_defaults(visualize=False)
29 |
30 | args = parser.parse_args()
31 | log_string('Training config path: {}'.format(args.config))
32 | log_string('Model config path: {}'.format(args.model_config))
33 | log_string('Debug mode: {}'.format(args.debug))
34 | log_string('Visualize: {}'.format(args.visualize))
35 |
36 | params = MinkLocParams(args.config, args.model_config)
37 | if args.qzj_debug:
38 | params.batch_size = 4
39 | if params.is_register:
40 | params.reg.batch_size = 4
41 | params.print()
42 |
43 | if args.debug:
44 | torch.autograd.set_detect_anomaly(True)
45 |
46 | dataloaders = make_dataloaders(params, debug=args.debug)
47 | trainer = Trainer(dataloaders, params, debug=args.debug, visualize=args.visualize, checkpoint=args.checkpoint)
48 | if args.debug:
49 | pdb.set_trace()
50 |
51 | trainer.do_train()
52 |
--------------------------------------------------------------------------------
/training/optimizer_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from misc.utils import MinkLocParams
4 |
5 |
6 | def optimizer_factory(params: MinkLocParams, model, d_model):
7 | # Training elements
8 | if params.weight_decay is None or params.weight_decay == 0:
9 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr)
10 | optimizer_d = torch.optim.Adam(d_model.parameters(), lr=params.d_lr) if params.domain_adapt else None
11 | print("optimizer use Adam with lr {}".format(params.lr))
12 | if optimizer_d is not None:
13 | print("optimizer_d use Adam with and lr {}".format(params.d_lr))
14 | else:
15 | if params.lpd_fixed:
16 | ignored_params = list(map(id, model.emb_nn.parameters()))
17 | base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
18 | optimizer = torch.optim.Adam([
19 | {'params': base_params},
20 | {'params': model.emb_nn.parameters(), 'lr': 0.0}], lr=params.lr, weight_decay=params.weight_decay)
21 | else:
22 | optimizer = torch.optim.Adam(model.parameters(), lr=params.lr, weight_decay=params.weight_decay)
23 | optimizer_d = torch.optim.Adam(d_model.parameters(), lr=params.d_lr,
24 | weight_decay=params.d_weight_decay) if params.domain_adapt else None
25 | print("optimizer use Adam with weight_decay {} and lr {}".format(params.weight_decay, params.lr))
26 | if optimizer_d is not None:
27 | print("optimizer_d use Adam with weight_decay {} and lr {}".format(params.d_weight_decay, params.d_lr))
28 |
29 | return optimizer, optimizer_d
30 |
31 |
32 | def scheduler_factory(params: MinkLocParams, optimizer, optimizer_d=None):
33 | # Training elements
34 | optimizer_d = optimizer if optimizer_d == None else optimizer_d
35 | if params.scheduler is None:
36 | scheduler = None
37 | scheduler_d = None
38 | else:
39 | if params.scheduler == 'CosineAnnealingLR':
40 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=params.epochs + 1,
41 | eta_min=params.min_lr)
42 | scheduler_d = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_d, T_max=params.epochs + 1,
43 | eta_min=params.min_lr) if params.domain_adapt else None
44 | elif params.scheduler == 'MultiStepLR':
45 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, params.scheduler_milestones, gamma=0.1)
46 | scheduler_d = torch.optim.lr_scheduler.MultiStepLR(optimizer_d, params.scheduler_milestones,
47 | gamma=0.1) if params.domain_adapt else None
48 | else:
49 | scheduler = None
50 | scheduler_d = None
51 | raise NotImplementedError('Unsupported LR scheduler: {}'.format(params.scheduler))
52 |
53 | return scheduler, scheduler_d
54 |
--------------------------------------------------------------------------------
/weights/vlpdnet-registration.t7:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qiaozhijian/vLPD-Net/233d484c3becc562f7bf3ea8a1bdee711eef58c3/weights/vlpdnet-registration.t7
--------------------------------------------------------------------------------