├── .gitignore ├── LICENSE ├── README.md ├── config ├── evaluate │ └── once.yaml └── gapr │ └── train.yaml ├── datasets ├── dataloders │ ├── __init__.py │ ├── augments │ │ ├── __init__.py │ │ ├── augment.py │ │ └── utils.py │ ├── collates │ │ ├── __init__.py │ │ ├── lprcollate.py │ │ └── utils.py │ ├── lprdataloader.py │ └── samplers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── batch.py │ │ ├── hetero.py │ │ ├── homo.py │ │ ├── lprbatchsampler.py │ │ └── utils.py └── lprdataset.py ├── evaluate ├── once.py └── utils.py ├── loss ├── __init__.py ├── base.py ├── gapr.py ├── lprloss.py ├── overlap.py ├── point.py └── triplet.py ├── media ├── description.png └── pipeline.png ├── misc └── utils.py ├── models ├── __init__.py ├── gapr.py ├── lprmodel.py └── utils │ ├── aggregation │ └── gem.py │ ├── extraction │ └── mink │ │ ├── minkfpn.py │ │ ├── resnet.py │ │ └── utils.py │ └── transformers │ └── transgeo.py ├── pretrain ├── GAPR.pth └── config.yaml ├── results ├── evaluate │ └── readme.txt └── weights │ └── readme.txt ├── scripts ├── add_path.sh ├── clean.sh └── train.sh └── train └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # results of weights and evaluation 132 | results/evaluate/20* 133 | results/weights/20* 134 | 135 | # .vscode 136 | *.vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 SYSU RAPID Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAPR 2 | 3 | ## Introduction 4 | [RA-L 23] Heterogeneous Deep Metric Learning for Ground and Aerial Point Cloud-Based Place Recognition 5 | 6 | In this paper, we propose a heterogeneous deep metric learning pipeline for ground and aerial point cloud-based place recognition in large-scale environments.The pipeline extracts local features from ground and aerial raw point clouds by a sparse convolution module. The local features are processed by transformer encoders to capture the overlaps between ground and aerial point clouds, and then transformed to unified descriptors for retrieval purposes by backpropagation of heterogeneous loss functions.To facilitate training and provide a reliable benchmark, a large-scale dataset is also proposed, which is collected from well-equipped ground and aerial robotic platforms. We demonstrate the superiority of the proposed method by comparing it with existing well-performed methods. We also show that our method is capable to detect loop closures in a collaborative ground and aerial robotic system in the experimental results. 7 | 8 |
9 |
10 | 11 |
12 | 13 | Task Illustration 14 |
15 | 16 |
17 |
18 | 19 |
20 | 21 | GAPR Pipeline 22 |
23 | 24 | 25 | ## Contributors 26 | [Yingrui Jie 揭英睿](https://github.com/yingruijie), 27 | [Yilin Zhu 朱奕霖](https://github.com/inntoy), and 28 | [Hui Cheng 成慧](https://cse.sysu.edu.cn/content/2504) from 29 | [SYSU RAPID Lab](http://lab.sysu-robotics.com). 30 | 31 | ## Citation 32 | ```tex 33 | @ARTICLE{10173571, 34 | author={Jie, Yingrui and Zhu, Yilin and Cheng, Hui}, 35 | journal={IEEE Robotics and Automation Letters}, 36 | title={Heterogeneous Deep Metric Learning for Ground and Aerial Point Cloud-Based Place Recognition}, 37 | year={2023}, 38 | volume={}, 39 | number={}, 40 | pages={1-8}, 41 | doi={10.1109/LRA.2023.3292623}} 42 | ``` 43 | 44 | 45 | # Usage 46 | ## Environment 47 | This project has been tested on a system with Ubuntu 18.04. Main dependencies include: CUDA >= 10.2; PyTorch >= 1.9.1; MinkowskiEngine >= 0.5.4; Opne3D >= 0.15.2. Please set up the requirments as follows. 48 | 1. Install [cuda-10.2](https://developer.nvidia.com/cuda-10.2-download-archive). 49 | 50 | 2. Create the anaconda environment. 51 | ```sh 52 | conda create -n gapr python=3.8 53 | conda activate gapr 54 | ``` 55 | 3. [PyTorch](https://pytorch.org/). 56 | ```sh 57 | pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 58 | ``` 59 | 60 | 4. [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine). 61 | ```sh 62 | conda install openblas-devel -c anaconda 63 | git clone https://github.com/NVIDIA/MinkowskiEngine.git 64 | cd MinkowskiEngine 65 | export CXX=g++-7 66 | git checkout v0.5.4 67 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas 68 | cd .. 69 | ``` 70 | 71 | 5. Install requirements. 72 | ```sh 73 | # install setuptools firstly to avoid some bugs 74 | pip install setuptools==58.0.4 75 | pip install tqdm open3d tensorboard pandas matplotlib pillow ptflops timm==0.9.2 76 | ``` 77 | 78 | 6. Download this repository. 79 | ```sh 80 | git clone https://github.com/SYSU-RoboticsLab/GAPR.git 81 | cd GAPR 82 | ``` 83 | Add the python path before running codes: 84 | ```sh 85 | export PYTHONPATH=$PYTHONPATH:/PATH_TO_CODE/GAPR 86 | ``` 87 | 88 | ## Dataset 89 | Please download our [benchmark dataset](https://pan.baidu.com/s/1TsxSNZVkGwpZjBM0eNXglw?pwd=zxx4) and unpack the tar file. 90 | Run the following command to check the dataset (`train` and `evaluate`). 91 | ```sh 92 | python datasets/lprdataset.py --dataset /PATH_TO_DATASET/benchmark/train 93 | python datasets/lprdataset.py --dataset /PATH_TO_DATASET/benchmark/evaluate 94 | ``` 95 | 96 | ## Evaluate 97 | 1. Change the path of dataset in `config/evaluate/once.yaml`. 98 | ```yaml 99 | # ... 100 | dataloaders: 101 | evaluate: 102 | dataset: /PATH_TO_DATASET/benchmark/evaluate 103 | # ... 104 | ``` 105 | 2. We provide a pretrain weights for evaluation. 106 | ```sh 107 | python evaluate/once.py --weights pretrain/GAPR.pth --yaml config/evaluate/once.yaml 108 | ``` 109 | Parameter `weights` is used to set the path of model weights. The results are saved at `results/evaluate/YYMMDD_HHMMSS`. 110 | ## Train 111 | 1. Change the path of dataset in `config/gapr/train.yaml`. 112 | ```yaml 113 | # ... 114 | dataloaders: 115 | train: 116 | dataset: /PATH_TO_DATASET/benchmark/train 117 | # ... 118 | ``` 119 | 2. Select the GPUs and start training. For example, here GPU 1,3 are sued and `nproc_per_node` is set to 2 (number of selected GPUs). 120 | ```sh 121 | CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train/train.py --yaml config/gapr/train.yaml 122 | ``` 123 | The training weights are saved at `results/weights/YYMMDD_HHMMSS`. 124 | 125 | # Acknowledgement 126 | We acknowledge the authors of [MinkLoc3D](https://github.com/jac99/MinkLoc3D) for their excellent codebase which has been used as a starting point for this project. 127 | -------------------------------------------------------------------------------- /config/evaluate/once.yaml: -------------------------------------------------------------------------------- 1 | dataloaders: 2 | evaluate: 3 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/evaluate 4 | collate: 5 | name: MetricCollate 6 | augment: 7 | name: EvaluateAugment 8 | rotate_cmd: zxy10 9 | translate_delta: 0.0 10 | if_jrr: no 11 | sampler: 12 | name: BatchSample 13 | batch_size: 1 14 | batch_size_limit: null 15 | batch_expansion_rate: null 16 | # sample kw 17 | shuffle: false 18 | max_batches: null 19 | num_workers: 1 20 | 21 | -------------------------------------------------------------------------------- /config/gapr/train.yaml: -------------------------------------------------------------------------------- 1 | dataloaders: 2 | train: 3 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/train 4 | collate: 5 | name: MetricCollate 6 | augment: 7 | name: TrainAugment 8 | rotate_cmd: zxy10 9 | translate_delta: 1.0 10 | if_jrr: no 11 | sampler: 12 | name: HeteroTripletSample 13 | batch_size: 16 14 | batch_size_limit: 32 15 | batch_expansion_rate: 1.4 16 | max_batches: null 17 | num_workers: 4 18 | 19 | method: 20 | model: 21 | name: GAPR 22 | debug: no 23 | minkfpn: 24 | quant_size: 0.6 25 | in_channels: 1 26 | out_channels: 256 27 | num_top_down: 1 28 | conv0_kernel_size: 5 29 | layers: [1, 1, 1] 30 | planes: [32, 64, 64] 31 | pctrans: 32 | dim: 256 33 | num_heads: 2 34 | mlp_ratio: 4 35 | depth: 1 36 | qkv_bias: yes 37 | init_values: null 38 | drop: 0.0 39 | attn_drop: 0.0 40 | drop_path_rate: 0.0 41 | meangem: 42 | p: 3.0 43 | eps: 0.000001 44 | loss: 45 | name: GAPRLoss 46 | batch_loss: 47 | margin: 1.0 48 | style: hard 49 | point_loss: 50 | margin: 10.0 51 | style: soft 52 | corr_dist: 2.0 53 | sample_num: 64 54 | pos_dist: 2.1 55 | neg_dist: 20.0 56 | overlap_loss: 57 | corr_dist: 2.0 58 | point_loss_scale: 0.5 59 | overlap_loss_scale: 1.0 60 | 61 | train: 62 | lr: 0.001 63 | epochs: 40 64 | weight_decay: 0.001 65 | batch_expansion_th: 0.7 # no used 66 | scheduler_milestones: [15, 30] # no used 67 | 68 | dist: 69 | backend: nccl 70 | find_unused_parameters: no 71 | 72 | results: 73 | weights: results/weights 74 | logs: null # no used 75 | -------------------------------------------------------------------------------- /datasets/dataloders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/__init__.py -------------------------------------------------------------------------------- /datasets/dataloders/augments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/augments/__init__.py -------------------------------------------------------------------------------- /datasets/dataloders/augments/augment.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | import open3d as o3d 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import open3d as o3d 7 | from typing import List, Dict, Tuple 8 | from datasets.dataloders.augments.utils import * 9 | 10 | class Augment: 11 | """ 12 | # Wrapper for Pointcloud Augment 13 | """ 14 | def __init__(self, name:str, rotate_cmd:str, translate_delta:float, if_jrr:bool): 15 | print("Augment: name=%s, rotate=%s, translate=%.3f, jrr=%s " % (name, rotate_cmd, translate_delta, if_jrr)) 16 | self.rotate = RandomRotation(rotate_cmd) 17 | self.translate = RandomTranslation(translate_delta) 18 | 19 | if if_jrr: raise NotImplementedError("Augment: jrr is currently not implemented.") 20 | self.jrr = None 21 | 22 | def __call__(self, e:torch.Tensor): 23 | # jrr 24 | if self.jrr is not None: e0 = self.jrr(e) 25 | else: e0 = e 26 | # rotate 27 | e1, rotms = self.rotate(e0) 28 | # translate 29 | e2, trans = self.translate(e1) 30 | # align data type and device 31 | e2 = e2.to(e.device).type_as(e).contiguous() 32 | rotms = rotms.to(e.device).type_as(e).contiguous() 33 | trans = trans.to(e.device).type_as(e).contiguous() 34 | return e2, rotms, trans 35 | 36 | -------------------------------------------------------------------------------- /datasets/dataloders/augments/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import numpy as np 5 | import math 6 | import random 7 | import torch 8 | 9 | 10 | class RandomRotation: 11 | """ 12 | # Random Rotate pointclouds and return matrix 13 | """ 14 | def __init__(self, cmd:str): 15 | if cmd not in [None, "zxy10", "zxy20", "so3"]: raise NotImplementedError("RandomRotate: cmd in [None, zxy10, zxy20,so3]") 16 | self.cmd = cmd 17 | 18 | def getRotateMatrixFromRotateVector(self, axis:torch.Tensor, theta:torch.Tensor)->torch.Tensor: 19 | """ 20 | # Get Rotate Matrix from Rotate Vector\n 21 | ## input \n 22 | axis.size() == [bs, 3] \n 23 | theta.size() == [bs] \n 24 | ## output \n 25 | rotateMatrix.size() == [bs, 3, 3] \n 26 | """ 27 | device = axis.device 28 | bs = axis.size()[0] 29 | # [bs, 3] 30 | axis = axis / torch.norm_except_dim(v=axis, pow=2, dim=0) 31 | # [bs, 1, 3], [bs, 3, 1] 32 | axisH, axisV = axis.unsqueeze(1), axis.unsqueeze(2) 33 | # [bs, 1, 1] 34 | cosTheta = torch.cos(theta).reshape(-1, 1, 1) 35 | # [bs, 1, 1] 36 | sinTheta = torch.sin(theta).reshape(-1, 1, 1) 37 | # [bs, 3, 3] 38 | eye = torch.eye(3, device=device).expand(bs, 3, 3) 39 | # axis^ [bs, 3, 3] 40 | axisCaret = torch.cross(eye, axisH.expand(bs, 3, 3), dim=2) 41 | # so3: R = cos(theta) * I + (1-cos(theta)) * dot(a, aT) + sin(theta) * a^ 42 | r = cosTheta * eye + (1.0-cosTheta) * torch.bmm(axisV, axisH) + sinTheta * axisCaret 43 | return r 44 | 45 | def __call__(self, coords:torch.Tensor): 46 | device = coords.device 47 | BS, PN, D = coords.shape 48 | # initial theta and axis 49 | theta, axis = torch.zeros((BS), device=device), torch.tensor([[0.0, 0.0, 1.0]] * BS, device=device) 50 | if self.cmd == "zxy10": 51 | # theta [-pi, pi] 52 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi 53 | alpha = torch.rand(BS,device=device) * 2 * np.pi - np.pi 54 | beta = torch.rand(BS,device=device) * np.pi * 10.0 / 180.0 55 | # alpha_axis is a vector in xOy plane 56 | alpha_axis = torch.stack([torch.sin(alpha), torch.cos(alpha), torch.zeros((BS), device=device)], dim=1) 57 | # print(alpha_axis, beta) 58 | alpha_mat = self.getRotateMatrixFromRotateVector(alpha_axis, beta) 59 | axis = torch.bmm(alpha_mat, axis.unsqueeze(2)).squeeze(2) 60 | elif self.cmd == "zxy20": 61 | # theta [-pi, pi] 62 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi 63 | alpha = torch.rand(BS,device=device) * 2 * np.pi - np.pi 64 | beta = torch.rand(BS,device=device) * np.pi * 20.0 / 180.0 65 | # alpha_axis is a vector in xOy plane 66 | alpha_axis = torch.stack([torch.sin(alpha), torch.cos(alpha), torch.zeros((BS), device=device)], dim=1) 67 | # print(alpha_axis, beta) 68 | alpha_mat = self.getRotateMatrixFromRotateVector(alpha_axis, beta) 69 | axis = torch.bmm(alpha_mat, axis.unsqueeze(2)).squeeze(2) 70 | elif self.cmd == "so3": 71 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi 72 | axis = torch.rand((BS, 3), device=device) - 0.5 73 | axis = axis / torch.norm_except_dim(axis, dim=1) 74 | 75 | rots_mat = self.getRotateMatrixFromRotateVector(axis, theta).type_as(coords) 76 | coords = torch.bmm(rots_mat, coords.transpose(1,2)).transpose(1, 2) 77 | return coords, rots_mat 78 | 79 | 80 | class RandomTranslation: 81 | """ 82 | # Random Translation 83 | """ 84 | def __init__(self, delta=0.05): 85 | self.delta = delta 86 | 87 | def __call__(self, coords:torch.Tensor): 88 | BS, device = coords.shape[0], coords.device 89 | trans = self.delta * torch.randn(BS, 3, device=device) 90 | return coords + trans.unsqueeze(1), trans 91 | 92 | 93 | class RandomFlip: 94 | def __init__(self, p): 95 | # p = [p_x, p_y, p_z] probability of flipping each axis 96 | assert len(p) == 3 97 | assert 0 < sum(p) <= 1, 'sum(p) must be in (0, 1] range, is: {}'.format(sum(p)) 98 | self.p = p 99 | self.p_cum_sum = np.cumsum(p) 100 | 101 | def __call__(self, coords): 102 | r = random.random() 103 | if r <= self.p_cum_sum[0]: 104 | # Flip the first axis 105 | coords[..., 0] = -coords[..., 0] 106 | elif r <= self.p_cum_sum[1]: 107 | # Flip the second axis 108 | coords[..., 1] = -coords[..., 1] 109 | elif r <= self.p_cum_sum[2]: 110 | # Flip the third axis 111 | coords[..., 2] = -coords[..., 2] 112 | 113 | return coords 114 | 115 | class RandomScale: 116 | def __init__(self, min, max): 117 | self.scale = max - min 118 | self.bias = min 119 | 120 | def __call__(self, coords): 121 | s = self.scale * np.random.rand(1) + self.bias 122 | return coords * s.astype(np.float32) 123 | 124 | class RandomShear: 125 | def __init__(self, delta=0.1): 126 | self.delta = delta 127 | 128 | def __call__(self, coords): 129 | T = np.eye(3) + self.delta * np.random.randn(3, 3) 130 | return coords @ T.astype(np.float32) 131 | 132 | 133 | class JitterPoints: 134 | def __init__(self, sigma=0.01, clip=None, p=1.): 135 | assert 0 < p <= 1. 136 | assert sigma > 0. 137 | 138 | self.sigma = sigma 139 | self.clip = clip 140 | self.p = p 141 | 142 | def __call__(self, e:torch.Tensor): 143 | """ Randomly jitter points. jittering is per point. 144 | Input: 145 | BxNx3 array, original batch of point clouds 146 | Return: 147 | BxNx3 array, jittered batch of point clouds 148 | """ 149 | 150 | sample_shape = (e.shape[0],) 151 | if self.p < 1.: 152 | # Create a mask for points to jitter 153 | m = torch.distributions.categorical.Categorical(probs=torch.tensor([1 - self.p, self.p])) 154 | mask = m.sample(sample_shape=sample_shape) 155 | else: 156 | mask = torch.ones(sample_shape, dtype=torch.int64 ) 157 | 158 | mask = mask == 1 159 | jitter = self.sigma * torch.randn_like(e[mask]) 160 | 161 | if self.clip is not None: 162 | jitter = torch.clamp(jitter, min=-self.clip, max=self.clip) 163 | 164 | e[mask] = e[mask] + jitter 165 | return e 166 | 167 | 168 | class RemoveRandomPoints: 169 | def __init__(self, r): 170 | if type(r) is list or type(r) is tuple: 171 | assert len(r) == 2 172 | assert 0 <= r[0] <= 1 173 | assert 0 <= r[1] <= 1 174 | self.r_min = float(r[0]) 175 | self.r_max = float(r[1]) 176 | else: 177 | assert 0 <= r <= 1 178 | self.r_min = None 179 | self.r_max = float(r) 180 | 181 | def __call__(self, e:torch.Tensor): 182 | n = len(e) 183 | if self.r_min is None: 184 | r = self.r_max 185 | else: 186 | # Randomly select removal ratio 187 | r = random.uniform(self.r_min, self.r_max) 188 | 189 | mask = np.random.choice(range(n), size=int(n*r), replace=False) # select elements to remove 190 | e[mask] = torch.zeros_like(e[mask]) 191 | return e 192 | 193 | 194 | class RemoveRandomBlock: 195 | """ 196 | Randomly remove part of the point cloud. Similar to PyTorch RandomErasing but operating on 3D point clouds. 197 | Erases fronto-parallel cuboid. 198 | Instead of erasing we set coords of removed points to (0, 0, 0) to retain the same number of points 199 | """ 200 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)): 201 | self.p = p 202 | self.scale = scale 203 | self.ratio = ratio 204 | 205 | def get_params(self, coords:torch.Tensor): 206 | # Find point cloud 3D bounding box 207 | flattened_coords = coords.contiguous().view(-1, 3) 208 | min_coords, _ = torch.min(flattened_coords, dim=0) 209 | max_coords, _ = torch.max(flattened_coords, dim=0) 210 | span = max_coords - min_coords 211 | area = span[0] * span[1] 212 | erase_area = random.uniform(self.scale[0], self.scale[1]) * area 213 | aspect_ratio = random.uniform(self.ratio[0], self.ratio[1]) 214 | 215 | h = math.sqrt(erase_area * aspect_ratio) 216 | w = math.sqrt(erase_area / aspect_ratio) 217 | 218 | x = min_coords[0] + random.uniform(0, 1) * (span[0] - w) 219 | y = min_coords[1] + random.uniform(0, 1) * (span[1] - h) 220 | 221 | return x, y, w, h 222 | 223 | def __call__(self, coords): 224 | if random.random() < self.p: 225 | x, y, w, h = self.get_params(coords) # Fronto-parallel cuboid to remove 226 | mask = (x < coords[..., 0]) & (coords[..., 0] < x+w) & (y < coords[..., 1]) & (coords[..., 1] < y+h) 227 | coords[mask] = torch.zeros_like(coords[mask]) 228 | return coords 229 | -------------------------------------------------------------------------------- /datasets/dataloders/collates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/collates/__init__.py -------------------------------------------------------------------------------- /datasets/dataloders/collates/lprcollate.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | from time import sleep 5 | import torch 6 | import numpy as np 7 | import open3d as o3d 8 | import torch.nn.functional as F 9 | import matplotlib.pyplot as plt 10 | from typing import Any, Dict, List 11 | from datasets.lprdataset import LPRDataset 12 | from scipy.spatial.transform import Rotation as R 13 | # common 14 | from datasets.dataloders.collates.utils import align_pcs, in_sorted_array, triplet_mask 15 | 16 | def LPRCollate(dataset:LPRDataset, augment, name:str, **kw): 17 | """ 18 | # Wrapper for all collate_fn 19 | ## Format 20 | ``` 21 | def collate_fn(data_list): 22 | # data_list defined in LPRdataset.__getitem__() 23 | # data_list = [ 24 | # [label0, pc0], 25 | # [label1, pc1] 26 | # ] 27 | ... 28 | # data 29 | data:Dict[str, torch.Tensor] = {"data":data} 30 | 31 | mask:Dict[str, torch.Tensor] = {"mask":mask} 32 | return data, mask 33 | ``` 34 | ## Training 35 | ``` 36 | for data, mask in dataloader: 37 | data = tensors2device(data, device) 38 | output = model(data) 39 | loss = loss_fn(output, mask) 40 | ... 41 | ``` 42 | """ 43 | if name == "BaseCollate": 44 | def BaseCollate(data_list): 45 | """ 46 | # BaseCollate 47 | * align clouds 48 | * convert labels to tensor 49 | """ 50 | clouds = [e[1] for e in data_list] 51 | clouds = align_pcs(clouds) 52 | clouds = torch.stack(clouds, dim=0) 53 | 54 | labels = torch.tensor([e[0] for e in data_list]) 55 | data:Dict[str, torch.Tensor] = {"clouds":clouds} 56 | mask:Dict[str, torch.Tensor] = {"labels":labels} 57 | return data, mask 58 | return BaseCollate 59 | 60 | elif name == "MetricCollate": 61 | def MetricCollate(data_list): 62 | """ 63 | # Metric Learning Collate Function 64 | """ 65 | # constructs a batch object 66 | raw_coords = [e[1] for e in data_list] 67 | labels = [e[0] for e in data_list] 68 | # align points number 69 | raw_coords = align_pcs(raw_coords) 70 | # Tensor: raw_coords: [BS, PN, 3], 71 | raw_coords = torch.stack(raw_coords, dim=0) 72 | BS, device = raw_coords.shape[0], raw_coords.device 73 | 74 | # get tums from dataset 75 | tums = [] 76 | for ndx in labels: tums.append(dataset.get_tum(ndx)) 77 | tums = np.asarray(tums) 78 | # Tensor: raw_rotms: [BS, 3, 3], raw_trans: [BS, 3] 79 | # tf_global2raw is [raw_rotms, raw_trans] 80 | raw_trans = torch.tensor(tums[:, 1:4], device=device).type_as(raw_coords) 81 | raw_rotms = torch.tensor(R.from_quat(tums[:, 4:8]).as_matrix(), device=device).type_as(raw_coords) 82 | 83 | # apply augment 84 | if augment is not None: 85 | # Tensor: aug_coords: [BS, PN, 3], aug_rotms: [BS, 3, 3], aug_trans: [BS, 3] 86 | # coords = aug_rotms * raw_coords + aug_trans, tf_aug2raw is [aug_rotms, aug_trans] 87 | coords, aug_rotms, aug_trans = augment(raw_coords.clone()) 88 | # compute tf_global2aug = tf_global2raw * tf_aug2raw.inverse() 89 | # tf_aug2raw.inverse() = [aug_rotms.T, -aug_rotms.T*aug_trans] 90 | aug_rotms_inv, aug_trans_inv = aug_rotms.transpose(1,2), -torch.bmm(aug_rotms.transpose(1,2), aug_trans.unsqueeze(2)).squeeze(2) 91 | # trans = raw_trans + raw_rotms * aug_rotms.inverse() * aug_trans 92 | trans = raw_trans + torch.bmm(raw_rotms, aug_trans_inv.unsqueeze(2)).squeeze(2) 93 | # rotms = raw_rotms * aug_rotms.inverse() 94 | rotms = torch.bmm(raw_rotms, aug_rotms_inv) 95 | else: 96 | trans, rotms, coords = raw_trans.clone(), raw_rotms.clone(), raw_coords.clone() 97 | 98 | # set feats to 1, or color if rgb pointcloud 99 | feats = torch.ones((coords.shape[0], coords.shape[1], 1)) 100 | # compute positives and negatives mask 101 | positives_mask, negatives_mask = triplet_mask(dataset, labels) 102 | 103 | # get geneous 104 | geneous = dataset.get_all_geneous()[labels] 105 | geneous = torch.tensor(geneous, dtype=torch.int) 106 | # get labels 107 | labels = torch.tensor(labels) 108 | # write to data and mask 109 | data:Dict[str, torch.Tensor] = {"coords":coords, "feats":feats, "geneous":geneous} 110 | mask:Dict[str, torch.Tensor] = {"labels":labels, "geneous":geneous, "rotms":rotms, "trans":trans, "positives":positives_mask, "negatives":negatives_mask} 111 | return data, mask 112 | return MetricCollate 113 | else: 114 | raise NotImplementedError("LPRCollate: %s not implemented" % name) 115 | 116 | -------------------------------------------------------------------------------- /datasets/dataloders/collates/utils.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import open3d as o3d 8 | from typing import List, Dict, Tuple 9 | from datasets.lprdataset import LPRDataset 10 | 11 | def align_pcs(pcs:List[torch.Tensor], align_size:int=None)->List[torch.Tensor]: 12 | """ 13 | # align points number in pointclouds 14 | ## Input 15 | * pcs 16 | * align_size: points number \n 17 | ## Output 18 | * newpcs 19 | """ 20 | if align_size is None: 21 | # if None, find the max size 22 | max_size = 0 23 | for pc in pcs: 24 | if pc.size()[0] > max_size: max_size = pc.size()[0] 25 | align_size = max_size 26 | else: 27 | for pc in pcs: 28 | assert pc.size()[0] <= align_size, "LPRCollate: pc.size()[0] <= align_size" 29 | 30 | newpcs:List[torch.Tensor] = [] 31 | for pc in pcs: 32 | # zero padding 33 | newpcs.append(F.pad(pc, (0,0,0,align_size-pc.size()[0]), "constant", 0)) 34 | return newpcs 35 | 36 | def display_inlier_outlier(cloud, ind): 37 | inlier_cloud = cloud.select_by_index(ind) 38 | outlier_cloud = cloud.select_by_index(ind, invert=True) 39 | 40 | print("Showing outliers (red) and inliers (gray): ") 41 | outlier_cloud.paint_uniform_color([1, 0, 0]) 42 | inlier_cloud.paint_uniform_color([0.8, 0.8, 0.8]) 43 | o3d.visualization.draw_geometries([inlier_cloud, outlier_cloud], window_name='Open3D Removal Outlier', width=1920, 44 | height=1080, left=50, top=50, point_show_normal=False, mesh_show_wireframe=False, 45 | mesh_show_back_face=False) 46 | 47 | def in_sorted_array(e: int, array: np.ndarray) -> bool: 48 | pos = np.searchsorted(array, e) 49 | if pos == len(array) or pos == -1: 50 | return False 51 | else: 52 | return array[pos] == e 53 | 54 | def triplet_mask(dataset:LPRDataset, labels:List[int])->Tuple[torch.Tensor, torch.Tensor]: 55 | positives_mask = [[in_sorted_array(e, np.sort(np.asarray(dataset.get_positives(label)))) for e in labels] for label in labels] 56 | negatives_mask = [[not in_sorted_array(e, np.sort(np.asarray(dataset.get_non_negatives(label)))) for e in labels] for label in labels] 57 | positives_mask = torch.tensor(positives_mask) 58 | negatives_mask = torch.tensor(negatives_mask) 59 | return positives_mask, negatives_mask -------------------------------------------------------------------------------- /datasets/dataloders/lprdataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import argparse 5 | from torch.utils.data import DataLoader 6 | import torch.distributed as dist 7 | from typing import Dict, List, Any 8 | from time import sleep 9 | from datasets.dataloders.augments.augment import Augment 10 | from datasets.dataloders.samplers.lprbatchsampler import LPRBatchSampler 11 | from datasets.dataloders.collates.lprcollate import LPRCollate 12 | from datasets.lprdataset import LPRDataset 13 | 14 | from misc.utils import str2bool 15 | 16 | def LPRDataLoader(**kw): 17 | """ 18 | Create dataloaders 19 | """ 20 | 21 | augment = None 22 | if "augment" in kw: augment = Augment(**kw["augment"]) 23 | 24 | dataset = LPRDataset( 25 | rootpath=kw["dataset"], 26 | ) 27 | 28 | sampler = LPRBatchSampler( 29 | dataset=dataset, 30 | **kw["sampler"] 31 | ) 32 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch 33 | collate = LPRCollate(dataset=dataset, augment=augment, **kw["collate"]) 34 | dataloader = DataLoader( 35 | dataset, 36 | batch_sampler=sampler, 37 | collate_fn=collate, 38 | num_workers=kw["num_workers"], 39 | pin_memory=True 40 | ) 41 | return dataloader 42 | 43 | 44 | def parse_opt()->dict: 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--yaml", type=str, default="config/dataloader.yaml") 47 | parser.add_argument("--local_rank", type=int, default=None) 48 | opt = parser.parse_args() 49 | opt = vars(opt) 50 | f = open(opt["yaml"], encoding="utf-8") 51 | kw:Dict[str, Any] = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件 52 | f.close() 53 | kw.update(opt) 54 | return kw 55 | 56 | def test_lprataloader(**kw): 57 | if kw["local_rank"] is not None: 58 | local_rank = int(os.environ["LOCAL_RANK"]) 59 | torch.cuda.set_device(local_rank) 60 | dist.init_process_group(backend="nccl") 61 | 62 | 63 | dataloader = LPRDataLoader(**kw["dataloader"]) 64 | for epoch in range(kw["show"]["epoch"]): 65 | print("epoch", epoch) 66 | for data, mask in dataloader: 67 | if kw["show"]["data"]: 68 | for e in data: print(e, "\n", data[e]) 69 | if kw["show"]["mask"]: 70 | for e in mask: print(e, "\n", mask[e]) 71 | sleep(kw["show"]["sleep"]) 72 | return 73 | 74 | if __name__ == "__main__": 75 | test_lprataloader(**parse_opt()) -------------------------------------------------------------------------------- /datasets/dataloders/samplers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/samplers/__init__.py -------------------------------------------------------------------------------- /datasets/dataloders/samplers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class BaseSample(object, metaclass=abc.ABCMeta): 4 | """ 5 | # Base class for all sample 6 | """ 7 | @abc.abstractmethod 8 | def __init__(self): 9 | pass 10 | 11 | @abc.abstractmethod 12 | def __call__(self): 13 | """ 14 | # Generate heterogeneous indices of batches 15 | """ 16 | pass 17 | 18 | @abc.abstractmethod 19 | def get_k(self)->int: 20 | """ 21 | # Ensure batch_size % k == 0 22 | """ 23 | pass 24 | -------------------------------------------------------------------------------- /datasets/dataloders/samplers/batch.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from typing import List 4 | from datasets.lprdataset import LPRDataset 5 | from datasets.dataloders.samplers.base import BaseSample 6 | 7 | class BatchSample(BaseSample): 8 | """ 9 | # Batch sampling for dataset 10 | """ 11 | def __init__(self, dataset:LPRDataset, shuffle:bool, max_batches:int): 12 | print("Sampling Mechanism: BatchSample") 13 | self.dataset = dataset 14 | self.max_batches = max_batches 15 | self.k = 1 16 | self.shuffle = shuffle 17 | 18 | def get_k(self): 19 | return self.k 20 | 21 | def __call__(self, batch_size:int) -> List[List[int]]: 22 | indices = self.dataset.get_indices() 23 | indices = np.sort(indices) 24 | if self.shuffle: random.shuffle(indices) 25 | # remove tail 26 | indices = indices[:indices.shape[0] - indices.shape[0] % batch_size] 27 | # reshape to (batches, batch_size) & tolist 28 | batch_idx = indices.reshape((-1, batch_size)).tolist() 29 | return batch_idx 30 | 31 | -------------------------------------------------------------------------------- /datasets/dataloders/samplers/hetero.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from typing import List 4 | from copy import deepcopy 5 | from datasets.lprdataset import LPRDataset 6 | from datasets.dataloders.samplers.base import BaseSample 7 | 8 | class HeteroTripletSample(BaseSample): 9 | def __init__(self, dataset:LPRDataset, max_batches:int): 10 | """ 11 | # Sampling mechanism for heterogeneous data 12 | """ 13 | print("Sampling Mechanism: HeteroTripletSample") 14 | self.dataset = dataset 15 | self.max_batches = max_batches 16 | self.k = 2 17 | 18 | def get_k(self) -> int: 19 | return self.k 20 | 21 | def __call__(self, batch_size:int) -> List[List[int]]: 22 | """ 23 | # Generate heterogeneous indices of batches 24 | ## Input 25 | * batch_size 26 | ## Output 27 | batch_idx: [[batch0], [batch1], ..., [batchn]] 28 | """ 29 | 30 | assert self.k == 2, "HeteroTripletSample: sampler can sample only k=3 elements from the same class" 31 | assert batch_size >= self.k, "HeteroTripletSample: batch_size >= k" 32 | assert batch_size%self.k == 0, "HeteroTripletSample: batch_size%k == 0" 33 | 34 | # batches indices of an training epoch 35 | batch_idx:List[List[int]] = [] 36 | # unused 37 | unused_elements_ndx:List[int] = self.dataset.get_indices().tolist() 38 | # current 39 | current_batch:List[int] = [] 40 | 41 | # items with heterogeneous positive samples in dataset 42 | anchors:List[int] = self.dataset.get_anchors().tolist() 43 | 44 | while True: 45 | anchor = random.choice(anchors) 46 | anchor_geneous = self.dataset.get_all_geneous()[anchor] 47 | current_batch.append(anchor) 48 | anchors.remove(anchor) 49 | unused_elements_ndx.remove(anchor) 50 | 51 | unused_elements_ndx_np = np.asarray(unused_elements_ndx) 52 | positives = self.dataset.get_positives(anchor) 53 | 54 | for gid, _ in enumerate(self.dataset.get_geneous_names()): 55 | if gid == anchor_geneous: continue 56 | geneous_positives = np.intersect1d(positives, self.dataset.get_homoindices(gid)) 57 | assert geneous_positives.shape[0] > 0, "HeteroTripletSampler: gpos.shape[0] = 0" 58 | unused_geneous_positives = np.intersect1d( 59 | unused_elements_ndx_np, 60 | geneous_positives 61 | ) 62 | this_geneous_positive:int=None 63 | if len(unused_geneous_positives) != 0: 64 | this_geneous_positive = random.choice(unused_geneous_positives.tolist()) 65 | unused_elements_ndx.remove(this_geneous_positive) 66 | else: 67 | this_geneous_positive = random.choice(geneous_positives.tolist()) 68 | 69 | current_batch.append(this_geneous_positive) 70 | 71 | if this_geneous_positive in anchors: anchors.remove(this_geneous_positive) 72 | 73 | if len(current_batch) >= batch_size: 74 | assert len(current_batch) % self.k == 0 75 | batch_idx.append(deepcopy(current_batch)) 76 | current_batch = [] 77 | if (self.max_batches is not None) and (len(batch_idx) >= self.max_batches): 78 | break 79 | 80 | if len(unused_elements_ndx) == 0 or len(anchors) == 0: 81 | break 82 | return batch_idx 83 | -------------------------------------------------------------------------------- /datasets/dataloders/samplers/homo.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import random 5 | import numpy as np 6 | from typing import List 7 | from datasets.lprdataset import LPRDataset 8 | from datasets.dataloders.samplers.base import BaseSample 9 | 10 | class HomoTripletSample(BaseSample): 11 | def __init__(self, dataset:LPRDataset, max_batches:int): 12 | """ 13 | # Homogeneous sampling 14 | * Sampler returning list of indices to form a mini-batch 15 | * Samples elements in groups consisting of k=2 similar elements (positives) 16 | * Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k 17 | ## Input 18 | * dataset 19 | * max_batches 20 | """ 21 | print("Sampling Mechanism: HomoTripletSample") 22 | self.dataset = dataset 23 | self.max_batches = max_batches 24 | self.k = 2 25 | 26 | def get_k(self) -> int: 27 | return self.k 28 | 29 | def __call__(self, batch_size:int) -> List[List[int]]: 30 | assert self.k == 2, "HomoTripletSample: sampler can sample only k=2 elements from the same class" 31 | assert batch_size >= 2*self.k, "HomoTripletSample: batch_size > 2*k" 32 | assert batch_size%self.k == 0, "HomoTripletSample: batch_size%k == 0" 33 | 34 | # Generate training/evaluation batches. 35 | # batch_idx holds indexes of elements in each batch as a list of lists 36 | batch_idx:List[List[int]] = [] 37 | 38 | unused_elements_ndx:List[int] = self.dataset.get_indices().tolist() 39 | 40 | current_batch:List[int] = [] 41 | 42 | while True: 43 | if len(current_batch) >= batch_size or len(unused_elements_ndx) == 0: 44 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more 45 | # elements to process 46 | if len(current_batch) >= 2*self.k: 47 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible 48 | # to find negative examples in the batch 49 | assert len(current_batch) % self.k == 0, "HomoTripletSample: Incorrect bach size: {}".format(len(current_batch)) 50 | batch_idx.append(current_batch) 51 | current_batch = [] 52 | if (self.max_batches is not None) and (len(batch_idx) >= self.max_batches): 53 | break 54 | if len(unused_elements_ndx) == 0: 55 | break 56 | 57 | # Add k=2 similar elements to the batch 58 | selected_element = random.choice(unused_elements_ndx) 59 | 60 | unused_elements_ndx.remove(selected_element) 61 | 62 | positives = list(self.dataset.get_positives(selected_element)) 63 | if len(positives) == 0: 64 | # Broken dataset element without any positives 65 | continue 66 | unused_positives = [e for e in positives if e in unused_elements_ndx] 67 | # If there're unused elements similar to selected_element, sample from them 68 | # otherwise sample from all similar elements 69 | if len(unused_positives) > 0: 70 | second_positive = random.choice(unused_positives) 71 | unused_elements_ndx.remove(second_positive) 72 | else: 73 | second_positive = random.choice(list(positives)) 74 | current_batch += [selected_element, second_positive] 75 | 76 | for batch in batch_idx: 77 | assert len(batch) % self.k == 0, "Incorrect bach size: {}".format(len(batch)) 78 | 79 | return batch_idx -------------------------------------------------------------------------------- /datasets/dataloders/samplers/lprbatchsampler.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Any 2 | import torch.distributed as dist 3 | from torch.utils.data import Sampler 4 | from datasets.lprdataset import LPRDataset 5 | from datasets.dataloders.samplers.utils import broadcast_batch_idx 6 | 7 | class LPRBatchSampler(Sampler[List[int]]): 8 | """ 9 | # Wrapper for all sampler 10 | """ 11 | def __init__( 12 | self, 13 | dataset:LPRDataset, 14 | name:str, 15 | batch_size:int, 16 | batch_size_limit:int, 17 | batch_expansion_rate:float, 18 | **kw, 19 | ): 20 | """ 21 | # Re-generate batch indices 22 | # Input 23 | * dataset 24 | * batch_size: initial batch size 25 | * sample: `["BaseSample", "HomoTripletSample", "HeteroTripletSample", "RandomSample"]` 26 | * batch_size_limit: max batch size 27 | * batch_expansion_rate 28 | * max_batches 29 | """ 30 | # sample factory 31 | self.sample_fn = None 32 | if name == "BatchSample": 33 | from datasets.dataloders.samplers.batch import BatchSample 34 | self.sample_fn = BatchSample(dataset=dataset, **kw) 35 | elif name == "HomoTripletSample": 36 | from datasets.dataloders.samplers.homo import HomoTripletSample 37 | self.sample_fn = HomoTripletSample(dataset=dataset, **kw) 38 | elif name == "HeteroTripletSample": 39 | from datasets.dataloders.samplers.hetero import HeteroTripletSample 40 | self.sample_fn = HeteroTripletSample(dataset=dataset, **kw) 41 | else: 42 | raise NotImplementedError("LPRBatchSampler: %s sample_fn not implemented" % name) 43 | 44 | # gpu mode 45 | self.use_dist = False 46 | if dist.is_initialized(): 47 | # multi-gpu 48 | self.use_dist = True 49 | if dist.get_rank() == 0: print("LPRBatchSampler: multi-gpu mode") 50 | else: 51 | # single-gpu 52 | print("LPRBatchSampler: sigle-gpu mode") 53 | 54 | 55 | self.batch_size = batch_size - batch_size%self.sample_fn.get_k() 56 | self.batch_size_limit = batch_size_limit 57 | self.batch_expansion_rate = batch_expansion_rate 58 | if batch_expansion_rate is not None: 59 | assert batch_expansion_rate > 1., "LPRBatchSampler: batch_expansion_rate must be greater than 1" 60 | assert batch_size <= batch_size_limit, "LPRBatchSampler: batch_size_limit must be greater or equal to batch_size" 61 | 62 | self.batch_idx = [] 63 | 64 | 65 | def __iter__(self): 66 | """ 67 | # Generate A Bacth_idx 68 | """ 69 | # multi-gpu 70 | if self.use_dist: 71 | gen_rank = 0 72 | all_batch_idx:List[List[int]] = None 73 | if dist.get_rank() == gen_rank: 74 | # generate all_batch_idx 75 | all_batch_idx = self.sample_fn(self.batch_size) 76 | else: pass 77 | # broadcast all_batch_idx to all process 78 | self.batch_idx = broadcast_batch_idx( 79 | batch_size=self.batch_size, 80 | all_batch_idx=all_batch_idx, 81 | gen_rank=gen_rank 82 | ) 83 | # single-gpu 84 | else: 85 | self.batch_idx = self.sample_fn(self.batch_size) 86 | 87 | 88 | for batch in self.batch_idx: yield batch 89 | 90 | 91 | def __len__(self): 92 | return len(self.batch_idx) 93 | 94 | def expand_batch(self): 95 | """ 96 | # Expand batch_size by batch_expansion_rate 97 | """ 98 | if self.batch_expansion_rate is None: 99 | print("LPRBatchSampler: WARNING batch_expansion_rate is None") 100 | return 101 | 102 | if self.batch_size >= self.batch_size_limit: 103 | return 104 | 105 | old_batch_size = self.batch_size 106 | self.batch_size = int(self.batch_size * self.batch_expansion_rate) 107 | self.batch_size = min(self.batch_size, self.batch_size_limit) 108 | 109 | self.batch_size = self.batch_size - self.batch_size%self.sample_fn.get_k() 110 | 111 | print("LPRBatchSampler: Batch size increased from: {} to {}".format(old_batch_size, self.batch_size)) 112 | -------------------------------------------------------------------------------- /datasets/dataloders/samplers/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.distributed as dist 4 | from typing import List 5 | 6 | def broadcast_batch_idx( 7 | batch_size:int, 8 | all_batch_idx:List[List[int]], 9 | gen_rank:int, 10 | ): 11 | """ 12 | # Assign all_batch_idx to all processes evenly 13 | ## Input 14 | * batch_size 15 | * all_batch_idx 16 | * gen_rank: process id that genenrates all_batch_idx 17 | """ 18 | assert dist.is_available() and dist.is_initialized(), "broadcast: sampler broadcast must be dist.is_initialized()" 19 | rank, world_size = dist.get_rank(), dist.get_world_size() 20 | assert gen_rank < world_size, "broadcast: sampler gen_rank >= word_size" 21 | 22 | # num_size[0] is the number of batch, num_size[1] is the batch_size, broadcast for initilize 23 | num_size = torch.tensor([0, 0], dtype=torch.int64).to(rank) 24 | 25 | broadcast_batch_idx:torch.Tensor = None 26 | 27 | if rank == gen_rank: 28 | assert all_batch_idx is not None, "broadcast_batch_idx: rank == gen_rank and batch_idx is None" 29 | # print("HeteroBatchSampler: rank %d generating batch idx" % rank) 30 | # deepcopy all_batch_idx before writing 31 | all_batch_idx = copy.deepcopy(all_batch_idx) 32 | all_batch_idx = [e for e in all_batch_idx if len(e)==batch_size] 33 | # remove tail to ensure all_batch_idx % world_size == 0 34 | all_batch_idx = all_batch_idx[:len(all_batch_idx)-len(all_batch_idx)%world_size] 35 | # print("cut len = {}, each = {}".format(num_cut_batchs, num_cut_batchs/num_replicas)) 36 | broadcast_batch_idx = torch.tensor(all_batch_idx).detach().type(torch.int64).to(rank) 37 | # record num_size 38 | num_size[0], num_size[1] = broadcast_batch_idx.size() 39 | else: 40 | pass 41 | 42 | # broadcast num_size 43 | dist.broadcast(num_size, gen_rank) 44 | # print("rank {} num_size = {}".format(rank, num_size)) 45 | 46 | # initialize batch_idx according to num_size 47 | if rank == gen_rank: 48 | pass 49 | else: 50 | broadcast_batch_idx = torch.zeros((num_size[0], num_size[1])).detach().type(torch.int64).to(rank) 51 | 52 | # print("rank {} broadcast_batch_idx = {}".format(rank, broadcast_batch_idx.size())) 53 | dist.broadcast(broadcast_batch_idx, gen_rank) 54 | 55 | # broadcast_batch_idx = [[int(c) for c in r] for r in broadcast_batch_idx.cpu()] 56 | broadcast_batch_idx = broadcast_batch_idx.cpu().numpy().tolist() 57 | assert len(broadcast_batch_idx)%world_size == 0, "broadcast: len(broadcast_batch_idx)%world_size != 0" 58 | avg_num = int(len(broadcast_batch_idx)/world_size) 59 | batch_idx = broadcast_batch_idx[rank*avg_num: (rank+1)*avg_num] 60 | 61 | return batch_idx -------------------------------------------------------------------------------- /datasets/lprdataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import argparse 5 | import numpy as np 6 | import open3d as o3d 7 | from tqdm import tqdm 8 | from torch.utils.data import Dataset 9 | import matplotlib.pyplot as plt 10 | from typing import List 11 | from misc.utils import str2bool 12 | 13 | class LPRDataset(Dataset): 14 | """ 15 | # Dataset wrapper for LPRDataset 16 | """ 17 | def __init__(self, rootpath:str,): 18 | self.rootpath = rootpath 19 | assert os.path.exists(self.rootpath), "Cannot access rootpath {}".format(self.rootpath) 20 | print("LPRDataset: {}".format(self.rootpath)) 21 | # 0: ground, 1: aerial 22 | self.geneous_names = ["ground", "aerial"] 23 | self.Ng = len(self.geneous_names) 24 | self.geneous = np.load(os.path.join(self.rootpath, "geneous.npy")) 25 | self.Nm = self.geneous.shape[0] 26 | # self.homoindices 27 | # ground: [0, 2, 3, 5, 8, ...] 28 | # aerial: [1, 4, 6, 7, 9, ...] 29 | self.homoindices = [[] for _ in self.geneous_names] 30 | for ndx in range(self.Nm): 31 | self.homoindices[self.geneous[ndx]].append(ndx) 32 | self.homoindices = [np.asarray(e) for e in self.homoindices] 33 | 34 | # tum format (Nm, 8) [t, x, y, z, qx, qy, qz, qw] 35 | self.tum = np.load(os.path.join(self.rootpath, "tum.npy")) 36 | assert self.Nm == self.tum.shape[0], "LPRDataset: self.Nm != self.tum.shape[0]" 37 | 38 | # make self check files 39 | self.checkpath = os.path.join(self.rootpath, "selfcheck") 40 | if not os.path.exists(self.checkpath): os.mkdir(self.checkpath) 41 | 42 | self.anchors:np.ndarray = None 43 | # load data 44 | self.pcs = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "pointcloud.npy")) for ndx in range(self.Nm)] 45 | self.positives = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "positives.npy")) for ndx in range(self.Nm)] 46 | self.non_negatives = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "non_negatives.npy")) for ndx in range(self.Nm)] 47 | self.get_anchors() 48 | 49 | def __len__(self): 50 | return self.Nm 51 | 52 | def __getitem__(self, ndx): 53 | # Load point cloud and apply transform 54 | pc = torch.tensor(self.get_pc(ndx)) 55 | return ndx, pc 56 | 57 | def get_indices(self) -> np.ndarray: 58 | return np.arange(self.Nm) 59 | 60 | def get_homoindices(self, geneous_id:int) -> np.ndarray: 61 | return np.copy(self.homoindices[geneous_id]) 62 | 63 | def get_geneous_names(self) -> List[str]: 64 | return self.geneous_names 65 | 66 | def get_all_geneous(self) -> np.ndarray: 67 | return np.copy(self.geneous) 68 | 69 | def get_positives(self, ndx:int) -> np.ndarray: 70 | return np.copy(self.positives[ndx]) 71 | 72 | def get_non_negatives(self, ndx:int) -> np.ndarray: 73 | return np.copy(self.non_negatives[ndx]) 74 | 75 | def get_tum(self, ndx:int): 76 | return np.copy(self.tum[ndx]) 77 | 78 | def get_correspondences(self, source_ndx:int, target_ndx:int) -> np.ndarray: 79 | path = os.path.join( 80 | self.rootpath, 81 | "items", 82 | "%06d"%source_ndx, 83 | "correspondence", 84 | "%06d.npy"%target_ndx 85 | ) 86 | return np.load(path) 87 | 88 | def get_pc(self, ndx) -> np.ndarray: 89 | return np.copy(self.pcs[ndx]) 90 | 91 | def get_anchors(self) -> np.ndarray: 92 | """ 93 | # Get indices of items with heterogeneous positive samples in dataset 94 | """ 95 | if self.anchors is not None: return np.copy(self.anchors) 96 | print("LPRDataset: self.anchors is None, generating") 97 | anchors = [] 98 | for i in self.get_indices(): 99 | positives = self.get_positives(i) 100 | is_anchor = True 101 | for gid, gname in enumerate(self.geneous_names): 102 | if np.intersect1d(positives, self.get_homoindices(gid)).shape[0] == 0: 103 | is_anchor = False 104 | break 105 | if is_anchor: anchors.append(i) 106 | 107 | self.anchors = np.asarray(anchors) 108 | 109 | for gid, gname in enumerate(self.get_geneous_names()): 110 | ganchors = np.intersect1d( 111 | self.anchors, 112 | self.get_homoindices(gid) 113 | ).shape[0] 114 | print("LPRDataset: %s has %d anchors" % (gname, ganchors)) 115 | 116 | return np.copy(self.anchors) 117 | 118 | def check_hetero_triplet(self): 119 | """ 120 | # Count hetero triplet number 121 | """ 122 | print("LPRDataset: check hetero triplet") 123 | # multi_geneous_positives 124 | mgp = np.zeros((self.Ng, self.Ng)) 125 | mgn = np.zeros((self.Ng, self.Ng)) 126 | for i in tqdm(self.get_indices()): 127 | sgid = self.geneous[i] 128 | positives = self.get_positives(i) 129 | non_negative = self.get_non_negatives(i) 130 | for tgid in range(self.Ng): 131 | mgp[sgid][tgid] += np.intersect1d(positives, self.homoindices[tgid]).shape[0] 132 | # mgn[sgid][tgid] += np.intersect1d(non_negative, self.homoindices[tgid]).shape[0] 133 | mgn[sgid][tgid] += np.intersect1d( 134 | np.setdiff1d(self.get_indices(), non_negative), 135 | self.homoindices[tgid] 136 | ).shape[0] 137 | mgp = mgp/np.array([self.homoindices[0].shape[0], self.homoindices[1].shape[0]]) 138 | mgn = mgn/np.array([self.homoindices[0].shape[0], self.homoindices[1].shape[0]]) 139 | print("Avg positive:") 140 | print(str(mgp)) 141 | print("Avg negative:") 142 | print(str(mgn)) 143 | return 144 | 145 | def check_positives(self, step:int=1): 146 | print("LPRDataset: check_positives") 147 | pos_map:dict[str, np.ndarray] = {} 148 | for sgid, source in enumerate(self.get_geneous_names()): 149 | for tgid, target in enumerate(self.get_geneous_names()): 150 | keyname = "%s-%s" % (source, target) 151 | npos = [] 152 | nmap = [] 153 | sindices = self.get_homoindices(sgid) 154 | tindices = self.get_homoindices(tgid) 155 | for sndx in tqdm(sindices, desc=keyname): 156 | this_npos = np.intersect1d( 157 | self.get_positives(sndx), 158 | tindices 159 | ).shape[0] 160 | this_npos = int(this_npos/step)*step 161 | if this_npos in npos: nmap[npos.index(this_npos)] += 1 162 | else: 163 | npos.append(this_npos) 164 | nmap.append(1) 165 | this_pos_map = np.asarray([npos, nmap]) 166 | sort_ndx = np.argsort(this_pos_map[0]) 167 | this_pos_map = this_pos_map[:, sort_ndx] 168 | pos_map[keyname] = this_pos_map 169 | 170 | plt.figure(figsize=(7,4)) 171 | 172 | plt.grid() 173 | for keyname in pos_map: 174 | plt.plot(pos_map[keyname][0], pos_map[keyname][1]) 175 | 176 | 177 | plt.xlabel("Number of positive samples in database") 178 | plt.ylabel("Number of queries") 179 | plt.legend(list(pos_map.keys())) 180 | plt.show() 181 | 182 | return 183 | 184 | def check_pn(self): 185 | print("LPRDataset: check points number") 186 | for gid, geneous in enumerate(self.geneous_names): 187 | if self.get_homoindices(gid).shape[0] == 0: continue 188 | pn = 0 189 | for i in self.homoindices[gid]: 190 | pn += self.get_pc(i).shape[0] 191 | avgpn = pn / self.homoindices[gid].shape[0] 192 | print("%s avg pn = %.3f" % (geneous, avgpn)) 193 | return 194 | 195 | 196 | def show_submaps(self, N=10): 197 | """ 198 | # visualize some submaps 199 | """ 200 | anchors = self.get_anchors() 201 | 202 | ganchors = np.intersect1d( 203 | anchors, 204 | self.get_homoindices(1) 205 | ) 206 | for _ in range(N): 207 | a = random.choice(ganchors) 208 | p = random.choice( 209 | np.intersect1d( 210 | self.get_homoindices(0), 211 | self.get_positives(a) 212 | ) 213 | ) 214 | pcda = o3d.geometry.PointCloud() 215 | pcda.points = o3d.utility.Vector3dVector(self.get_pc(a)) 216 | 217 | pcdp = o3d.geometry.PointCloud() 218 | pcdp.points = o3d.utility.Vector3dVector(self.get_pc(p) + np.asarray([-70, 0, 0])) 219 | 220 | o3d.visualization.draw_geometries( 221 | [pcda, pcdp], 222 | window_name="left: gorund, right: aerial" 223 | ) 224 | 225 | 226 | 227 | 228 | def test_dataset(): 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument("--dataset", type=str, required=True) 231 | parser.add_argument("--check_positives", type=str2bool, default=True) 232 | parser.add_argument("--check_hetero_triplet", type=str2bool, default=True) 233 | parser.add_argument("--check_pn", type=str2bool, default=True) 234 | parser.add_argument("--get_anchors", type=str2bool, default=True) 235 | parser.add_argument("--show_submaps", type=str2bool, default=5) 236 | opt = parser.parse_args() 237 | opt = vars(opt) 238 | lprdataset = LPRDataset(rootpath=opt["dataset"]) 239 | 240 | if opt["check_positives"]: lprdataset.check_positives() 241 | if opt["check_hetero_triplet"]: lprdataset.check_hetero_triplet() 242 | if opt["check_pn"]: lprdataset.check_pn() 243 | if opt["get_anchors"]: lprdataset.get_anchors() 244 | if opt["show_submaps"] > 0: lprdataset.show_submaps(opt["show_submaps"]) 245 | 246 | return 247 | 248 | if __name__ == "__main__": 249 | test_dataset() -------------------------------------------------------------------------------- /evaluate/once.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import yaml 6 | from datasets.dataloders.lprdataloader import LPRDataLoader 7 | from models.lprmodel import LPRModel 8 | from evaluate.utils import get_embeddings, get_hetero_topN_recall, get_hetero_recall_precision, show_closest 9 | 10 | from tqdm import tqdm 11 | from misc.utils import get_datetime 12 | import matplotlib.pyplot as plt 13 | 14 | def parse_opt()->dict: 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--weights", type=str, required=True) 17 | parser.add_argument("--yaml", type=str, required=True) 18 | parser.add_argument("--tn", type=int, default=30) 19 | parser.add_argument("--rp", type=int, default=100) 20 | parser.add_argument("--save", type=str, default="results/evaluate/") 21 | opt = parser.parse_args() 22 | opt = vars(opt) 23 | f = open(opt["yaml"], encoding="utf-8") 24 | lpreval = yaml.load(f, Loader=yaml.FullLoader) 25 | 26 | lpreval.update(opt) 27 | return lpreval 28 | 29 | 30 | def feat_l2d_mat(embeddings: np.ndarray) -> np.ndarray: 31 | Nm, Fs = embeddings.shape 32 | distance = np.linalg.norm(embeddings.reshape((Nm, 1, Fs)) - embeddings.reshape((1, Nm, Fs)), axis=2) 33 | distance += np.eye(Nm)*(np.max(distance)+1) # eye 34 | return distance 35 | 36 | 37 | def main(**kw): 38 | 39 | dataloader = LPRDataLoader(**kw["dataloaders"]["evaluate"]) 40 | 41 | device:str = None 42 | if torch.cuda.is_available(): device = "cuda" 43 | else: device = "cpu" 44 | print("Device: {}".format(device)) 45 | assert os.path.exists(kw["weights"]), "Cannot open network weights: {}".format(kw["weights"]) 46 | print("Loading weights: {}".format(kw["weights"])) 47 | 48 | lprmodel = LPRModel() 49 | lprmodel.load(kw["weights"], device) 50 | 51 | # check savepath 52 | savepath = None 53 | if kw["save"] is not None: 54 | assert os.path.exists(kw["save"]), "Path does not exist, please run: mkdir " + kw["save"] 55 | savepath = os.path.join(kw["save"], get_datetime()) 56 | os.mkdir(savepath) 57 | print("Save path:", savepath) 58 | 59 | # recall-precision 60 | if kw["rp"] < 1: 61 | print("Evaluation of Recall-Precision: Skip.") 62 | else: 63 | print("Evaluation of Recall-Precision: %d steps." % kw["rp"]) 64 | distance = feat_l2d_mat(get_embeddings(lprmodel, dataloader, device, print_stats=False)) 65 | rp = get_hetero_recall_precision(dataloader.dataset, distance, num_eval=kw["rp"]) 66 | plt.figure() 67 | 68 | plt.xlim(-0.1, 1.1) 69 | plt.ylim(-0.1, 1.1) 70 | plt.grid() 71 | for st in rp: 72 | plt.plot(rp[st]["xy"][0], rp[st]["xy"][1]) 73 | for i, d in enumerate(rp[st]["ds"]): 74 | plt.annotate( 75 | text="%.2f"%d, 76 | xy=(rp[st]["xy"][0][i], rp[st]["xy"][1][i]), 77 | xytext=(rp[st]["xy"][0][i], rp[st]["xy"][1][i]), 78 | fontsize=10, 79 | ) 80 | plt.xlabel("recall") 81 | plt.ylabel("precision") 82 | plt.legend(list(rp)) 83 | if savepath is not None: plt.savefig(os.path.join(savepath, "recall-precision.png")) 84 | plt.close() 85 | 86 | 87 | # average topN-recall 88 | if kw["tn"] < 1: 89 | print("Evaluation of TopN-Recall: Skip.") 90 | else: 91 | topNs = [] 92 | print("Evaluation of TopN-Recall: %d epochs, the average is taken." % kw["tn"]) 93 | print("(The results are saved each epoch. Enter Ctrl+C to stop.)") 94 | iterator = tqdm(range(kw["tn"])) 95 | for _ in iterator: 96 | # get descriptor distance 97 | distance = feat_l2d_mat(get_embeddings(lprmodel, dataloader, device, print_stats=False)) 98 | # append to all topN recall 99 | topNs.append(get_hetero_topN_recall(dataloader.dataset, distance)) 100 | # take average values 101 | tn = {} 102 | for e in topNs[0]: tn[e] = np.stack([topN[e] for topN in topNs], axis=0).mean(axis=0) 103 | 104 | plt.figure() 105 | plt.grid() 106 | for e in tn: plt.plot(tn[e]) 107 | plt.xlabel("TopN") 108 | plt.ylabel("Recall") 109 | plt.legend(list(tn)) 110 | if savepath is not None: plt.savefig(os.path.join(savepath, "topN-recall.png")) 111 | plt.close() 112 | 113 | stats = "Top1-Recall: " 114 | for e in list(tn.keys()): stats += "%s:%.3f|" % (e, tn[e][0]) 115 | iterator.set_postfix_str(stats) 116 | 117 | return 118 | 119 | 120 | if __name__ == "__main__": 121 | main(**parse_opt()) -------------------------------------------------------------------------------- /evaluate/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import open3d as o3d 7 | from tqdm import tqdm 8 | from typing import List 9 | from time import sleep 10 | from sklearn.neighbors import KDTree 11 | from sklearn import manifold 12 | from torch.utils.data import DataLoader 13 | from datasets.lprdataset import LPRDataset 14 | from models.lprmodel import LPRModel 15 | from misc.utils import tensors2device 16 | 17 | 18 | 19 | 20 | 21 | 22 | def get_embeddings(lprmodel: LPRModel, dataloader: DataLoader, device:str, print_stats:bool=True): 23 | lprmodel.model = lprmodel.model.to(device) 24 | lprmodel.model.eval() 25 | embeddings = [] 26 | if print_stats: iterater = tqdm(dataloader, desc="Getting embedding") 27 | else: iterater=dataloader 28 | for data, mask in iterater: 29 | data = tensors2device(data, device) 30 | with torch.no_grad(): 31 | output = lprmodel(data) 32 | assert "embeddings" in output, "Evaluate: no embeddings in model output" 33 | embeddings.append(output["embeddings"].clone().detach().cpu().numpy()) 34 | # visualize_cnn_feats_scores(output["feats"], output["scores"]) 35 | # visualize_cnn_feats(output["coords"], output["feats"], data["coords"]) 36 | # visualize_scores(output["coords"], output["scores"], data["coords"]) 37 | 38 | embeddings = np.concatenate(embeddings, axis=0) 39 | if print_stats: print("Embeddings size = ", embeddings.shape) 40 | # np.save("examples/minkloc3d/results/embeddings", embeddings) 41 | return embeddings 42 | 43 | def get_topN_recall_curve( 44 | dataset:LPRDataset, 45 | distance:np.ndarray, 46 | source_indices:np.ndarray, 47 | target_indices:np.ndarray, 48 | topN:int=10 49 | ): 50 | # top_one_percent = int(target_indices.shape[0]/100) 51 | topN_count = np.zeros((topN,), dtype=np.int32) 52 | topN_recall = np.zeros((topN,)) 53 | for j in range(topN): 54 | for sndx in source_indices: 55 | real_positive = np.intersect1d( 56 | dataset.get_positives(sndx), 57 | target_indices 58 | ) 59 | # print("true_positives", true_positives) 60 | pred_positive = target_indices[np.argsort(distance[sndx][target_indices])][:j] 61 | 62 | recall_positive = np.intersect1d( 63 | pred_positive, 64 | real_positive 65 | ) 66 | if recall_positive.shape[0] != 0: 67 | topN_count[j] += 1 68 | topN_recall[j] = float(topN_count[j]) / float(source_indices.shape[0]) 69 | 70 | return topN_recall[1:] 71 | 72 | def get_hetero_topN_recall( 73 | dataset:LPRDataset, 74 | distance:np.ndarray, 75 | savepath:str=None, 76 | show:bool=False 77 | ): 78 | assert len(dataset) == distance.shape[0], "Evaluate: len(datasets) == embeddings.shape[0]" 79 | 80 | # source gen 81 | geneous_names = dataset.get_geneous_names() 82 | 83 | all_topN_recall = {} 84 | for sgid, source in enumerate(geneous_names): 85 | sgindices = dataset.get_homoindices(sgid) 86 | for tgid, target in enumerate(geneous_names): 87 | tgindices = dataset.get_homoindices(tgid) 88 | # print(source, "-", target) 89 | if sgindices.shape[0] == 0 or tgindices.shape[0] == 0: 90 | print("no instance in source or target, continue") 91 | continue 92 | topN_recall = get_topN_recall_curve(dataset, distance, sgindices, tgindices) 93 | all_topN_recall["{}-{}".format(source, target)] = topN_recall 94 | # print("all-all") 95 | all_topN_recall["all-all"] = get_topN_recall_curve(dataset, distance, dataset.get_indices(), dataset.get_indices()) 96 | 97 | plt.figure() 98 | 99 | # plt.ylim(-0.1, 1.1) 100 | plt.grid() 101 | for topN_recall in all_topN_recall: 102 | plt.plot(all_topN_recall[topN_recall]) 103 | plt.xlabel("topN") 104 | plt.ylabel("recall") 105 | plt.legend(list(all_topN_recall)) 106 | if savepath is not None: plt.savefig(os.path.join(savepath, "topN-recall.png")) 107 | elif show: plt.show() 108 | else: plt.close() 109 | return all_topN_recall 110 | 111 | def get_recall_precision_curve( 112 | dataset:LPRDataset, 113 | distance:np.ndarray, 114 | source_indices:np.ndarray, 115 | target_indices:np.ndarray, 116 | num_eval:int, 117 | ): 118 | # rp = np.array([[0.0, 1.0], [1.0, 0.0]]) 119 | rp = np.empty((0,2)) 120 | ds = np.linspace(np.min(distance)-0.01, np.max(distance)+0.01, num_eval) 121 | for threshold in ds: 122 | threshold_rp = np.empty((0,2)) 123 | for i in source_indices: 124 | real_positive = np.intersect1d( 125 | dataset.get_positives(i), 126 | target_indices, 127 | ) 128 | pred_positive = np.intersect1d( 129 | np.where(distance[i] < threshold)[0], 130 | target_indices, 131 | ) 132 | 133 | # if real_positive.shape[0] == 0 or pred_positive.shape[0] == 0: continue 134 | tp = np.intersect1d(real_positive, pred_positive).shape[0] 135 | fn = np.setdiff1d(real_positive, pred_positive).shape[0] 136 | fp = np.setdiff1d(pred_positive, real_positive).shape[0] 137 | # tqdmiter.write(str(tp)+" "+str(fn)+" "+str(fp)) 138 | recall, precision = 0., 0. 139 | if tp == 0: 140 | if fn == 0 and fp == 0: continue 141 | elif fn == 0 and fp != 0: recall, precision = 1., 0. 142 | elif fn != 0 and fp == 0: recall, precision = 0., 1. 143 | else: recall, precision = 0., 0. 144 | else: 145 | recall = float(tp)/float(tp+fn) 146 | precision = float(tp)/float(tp+fp) 147 | # this_rp.append([recall, pricision]) 148 | threshold_rp = np.concatenate([threshold_rp, np.asarray([[recall, precision]])], axis=0) 149 | 150 | if threshold_rp.shape[0] == 0: continue 151 | threshold_rp = np.mean(np.asarray(threshold_rp), axis=0) 152 | 153 | # tqdm_iter.set_postfix(recall=threshold_rp[0], precision=threshold_rp[1]) 154 | rp = np.concatenate([rp, threshold_rp.reshape(1,2)], axis=0) 155 | # [N, 2] -> [2, N] 156 | rp = rp.T 157 | indices = np.argsort(rp[0]) 158 | rp = rp[:, indices] 159 | ds = ds[indices] 160 | return rp, ds 161 | 162 | def get_hetero_recall_precision( 163 | dataset:LPRDataset, 164 | distance:np.ndarray, 165 | savepath:str=None, 166 | num_eval:int=100, 167 | show:bool=False 168 | ): 169 | 170 | assert len(dataset) == distance.shape[0], "Evaluate: len(datasets) == embeddings.shape[0]" 171 | Nm = distance.shape[0] 172 | distance = distance.copy() + np.eye(Nm)*(np.max(distance)+0.01) 173 | 174 | all_rp = {} 175 | 176 | geneous_names = dataset.get_geneous_names() 177 | for sgid, source in enumerate(geneous_names): 178 | sgndx_all = dataset.get_homoindices(sgid) 179 | for tgid, target in enumerate(geneous_names): 180 | tgndx_all = dataset.get_homoindices(tgid) 181 | st = "{}-{}".format(source, target) 182 | # print(st) 183 | if sgndx_all.shape[0] == 0 or tgndx_all.shape[0] == 0: 184 | print("no instance in source or target, continue") 185 | continue 186 | # recall-pricision 187 | all_rp[st] = {} 188 | all_rp[st]["xy"], all_rp[st]["ds"] = get_recall_precision_curve(dataset, distance, sgndx_all, tgndx_all, num_eval) 189 | # print("all-all") 190 | all_rp["all-all"] = {} 191 | all_rp["all-all"]["xy"], all_rp["all-all"]["ds"] = get_recall_precision_curve(dataset, distance, dataset.get_indices(), dataset.get_indices(), num_eval) 192 | 193 | plt.figure(figsize=(20,20)) 194 | 195 | plt.xlim(-0.1, 1.1) 196 | plt.ylim(-0.1, 1.1) 197 | plt.grid() 198 | for st in all_rp: 199 | plt.plot(all_rp[st]["xy"][0], all_rp[st]["xy"][1]) 200 | for i, d in enumerate(all_rp[st]["ds"]): 201 | plt.annotate( 202 | text="%.2f"%d, 203 | xy=(all_rp[st]["xy"][0][i], all_rp[st]["xy"][1][i]), 204 | xytext=(all_rp[st]["xy"][0][i], all_rp[st]["xy"][1][i]), 205 | fontsize=10, 206 | ) 207 | plt.xlabel("recall") 208 | plt.ylabel("precision") 209 | plt.legend(list(all_rp)) 210 | if savepath is not None: plt.savefig(os.path.join(savepath, "recall-precision.png")) 211 | elif show: plt.show() 212 | else: plt.close() 213 | return all_rp 214 | 215 | 216 | def show_closest(dataset:LPRDataset, distance:np.ndarray): 217 | print("Show Closest Submaps") 218 | 219 | geneous_names = dataset.get_geneous_names() 220 | for sgid, source in enumerate(geneous_names): 221 | sgndx_all = dataset.get_homoindices(sgid) 222 | for tgid, target in enumerate(geneous_names): 223 | if source == target: continue 224 | 225 | tgndx_all = dataset.get_homoindices(tgid) 226 | for _ in range(10): 227 | d = 4.0 228 | anchor, closest = None, None 229 | while d > 2.0: 230 | anchor = random.choice(sgndx_all) 231 | closest = tgndx_all[np.argsort(distance[anchor, tgndx_all])][0] 232 | d = distance[anchor][closest] 233 | 234 | 235 | suc = "False" 236 | if closest in dataset.get_positives(anchor): suc = "True" 237 | 238 | anchor_pcd = o3d.geometry.PointCloud() 239 | anchor_pcd.points = o3d.utility.Vector3dVector(dataset.get_pc(anchor) - np.asarray([40,0,0])) 240 | 241 | closest_pcd = o3d.geometry.PointCloud() 242 | closest_pcd.points = o3d.utility.Vector3dVector(dataset.get_pc(closest) + np.asarray([40,0,0])) 243 | 244 | o3d.visualization.draw_geometries( 245 | [anchor_pcd, closest_pcd], 246 | window_name="%s-%s: result=%s, distance=%.3f"%(source, target, suc, d) 247 | ) 248 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/loss/__init__.py -------------------------------------------------------------------------------- /loss/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class BaseLoss(object, metaclass=abc.ABCMeta): 4 | """ 5 | # Base class for all loss functions 6 | 7 | """ 8 | @abc.abstractmethod 9 | def __init__(self): 10 | pass 11 | 12 | @abc.abstractmethod 13 | def __call__(self): 14 | pass 15 | 16 | @abc.abstractmethod 17 | def print_stats(self): 18 | pass 19 | -------------------------------------------------------------------------------- /loss/gapr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import open3d as o3d 5 | from typing import List, Dict, Any, Tuple 6 | from torch.utils.tensorboard import SummaryWriter 7 | import random 8 | 9 | 10 | from loss.base import BaseLoss 11 | from loss.triplet import BatchTripletLoss 12 | from loss.point import PointTripletLoss 13 | from loss.overlap import OverlapLoss 14 | 15 | class GAPRLoss(BaseLoss): 16 | def __init__(self, batch_loss:Dict, point_loss:Dict, overlap_loss:Dict, point_loss_scale:float, overlap_loss_scale:float): 17 | super().__init__() 18 | print("GAPRLoss: point_loss_scale=%.2f overlap_loss_scale=%.2f"%(point_loss_scale, overlap_loss_scale)) 19 | self.batch_loss = BatchTripletLoss(**batch_loss) 20 | self.point_loss = PointTripletLoss(**point_loss) 21 | self.overlap_loss = OverlapLoss(**overlap_loss) 22 | self.point_loss_scale = point_loss_scale 23 | self.overlap_loss_scale = overlap_loss_scale 24 | 25 | def __call__(self, 26 | # model 27 | embeddings:torch.Tensor, 28 | coords:List[torch.Tensor], 29 | feats:List[torch.Tensor], 30 | scores:List[torch.Tensor], 31 | # mask 32 | rotms:torch.Tensor, 33 | trans:torch.Tensor, 34 | positives_mask:torch.Tensor, 35 | negatives_mask:torch.Tensor, 36 | geneous:torch.Tensor 37 | ): 38 | # get global coords 39 | device, BS = embeddings.device, embeddings.shape[0] 40 | rotms, trans = rotms.to(device), trans.to(device) 41 | # R*p + T 42 | global_coords = [torch.mm(rotms[ndx], coords[ndx].clone().detach().transpose(0,1)).transpose(0,1) + trans[ndx].unsqueeze(0) for ndx in range(BS)] 43 | # compute point loss 44 | point_loss, point_stats = self.point_loss(feats, global_coords, positives_mask) 45 | # compute attention loss 46 | overlap_loss, overlap_stats = self.overlap_loss(scores, global_coords, positives_mask, geneous) 47 | # compute batch loss 48 | batch_loss, batch_stats = self.batch_loss(embeddings, embeddings, positives_mask, negatives_mask) 49 | 50 | stats = {"batch":batch_stats, "point":point_stats, "overlap":overlap_stats} 51 | # stats.update(mean_point_stats_show) 52 | return batch_loss+self.point_loss_scale*point_loss+self.overlap_loss_scale*overlap_loss, stats 53 | 54 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 55 | self.batch_loss.print_stats(epoch, phase, writer, stats["batch"]) 56 | self.point_loss.print_stats(epoch, phase, writer, stats["point"]) 57 | self.overlap_loss.print_stats(epoch, phase, writer, stats["overlap"]) 58 | # print("point_consistence_loss: pos=%.3f, neg=%.3f" % (stats["pos_l2ds"], stats["neg_l2ds"])) 59 | return -------------------------------------------------------------------------------- /loss/lprloss.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | from misc.utils import tensors2numbers 3 | from torch.utils.tensorboard import SummaryWriter 4 | 5 | class LPRLoss: 6 | """ 7 | # Wrapper loss function 8 | """ 9 | def __init__(self, name:str, **kw): 10 | self.name = name 11 | if self.name == "GAPRLoss": 12 | from loss.gapr import GAPRLoss 13 | self.loss_fn = GAPRLoss(**kw) 14 | else: 15 | raise NotImplementedError("LPRLoss: loss_fn %s not implemented" % self.name) 16 | 17 | def __call__(self, output:Dict[str, Any], mask:Dict[str, Any]): 18 | loss, stats = None, None 19 | if self.name == "GAPRLoss": 20 | assert set(["embeddings", "coords", "feats", "scores"]) <= set(output.keys()) 21 | assert set(["positives", "negatives", "rotms", "trans", "geneous"]) <= set(mask.keys()) 22 | loss, stats = self.loss_fn( 23 | output["embeddings"], 24 | output["coords"], 25 | output["feats"], 26 | output["scores"], 27 | mask["rotms"], 28 | mask["trans"], 29 | mask["positives"], 30 | mask["negatives"], 31 | mask["geneous"] 32 | ) 33 | else: 34 | raise NotImplementedError("LPRLoss: loss_fn %s not implemented" % self.name) 35 | 36 | assert loss is not None and stats is not None 37 | stats = tensors2numbers(stats) 38 | return loss, stats 39 | 40 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 41 | """ 42 | # visualize stats 43 | """ 44 | if self.name == "GAPRLoss": 45 | self.loss_fn.print_stats(epoch, phase, writer, stats) 46 | else: 47 | raise NotImplementedError("LPRLoss: loss_fn %s.print_stats() not implemented" % self.name) -------------------------------------------------------------------------------- /loss/overlap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open3d as o3d 3 | import numpy as np 4 | from typing import List, Dict, Any 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | from loss.base import BaseLoss 8 | 9 | class OverlapLoss(BaseLoss): 10 | def __init__(self, corr_dist:float): 11 | super().__init__() 12 | print("OverlapLoss: corr_dist=%.2f"%(corr_dist)) 13 | self.corr_dist = corr_dist 14 | def __call__(self, scores:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor, geneous:torch.Tensor): 15 | source_indices, target_indices = torch.where(positives_mask == True) 16 | keep_0 = geneous[source_indices] != geneous[target_indices] 17 | keep_1 = source_indices < target_indices 18 | select_indices = np.where(keep_0 & keep_1) 19 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist() 20 | losses = [] 21 | stats = {"fitness":[],"inpair_min":[], "inpair_mean":[], "inpair_max":[], "nopair_min":[], "nopair_mean":[], "nopair_max":[]} 22 | if len(source_indices) == 0: 23 | # no hetero positive pair, refer to sampler 24 | return torch.zeros((1,), device=scores[0].device).type_as(scores[0]), {"loss":0.0,"fitness":0.0,"inpair_min":0.0, "inpair_mean":0.0, "inpair_max":0.0, "nopair_min":0.0, "nopair_mean":0.0, "nopair_max":0.0} 25 | 26 | for sndx, tndx in zip(source_indices, target_indices): 27 | # construct pcd from coords 28 | source_pcd = o3d.geometry.PointCloud() 29 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy()) 30 | source_pcd.paint_uniform_color([0,0,1]) 31 | target_pcd = o3d.geometry.PointCloud() 32 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy()) 33 | target_pcd.paint_uniform_color([1,0,0]) 34 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4)) 35 | corr_set = np.asarray(reg_p2p.correspondence_set) 36 | assert reg_p2p.fitness > 0.05 37 | 38 | 39 | Ns, Nt = coords[sndx].shape[0], coords[tndx].shape[0] 40 | source_inpair_indices, source_nopair_indices = corr_set[:, 0], np.setdiff1d(np.arange(Ns), corr_set[:, 0]) 41 | target_inpair_indices, target_nopair_indices = corr_set[:, 1], np.setdiff1d(np.arange(Nt), corr_set[:, 1]) 42 | 43 | 44 | source_loss = scores[sndx][source_nopair_indices].mean() + 1.0 - scores[sndx][source_inpair_indices].mean() 45 | target_loss = scores[tndx][target_nopair_indices].mean() + 1.0 - scores[tndx][target_inpair_indices].mean() 46 | losses += [source_loss, target_loss] 47 | 48 | stats["fitness"] += [reg_p2p.fitness] 49 | stats["inpair_min"] += [scores[sndx][source_inpair_indices].min().item(), scores[tndx][target_inpair_indices].min().item()] 50 | stats["inpair_mean"] += [scores[sndx][source_inpair_indices].mean().item(), scores[tndx][target_inpair_indices].mean().item()] 51 | stats["inpair_max"] += [scores[sndx][source_inpair_indices].max().item(), scores[tndx][target_inpair_indices].max().item()] 52 | stats["nopair_min"] += [scores[sndx][source_nopair_indices].min().item(), scores[tndx][target_nopair_indices].min().item()] 53 | stats["nopair_mean"] += [scores[sndx][source_nopair_indices].mean().item(), scores[tndx][target_nopair_indices].mean().item()] 54 | stats["nopair_max"] += [scores[sndx][source_nopair_indices].max().item(), scores[tndx][target_nopair_indices].max().item()] 55 | 56 | loss = torch.stack(losses).mean() 57 | avg_stats = {e: np.mean(stats[e]) for e in stats} 58 | avg_stats["loss"] = loss.item() 59 | 60 | return loss, avg_stats 61 | 62 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 63 | print("OverlapLoss: %.3f" % (stats["loss"])) 64 | print("Overlap: %.3f, %.3f, %.3f | Non-overlap: %.3f, %.3f, %.3f | Fitness: %.3f" % ( 65 | stats["inpair_min"], stats["inpair_mean"], stats["inpair_max"], 66 | stats["nopair_min"], stats["nopair_mean"], stats["nopair_max"], 67 | stats["fitness"] 68 | )) 69 | return -------------------------------------------------------------------------------- /loss/point.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import random 4 | import numpy as np 5 | import open3d as o3d 6 | from typing import List, Dict, Any 7 | from loss.base import BaseLoss 8 | from torch.utils.tensorboard import SummaryWriter 9 | 10 | def get_max_per_row(mat:torch.Tensor, mask:torch.Tensor): 11 | non_zero_rows = torch.any(mask, dim=1) 12 | mat_masked = mat.clone() 13 | mat_masked[~mask] = 0 14 | return torch.max(mat_masked, dim=1), non_zero_rows 15 | 16 | 17 | def get_min_per_row(mat:torch.Tensor, mask:torch.Tensor): 18 | non_inf_rows = torch.any(mask, dim=1) 19 | mat_masked = mat.clone() 20 | mat_masked[~mask] = float("inf") 21 | return torch.min(mat_masked, dim=1), non_inf_rows 22 | 23 | 24 | class PointTripletLoss(BaseLoss): 25 | def __init__(self, margin:float, style:str, corr_dist:float, sample_num:int, pos_dist:float, neg_dist:float): 26 | super().__init__() 27 | assert style in ["soft", "hard"] 28 | print("PointTripletLoss: margin=%.1f, style=%s" % (margin, style)) 29 | self.margin = margin 30 | self.style = style 31 | self.corr_dist = corr_dist 32 | self.sample_num = sample_num 33 | self.pos_dist = pos_dist 34 | self.neg_dist = neg_dist 35 | 36 | def __call__(self, feats:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor): 37 | source_indices, target_indices = torch.where(positives_mask == True) 38 | select_indices = torch.where(source_indices < target_indices) 39 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist() 40 | 41 | losses = [] 42 | point_stats = { 43 | "fitness":[], 44 | "triplet_num":[], 45 | "non_zero_triplet_num":[], 46 | "pos_min":[], 47 | "pos_mean":[], 48 | "pos_max":[], 49 | "neg_min":[], 50 | "neg_mean":[], 51 | "neg_max":[] 52 | } 53 | 54 | for sndx, tndx in zip(source_indices, target_indices): 55 | # construct pcd from coords 56 | source_pcd = o3d.geometry.PointCloud() 57 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy()) 58 | source_pcd.paint_uniform_color([0,0,1]) 59 | target_pcd = o3d.geometry.PointCloud() 60 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy()) 61 | target_pcd.paint_uniform_color([1,0,0]) 62 | # o3d.visualization.draw_geometries([source_pcd, target_pcd]) 63 | # icp and get points set 64 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4)) 65 | corr_set = np.asarray(reg_p2p.correspondence_set) 66 | assert reg_p2p.fitness > 0.05 67 | # sample, Ns = 64 68 | sample_indices = np.random.choice(corr_set.shape[0], min(corr_set.shape[0], self.sample_num)) 69 | Ns = sample_indices.shape[0] 70 | # sample_set: 71 | # [ [s0, s1, s2, ... ], Ns 72 | # [t0, t1, t2, ... ] ] Ns 73 | sample_set = corr_set[sample_indices].T.tolist() 74 | 75 | # sample coords and feats 76 | scoord, tcoord = coords[sndx][sample_set[0]], coords[tndx][sample_set[1]] 77 | sfeat, tfeat = feats[sndx][sample_set[0]], feats[tndx][sample_set[1]] 78 | # Ns * Ns 79 | coord_dist = torch.norm(scoord.unsqueeze(1) - tcoord.unsqueeze(0), dim=2) 80 | # Ns * Ns 81 | feat_dist = torch.norm(sfeat.unsqueeze(1) - tfeat.unsqueeze(0), dim=2) 82 | # get hardest positive and negative 83 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(feat_dist, coord_dist < self.pos_dist) 84 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(feat_dist, coord_dist > self.neg_dist) 85 | # positive <=> anchor <=> negative 86 | a_keep_idx = torch.where(a1p_keep & a2n_keep)[0] 87 | triplet_num = a_keep_idx.shape[0] 88 | if triplet_num == 0: continue 89 | 90 | anc_ind = torch.arange(Ns).to(hardest_positive_indices.device)[a_keep_idx] 91 | pos_ind = hardest_positive_indices[a_keep_idx] 92 | neg_ind = hardest_negative_indices[a_keep_idx] 93 | 94 | triplet_dist = torch.norm(sfeat[anc_ind] - tfeat[pos_ind], dim=1) - torch.norm(sfeat[anc_ind] - tfeat[neg_ind], dim=1) 95 | 96 | non_zero_triplet_num = torch.where((triplet_dist + self.margin) > 0)[0].shape[0] 97 | 98 | 99 | if self.style == "hard": 100 | this_pair_loss = F.relu(triplet_dist + self.margin).mean() 101 | elif self.style == "soft": 102 | this_pair_loss = torch.log(1+self.margin*torch.exp(triplet_dist)).mean() 103 | else: 104 | raise NotImplementedError(f"PointTripletLoss: unkown style {self.style}") 105 | # this_pair_loss = F.relu(triplet_dist).mean() 106 | 107 | losses.append(this_pair_loss) 108 | 109 | point_stats["fitness"].append(reg_p2p.fitness) 110 | point_stats["triplet_num"].append(triplet_num) 111 | point_stats["non_zero_triplet_num"].append(non_zero_triplet_num) 112 | point_stats["pos_min"].append(hardest_positive_dist[a_keep_idx].min().item()) 113 | point_stats["pos_mean"].append(hardest_positive_dist[a_keep_idx].mean().item()) 114 | point_stats["pos_max"].append(hardest_positive_dist[a_keep_idx].max().item()) 115 | point_stats["neg_min"].append(hardest_negative_dist[a_keep_idx].min().item()) 116 | point_stats["neg_mean"].append(hardest_negative_dist[a_keep_idx].mean().item()) 117 | point_stats["neg_max"].append(hardest_negative_dist[a_keep_idx].max().item()) 118 | 119 | avg_point_stats = {e: np.mean(point_stats[e]) for e in point_stats} 120 | loss = torch.stack(losses).mean() 121 | avg_point_stats["loss"] = loss.item() 122 | 123 | return loss, avg_point_stats 124 | 125 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 126 | print("PointTripletLoss: %.3f" % (stats["loss"])) 127 | print( 128 | "Positive: %.3f, %.3f, %.3f | Negative: %.3f, %.3f, %.3f | Triplet: %.1f/%.1f | Fitness:%.3f" % 129 | ( 130 | stats["pos_min"], stats["pos_mean"] ,stats["pos_max"], 131 | stats["neg_min"], stats["neg_mean"] ,stats["neg_max"], 132 | stats["triplet_num"], stats["non_zero_triplet_num"], stats["fitness"] 133 | ) 134 | ) 135 | return 136 | 137 | 138 | ########################### PointConstrativeLoss from LoGG3D ################################# 139 | 140 | def hashM(arr, M): 141 | if isinstance(arr, np.ndarray): 142 | N, D = arr.shape 143 | else: 144 | N, D = len(arr[0]), len(arr) 145 | 146 | hash_vec = np.zeros(N, dtype=np.int64) 147 | for d in range(D): 148 | if isinstance(arr, np.ndarray): 149 | hash_vec += arr[:, d] * M**d 150 | else: 151 | hash_vec += arr[d] * M**d 152 | return hash_vec 153 | 154 | 155 | def pdist(A, B, dist_type="L2"): 156 | if dist_type == "L2": 157 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 158 | return torch.sqrt(D2 + 1e-7) 159 | elif dist_type == "SquareL2": 160 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2) 161 | else: 162 | raise NotImplementedError("Not implemented") 163 | 164 | 165 | # class PointContrastiveLoss(BaseLoss): 166 | class PointContrastiveLoss(BaseLoss): 167 | def __init__(self, corr_dist:float, pos_margin:float, neg_margin:float, neg_weight:float, num_pos:int, num_hn_samples:int): 168 | super().__init__() 169 | self.corr_dist = corr_dist 170 | self.pos_margin = pos_margin 171 | self.neg_margin = neg_margin 172 | self.neg_weight = neg_weight 173 | self.num_pos = num_pos 174 | self.num_hn_samples = num_hn_samples 175 | 176 | 177 | def __call__(self, feats:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor): 178 | source_indices, target_indices = torch.where(positives_mask == True) 179 | select_indices = torch.where(source_indices < target_indices) 180 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist() 181 | 182 | losses = [] 183 | stats = {} 184 | 185 | for sndx, tndx in zip(source_indices, target_indices): 186 | # construct pcd from coords 187 | source_pcd = o3d.geometry.PointCloud() 188 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy()) 189 | source_pcd.paint_uniform_color([0,0,1]) 190 | target_pcd = o3d.geometry.PointCloud() 191 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy()) 192 | target_pcd.paint_uniform_color([1,0,0]) 193 | # o3d.visualization.draw_geometries([source_pcd, target_pcd]) 194 | # icp and get points set 195 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4)) 196 | corr_set = np.asarray(reg_p2p.correspondence_set) 197 | assert reg_p2p.fitness > 0.05 198 | losses.append(self.point_contrastive_loss(feats[sndx], feats[tndx], corr_set)) 199 | 200 | loss = torch.stack(losses).mean() 201 | stats["loss"] = loss.item() 202 | 203 | return loss, stats 204 | 205 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 206 | print("PointContrastiveLoss: %.3f" % (stats["loss"])) 207 | return 208 | 209 | def point_contrastive_loss(self, 210 | F0:torch.Tensor, F1:torch.Tensor, 211 | positive_pairs:np.ndarray, 212 | # point_pos_margin:float, point_neg_margin:float, 213 | # point_neg_weight:float, 214 | # num_pos=128, num_hn_samples=2048 215 | ): 216 | """ 217 | Randomly select "num-pos" positive pairs. 218 | Find the hardest-negative (from a random subset of num_hn_samples) for each point in a positive pair. 219 | Calculate contrastive loss on the tuple (p1,p2,hn1,hn2) 220 | Based on: https://github.com/chrischoy/FCGF/blob/master/lib/trainer.py 221 | """ 222 | N0, N1 = len(F0), len(F1) 223 | N_pos_pairs = len(positive_pairs) 224 | hash_seed = max(N0, N1) 225 | sel0 = np.random.choice(N0, min(N0, self.num_hn_samples), replace=False) 226 | sel1 = np.random.choice(N1, min(N1, self.num_hn_samples), replace=False) 227 | 228 | if N_pos_pairs > self.num_pos: 229 | pos_sel = np.random.choice(N_pos_pairs, self.num_pos, replace=False) 230 | sample_pos_pairs = positive_pairs[pos_sel] 231 | else: 232 | sample_pos_pairs = positive_pairs 233 | 234 | # Find negatives for all F1[positive_pairs[:, 1]] 235 | subF0, subF1 = F0[sel0], F1[sel1] 236 | 237 | pos_ind0 = sample_pos_pairs[:, 0] # .long() 238 | pos_ind1 = sample_pos_pairs[:, 1] # .long() 239 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1] 240 | 241 | D01 = pdist(posF0, subF1, dist_type="L2") 242 | D10 = pdist(posF1, subF0, dist_type="L2") 243 | 244 | D01min, D01ind = D01.min(1) 245 | D10min, D10ind = D10.min(1) 246 | 247 | if not isinstance(positive_pairs, np.ndarray): 248 | positive_pairs = np.array(positive_pairs, dtype=np.int64) 249 | 250 | pos_keys = hashM(positive_pairs, hash_seed) 251 | 252 | D01ind = sel1[D01ind.cpu().numpy()] 253 | D10ind = sel0[D10ind.cpu().numpy()] 254 | neg_keys0 = hashM([pos_ind0, D01ind], hash_seed) 255 | neg_keys1 = hashM([D10ind, pos_ind1], hash_seed) 256 | 257 | mask0 = torch.from_numpy( 258 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False))) 259 | mask1 = torch.from_numpy( 260 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False))) 261 | pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - self.pos_margin) 262 | neg_loss0 = F.relu(self.neg_margin - D01min[mask0]).pow(2) 263 | neg_loss1 = F.relu(self.neg_margin - D10min[mask1]).pow(2) 264 | 265 | pos_loss = pos_loss.mean() 266 | neg_loss = (neg_loss0.mean() + neg_loss1.mean()) / 2 267 | loss = pos_loss + self.neg_weight * neg_loss 268 | return loss 269 | -------------------------------------------------------------------------------- /loss/triplet.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from typing import List, Dict, Any 7 | from torch.utils.tensorboard import SummaryWriter 8 | from loss.base import BaseLoss 9 | import matplotlib.pylab as plt 10 | 11 | def get_max_per_row(mat:torch.Tensor, mask:torch.Tensor): 12 | non_zero_rows = torch.any(mask, dim=1) 13 | mat_masked = mat.clone() 14 | mat_masked[~mask] = 0 15 | return torch.max(mat_masked, dim=1), non_zero_rows 16 | 17 | def get_min_per_row(mat:torch.Tensor, mask:torch.Tensor): 18 | non_inf_rows = torch.any(mask, dim=1) 19 | mat_masked = mat.clone() 20 | mat_masked[~mask] = float("inf") 21 | return torch.min(mat_masked, dim=1), non_inf_rows 22 | 23 | class TripletMiner: 24 | def __init__(self): 25 | return 26 | def __call__(self, dist_mat:torch.Tensor, positives_mask:torch.Tensor, negatives_mask:torch.Tensor): 27 | # [Ns, Nt] mat 28 | assert dist_mat.shape == positives_mask.shape == negatives_mask.shape 29 | with torch.no_grad(): 30 | # Based on pytorch-metric-learning implementation 31 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask) 32 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask) 33 | a_keep_idx = torch.where(a1p_keep & a2n_keep)[0] 34 | anc_ind = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx] 35 | pos_ind = hardest_positive_indices[a_keep_idx] 36 | neg_ind = hardest_negative_indices[a_keep_idx] 37 | 38 | stats = { 39 | "triplet_num" :a_keep_idx.shape[0], 40 | "max_pos_dist" :torch.max(hardest_positive_dist[a_keep_idx]).item(), 41 | "mean_pos_dist":torch.mean(hardest_positive_dist[a_keep_idx]).item(), 42 | "min_pos_dist" :torch.min(hardest_positive_dist[a_keep_idx]).item(), 43 | "max_neg_dist" :torch.max(hardest_negative_dist[a_keep_idx]).item(), 44 | "mean_neg_dist":torch.mean(hardest_negative_dist[a_keep_idx]).item(), 45 | "min_neg_dist" :torch.min(hardest_negative_dist[a_keep_idx]).item(), 46 | } 47 | return anc_ind, pos_ind, neg_ind, stats 48 | 49 | 50 | 51 | class BatchTripletLoss(BaseLoss): 52 | def __init__(self, margin:float, style:str): 53 | super().__init__() 54 | assert style in ["soft", "hard"] 55 | print("BatchTripletLoss: margin=%.1f, style=%s"%(margin, style)) 56 | self.miner = TripletMiner() 57 | self.margin = margin 58 | self.style = style 59 | return 60 | 61 | def __call__(self, 62 | source_feats:torch.Tensor, target_feats:torch.Tensor, 63 | positives_mask:torch.Tensor, negative_mask:torch.Tensor 64 | ): 65 | stats = {} 66 | # get dist l2d mat 67 | dist_mat = torch.norm(source_feats.unsqueeze(1) - target_feats.unsqueeze(0), dim=2) 68 | # miner 69 | anc, pos, neg, miner_stats = self.miner(dist_mat, positives_mask, negative_mask) 70 | stats.update(miner_stats) 71 | pos_dist = torch.norm(source_feats[anc] - target_feats[pos], dim=1) 72 | neg_dist = torch.norm(source_feats[anc] - target_feats[neg], dim=1) 73 | triplet_dist = pos_dist - neg_dist 74 | with torch.no_grad(): 75 | stats["norm"] = torch.norm(source_feats, dim=1).mean().item() 76 | stats["non_zero_triplet_num"] = torch.where((triplet_dist + self.margin) > 0)[0].shape[0] 77 | 78 | if self.style == "hard": 79 | loss = F.relu(triplet_dist + self.margin).mean() 80 | elif self.style == "soft": 81 | loss = torch.log(1+self.margin*torch.exp(triplet_dist)).mean() 82 | else: 83 | raise NotImplementedError(f"BatchTripletLoss: unkown style {self.style}") 84 | 85 | stats["loss"] = loss.item() 86 | return loss, stats 87 | 88 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]): 89 | print("TripletLoss: %.3f, Norm: %.3f, All/Non-zero: %.1f/%.1f"%( 90 | stats["loss"], stats["norm"], stats["triplet_num"], stats["non_zero_triplet_num"] 91 | )) 92 | print("Positive: %.3f, %.3f, %.3f | Negative: %.3f, %.3f, %.3f (min, avg, max)"%( 93 | stats["min_pos_dist"], stats["mean_pos_dist"], stats["max_pos_dist"], 94 | stats["min_neg_dist"], stats["mean_neg_dist"], stats["max_neg_dist"], 95 | )) 96 | return 97 | -------------------------------------------------------------------------------- /media/description.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/media/description.png -------------------------------------------------------------------------------- /media/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/media/pipeline.png -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import argparse 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | from typing import Any, Dict, List 7 | 8 | def get_datetime(): 9 | return time.strftime("%Y%m%d_%H%M%S") 10 | 11 | def str2bool(v): 12 | if v.lower() in ("yes", "true", "t", "y", "1"): 13 | return True 14 | elif v.lower() in ("no", "false", "f", "n", "0"): 15 | return False 16 | else: 17 | raise argparse.ArgumentTypeError("Unsupported value encountered.") 18 | 19 | def get_idx_from_string(elem): 20 | """ 21 | 000021.npy -> 21 22 | """ 23 | return int(elem.split(".")[0]) 24 | 25 | 26 | def tensors2numbers(data): 27 | """ 28 | ```python 29 | stats = {e: stats[e].item() if torch.is_tensor(stats[e]) else stats[e] for e in stats} 30 | ``` 31 | """ 32 | if data is None: return data 33 | else: 34 | if torch.is_tensor(data): 35 | return data.item() 36 | elif isinstance(data, list) or isinstance(data, tuple): 37 | for i, _ in enumerate(data): 38 | data[i] = tensors2numbers(data[i]) 39 | return data 40 | elif isinstance(data, dict): 41 | for e in data: 42 | data[e] = tensors2numbers(data[e]) 43 | return data 44 | else: 45 | return data 46 | 47 | def tensors2device(data:Any, device:torch.device): 48 | """ 49 | # [tensor.to(device)] 50 | """ 51 | if data is None: return data 52 | else: 53 | if torch.is_tensor(data): 54 | return data.to(device) 55 | elif isinstance(data, list) or isinstance(data, tuple): 56 | for i, _ in enumerate(data): 57 | data[i] = tensors2device(data[i], device) 58 | return data 59 | elif isinstance(data, dict): 60 | for e in data: 61 | data[e] = tensors2device(data[e], device) 62 | return data 63 | else: 64 | raise NotImplementedError("tensors2device: %s not implemented error"%str(type(data))) 65 | 66 | 67 | def avg_stats(stats:List): 68 | avg = stats[0] 69 | for e in avg: 70 | if isinstance(avg[e], Dict): 71 | this_stats = [stats[i][e] for i in range(len(stats))] 72 | avg[e] = avg_stats(this_stats) 73 | else: 74 | avg[e] = np.mean([stats[i][e] for i in range(len(stats))]) 75 | return avg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/models/__init__.py -------------------------------------------------------------------------------- /models/gapr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Dict 4 | 5 | from models.utils.aggregation.gem import MeanGeM 6 | from models.utils.extraction.mink.minkfpn import MinkFPN 7 | from models.utils.transformers.transgeo import PCTrans 8 | 9 | class GAPR(nn.Module): 10 | def __init__(self, minkfpn:Dict, pctrans:Dict, meangem:Dict, **kw): 11 | super(GAPR, self).__init__() 12 | print("Model: GAPR") 13 | self.minkfpn = MinkFPN(**minkfpn) 14 | self.geneous_names = ["ground", "aerial"] 15 | 16 | self.ground_trans = PCTrans(**pctrans) 17 | self.aerial_trans = PCTrans(**pctrans) 18 | 19 | self.meangem = MeanGeM(**meangem) 20 | 21 | 22 | def forward(self, coords:torch.Tensor, feats:torch.Tensor, geneous:torch.Tensor): 23 | BS = geneous.shape[0] 24 | cnn_coords, cnn_feats = self.minkfpn(coords, feats) 25 | attn_feats, attn_scores = [], [] 26 | for ndx in range(BS): 27 | if self.geneous_names[geneous[ndx].item()] == "ground": 28 | attn_feat, attn_score = self.ground_trans(cnn_feats[ndx].unsqueeze(0)) 29 | attn_feats.append(attn_feat.squeeze(0)) 30 | attn_scores.append(attn_score.squeeze(0)) 31 | elif self.geneous_names[geneous[ndx].item()] == "aerial": 32 | attn_feat, attn_score = self.aerial_trans(cnn_feats[ndx].unsqueeze(0)) 33 | attn_feats.append(attn_feat.squeeze(0)) 34 | attn_scores.append(attn_score.squeeze(0)) 35 | else: raise NotImplementedError 36 | 37 | 38 | batch_feats = torch.stack([self.meangem(feat) for feat in attn_feats], dim=0) 39 | # batch_feats = torch.stack([self.meangem(feat) for feat in cnn_feats], dim=0) 40 | 41 | return cnn_coords, cnn_feats, attn_scores, batch_feats 42 | -------------------------------------------------------------------------------- /models/lprmodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import copy 4 | import yaml 5 | from typing import Dict, Any 6 | 7 | class LPRModel: 8 | """ 9 | # Wrapper for models 10 | """ 11 | def __init__(self): 12 | return 13 | 14 | def construct(self, name:str, **kw): 15 | self.config = copy.deepcopy(kw) 16 | self.config["name"] = name 17 | self.name = name 18 | self.model = None 19 | if self.name == "GAPR": 20 | from models.gapr import GAPR 21 | self.model = GAPR(**kw) 22 | else: 23 | raise NotImplementedError("LPRModel: model %s not implemented" % self.name) 24 | 25 | def __call__(self, data:Dict[str, Any]) -> Dict[str, Any]: 26 | output:Dict[str, Any] = {} 27 | if self.name == "GAPR": 28 | assert set(["coords", "feats", "geneous"]) <= set(data.keys()) 29 | output["coords"], output["feats"], output["scores"], output["embeddings"] = self.model(data["coords"], data["feats"], data["geneous"]) 30 | else: 31 | raise NotImplementedError("LPRModel: model %s not implemented" % self.name) 32 | return output 33 | 34 | def save(self, path:str): 35 | pth_file_dict = {"config":self.config, "weight": self.model.module.state_dict()} 36 | torch.save(pth_file_dict, path) 37 | return 38 | 39 | def load(self, path, device): 40 | pth_file_dict = torch.load(path, map_location=device) 41 | print("LPRModel: load\n", pth_file_dict["config"]) 42 | self.construct(**pth_file_dict["config"]) 43 | self.model.load_state_dict(pth_file_dict["weight"]) 44 | return 45 | 46 | 47 | def import_and_save(self, config_path: str, weight_path:str, save_path:str): 48 | """ 49 | import models from other project and save it 50 | """ 51 | assert (os.path.exists(config_path)) and (os.path.exists(weight_path)) and (not os.path.exists(save_path)) 52 | pth_file_dict = {} 53 | # load weights 54 | pth_file_dict["weight"] = torch.load(weight_path) 55 | # load config 56 | f = open(config_path, encoding="utf-8") 57 | pth_file_dict["config"] = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件 58 | f.close() 59 | torch.save(pth_file_dict, save_path) 60 | return 61 | -------------------------------------------------------------------------------- /models/utils/aggregation/gem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GeM(nn.Module): 5 | def __init__(self, pn=256, p=3, eps=1e-6): 6 | super(GeM, self).__init__() 7 | self.p = nn.Parameter(torch.ones(1) * p) 8 | self.eps = eps 9 | self.f = nn.AvgPool1d(pn) # pn = 256 10 | def forward(self, x:torch.Tensor): 11 | temp = x.clamp(min=self.eps).pow(self.p) 12 | temp = self.f(temp) 13 | temp = temp.pow(1./self.p) 14 | # 防止把第一维压缩掉 15 | temp = temp.squeeze(dim=2) 16 | return temp 17 | 18 | class MeanGeM(nn.Module): 19 | def __init__(self, p:float, eps:float): 20 | # p=3, eps=0.0000001 21 | super(MeanGeM, self).__init__() 22 | self.p = nn.Parameter(torch.ones(1) * p) 23 | self.eps = eps 24 | def forward(self, x:torch.Tensor): 25 | # x: [pn, fs] 26 | x = x.clamp(min=self.eps).pow(self.p) 27 | x = x.mean(dim=0) 28 | x = x.pow(1./self.p) 29 | return x -------------------------------------------------------------------------------- /models/utils/extraction/mink/minkfpn.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import torch 5 | import torch.nn as nn 6 | import MinkowskiEngine as ME 7 | from MinkowskiEngine.modules.resnet_block import BasicBlock 8 | from models.utils.extraction.mink.resnet import ResNetBase 9 | from models.utils.extraction.mink.utils import minkowski_decomposed, minkowski_sparse 10 | from typing import List 11 | 12 | class MinkFPN(ResNetBase): 13 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks 14 | def __init__(self, # minkloc3d 15 | quant_size:float, # 0.6 16 | in_channels:int, # 1 17 | out_channels:int, # 256 18 | num_top_down:int, # 1 19 | conv0_kernel_size:int, # 5 20 | layers:List, # (1, 1, 1) 21 | planes:List, # (32, 64, 64) 22 | block=BasicBlock, # defualt 23 | ): 24 | self.quant_size = quant_size 25 | assert len(layers) == len(planes) 26 | assert 1 <= len(layers) 27 | assert 0 <= num_top_down <= len(layers) 28 | self.num_bottom_up = len(layers) 29 | self.num_top_down = num_top_down 30 | self.conv0_kernel_size = conv0_kernel_size 31 | self.block = block 32 | self.layers = layers 33 | self.planes = planes 34 | self.lateral_dim = out_channels 35 | self.init_dim = planes[0] 36 | ResNetBase.__init__(self, in_channels, out_channels, D=3) 37 | 38 | def network_initialization(self, in_channels, out_channels, D): 39 | assert len(self.layers) == len(self.planes) 40 | assert len(self.planes) == self.num_bottom_up 41 | 42 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2 43 | self.bn = nn.ModuleList() # Bottom-up BatchNorms 44 | self.blocks = nn.ModuleList() # Bottom-up blocks 45 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions 46 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections 47 | 48 | # The first convolution is special case, with kernel size = 5 49 | self.inplanes = self.planes[0] 50 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size, 51 | dimension=D) 52 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) 53 | 54 | for plane, layer in zip(self.planes, self.layers): 55 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D)) 56 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes)) 57 | self.blocks.append(self._make_layer(self.block, plane, layer)) 58 | 59 | # Lateral connections 60 | for i in range(self.num_top_down): 61 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1, 62 | stride=1, dimension=D)) 63 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2, 64 | stride=2, dimension=D)) 65 | # There's one more lateral connection than top-down TConv blocks 66 | if self.num_top_down < self.num_bottom_up: 67 | # Lateral connection from Conv block 1 or above 68 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1, 69 | stride=1, dimension=D)) 70 | else: 71 | # Lateral connection from Con0 block 72 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[0], self.lateral_dim, kernel_size=1, 73 | stride=1, dimension=D)) 74 | 75 | self.relu = ME.MinkowskiReLU(inplace=True) 76 | 77 | def forward(self, coords:torch.Tensor, feats:torch.Tensor): 78 | # Sparse Quant 79 | x = minkowski_sparse(coords, feats, self.quant_size) 80 | 81 | # *** BOTTOM-UP PASS *** 82 | # First bottom-up convolution is special (with bigger kernel) 83 | feature_maps = [] 84 | x = self.conv0(x) 85 | x = self.bn0(x) 86 | x = self.relu(x) 87 | if self.num_top_down == self.num_bottom_up: 88 | feature_maps.append(x) 89 | 90 | # BOTTOM-UP PASS 91 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)): 92 | x = conv(x) # Downsample (conv stride=2 with 2x2x2 kernel) 93 | x = bn(x) 94 | x = self.relu(x) 95 | x = block(x) 96 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1: 97 | feature_maps.append(x) 98 | 99 | assert len(feature_maps) == self.num_top_down 100 | 101 | x = self.conv1x1[0](x) 102 | 103 | # TOP-DOWN PASS 104 | for ndx, tconv in enumerate(self.tconvs): 105 | x = tconv(x) # Upsample using transposed convolution 106 | x = x + self.conv1x1[ndx+1](feature_maps[-ndx - 1]) 107 | 108 | # Decomposed 109 | cnn_coords, cnn_feats = minkowski_decomposed(x, self.quant_size) 110 | return cnn_coords, cnn_feats 111 | 112 | if __name__ == "__main__": 113 | model = MinkFPN( 114 | quant_size=0.6, 115 | in_channels=1, 116 | out_channels=256, 117 | num_top_down=1, 118 | conv0_kernel_size=5, 119 | layers=[1,1,1], 120 | planes=[32,64,64], 121 | ).cuda() 122 | BS, PN, FS = 16, 23212, 1 123 | coords, feats = torch.rand((BS, PN, 3))*60.0, torch.rand((BS, PN, FS)) 124 | coords, feats = coords.cuda(), feats.cuda() 125 | cnn_coords, cnn_feats = model(coords, feats) 126 | -------------------------------------------------------------------------------- /models/utils/extraction/mink/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). 2 | # 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to 6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies 7 | # of the Software, and to permit persons to whom the Software is furnished to do 8 | # so, subject to the following conditions: 9 | # 10 | # The above copyright notice and this permission notice shall be included in all 11 | # copies or substantial portions of the Software. 12 | # 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | # SOFTWARE. 20 | # 21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural 22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part 23 | # of the code. 24 | 25 | import torch.nn as nn 26 | 27 | import MinkowskiEngine as ME 28 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck 29 | 30 | 31 | class ResNetBase(nn.Module): 32 | block = None 33 | layers = () 34 | init_dim = 64 35 | planes = (64, 128, 256, 512) 36 | 37 | def __init__(self, in_channels, out_channels, D=3): 38 | nn.Module.__init__(self) 39 | self.D = D 40 | assert self.block is not None 41 | 42 | self.network_initialization(in_channels, out_channels, D) 43 | self.weight_initialization() 44 | 45 | def network_initialization(self, in_channels, out_channels, D): 46 | self.inplanes = self.init_dim 47 | self.conv1 = ME.MinkowskiConvolution( 48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) 49 | 50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) 51 | self.relu = ME.MinkowskiReLU(inplace=True) 52 | 53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) 54 | 55 | self.layer1 = self._make_layer( 56 | self.block, self.planes[0], self.layers[0], stride=2) 57 | self.layer2 = self._make_layer( 58 | self.block, self.planes[1], self.layers[1], stride=2) 59 | self.layer3 = self._make_layer( 60 | self.block, self.planes[2], self.layers[2], stride=2) 61 | self.layer4 = self._make_layer( 62 | self.block, self.planes[3], self.layers[3], stride=2) 63 | 64 | self.conv5 = ME.MinkowskiConvolution( 65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) 66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) 67 | 68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling() 69 | 70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) 71 | 72 | def weight_initialization(self): 73 | for m in self.modules(): 74 | if isinstance(m, ME.MinkowskiConvolution): 75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') 76 | 77 | if isinstance(m, ME.MinkowskiBatchNorm): 78 | nn.init.constant_(m.bn.weight, 1) 79 | nn.init.constant_(m.bn.bias, 0) 80 | 81 | def _make_layer(self, 82 | block, 83 | planes, 84 | blocks, 85 | stride=1, 86 | dilation=1, 87 | bn_momentum=0.1): 88 | downsample = None 89 | if stride != 1 or self.inplanes != planes * block.expansion: 90 | downsample = nn.Sequential( 91 | ME.MinkowskiConvolution( 92 | self.inplanes, 93 | planes * block.expansion, 94 | kernel_size=1, 95 | stride=stride, 96 | dimension=self.D), 97 | ME.MinkowskiBatchNorm(planes * block.expansion)) 98 | layers = [] 99 | layers.append( 100 | block( 101 | self.inplanes, 102 | planes, 103 | stride=stride, 104 | dilation=dilation, 105 | downsample=downsample, 106 | dimension=self.D)) 107 | self.inplanes = planes * block.expansion 108 | for i in range(1, blocks): 109 | layers.append( 110 | block( 111 | self.inplanes, 112 | planes, 113 | stride=1, 114 | dilation=dilation, 115 | dimension=self.D)) 116 | 117 | return nn.Sequential(*layers) 118 | 119 | def forward(self, x): 120 | x = self.conv1(x) 121 | x = self.bn1(x) 122 | x = self.relu(x) 123 | x = self.pool(x) 124 | 125 | x = self.layer1(x) 126 | x = self.layer2(x) 127 | x = self.layer3(x) 128 | x = self.layer4(x) 129 | 130 | x = self.conv5(x) 131 | x = self.bn5(x) 132 | x = self.relu(x) 133 | 134 | x = self.glob_avg(x) 135 | return self.final(x) 136 | 137 | 138 | class ResNet14(ResNetBase): 139 | BLOCK = BasicBlock 140 | LAYERS = (1, 1, 1, 1) 141 | 142 | 143 | class ResNet18(ResNetBase): 144 | BLOCK = BasicBlock 145 | LAYERS = (2, 2, 2, 2) 146 | 147 | 148 | class ResNet34(ResNetBase): 149 | BLOCK = BasicBlock 150 | LAYERS = (3, 4, 6, 3) 151 | 152 | 153 | class ResNet50(ResNetBase): 154 | BLOCK = Bottleneck 155 | LAYERS = (3, 4, 6, 3) 156 | 157 | 158 | class ResNet101(ResNetBase): 159 | BLOCK = Bottleneck 160 | LAYERS = (3, 4, 23, 3) 161 | 162 | -------------------------------------------------------------------------------- /models/utils/extraction/mink/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import MinkowskiEngine as ME 3 | 4 | def minkowski_sparse(coords:torch.Tensor, feats:torch.Tensor, quant_size:float): 5 | device = coords.device 6 | # sparse_quantize 7 | quant_coords, quant_feats = [], [] 8 | for coord, feat in zip(coords.clone().detach().cpu(), feats.clone().detach().cpu()): 9 | quant_coord, quant_feat = ME.utils.sparse_quantize( 10 | coordinates=coord, features=feat, quantization_size=quant_size 11 | ) 12 | quant_coords.append(quant_coord) 13 | quant_feats.append(quant_feat) 14 | 15 | # batch collate 16 | batch_coords, batch_feats = ME.utils.sparse_collate(quant_coords, quant_feats) 17 | # to sparse tensor 18 | sparse_tensor = ME.SparseTensor(features=batch_feats.to(device=device), coordinates=batch_coords.to(device=device)) 19 | return sparse_tensor 20 | 21 | def minkowski_decomposed(sparse_tensor, quant_size): 22 | coords, feats = sparse_tensor.decomposed_coordinates_and_features 23 | # de-quantize coordinates 24 | a = torch.tensor([2]) 25 | coords = [e.double()*quant_size for e in coords] 26 | return coords, feats 27 | -------------------------------------------------------------------------------- /models/utils/transformers/transgeo.py: -------------------------------------------------------------------------------- 1 | # Author: Sijie Zhu, https://github.com/Jeff-Zilence/TransGeo2022 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import torch 5 | import torch.nn as nn 6 | from functools import partial 7 | 8 | from timm.models.vision_transformer import Block 9 | from timm.models.layers import trunc_normal_ 10 | class PCTrans(nn.Module): 11 | def __init__(self, 12 | dim:int, 13 | num_heads:int, 14 | mlp_ratio:int, 15 | depth:int, 16 | qkv_bias:bool, 17 | init_values:float, 18 | drop:float, 19 | attn_drop:float, 20 | drop_path_rate:float 21 | ): 22 | super().__init__() 23 | assert depth >= 1 24 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] 25 | self.blocks = nn.Sequential(*[ 26 | Block( 27 | dim=dim, 28 | num_heads=num_heads, 29 | mlp_ratio=mlp_ratio, 30 | qkv_bias=qkv_bias, 31 | init_values=init_values, 32 | proj_drop=drop, 33 | attn_drop=attn_drop, 34 | drop_path=dpr[i], 35 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 36 | act_layer=nn.GELU 37 | ) 38 | for i in range(depth)]) 39 | self.norm = partial(nn.LayerNorm, eps=1e-6)(dim) 40 | self.num_heads = float(num_heads) 41 | 42 | def forward(self, x:torch.Tensor): 43 | attn_score = None 44 | 45 | for i, blk in enumerate(self.blocks): 46 | attn_x = blk.norm1(x) 47 | if i == len(self.blocks)-1: 48 | # decompose attn forward 49 | B, N, C = attn_x.shape 50 | qkv = blk.attn.qkv(attn_x).reshape(B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4) 51 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 52 | 53 | attn = (q @ k.transpose(-2, -1)) * blk.attn.scale 54 | attn = attn.softmax(dim=-1) 55 | 56 | # get attn_score 57 | attn_score = attn.sum(axis=1).sum(axis=1) - self.num_heads 58 | attn_score = torch.sigmoid(attn_score) 59 | 60 | attn = blk.attn.attn_drop(attn) 61 | 62 | attn_x = (attn @ v).transpose(1, 2).reshape(B, N, C) 63 | attn_x = blk.attn.proj(attn_x) 64 | attn_x = blk.attn.proj_drop(attn_x) 65 | else: 66 | attn_x = blk.attn(attn_x) 67 | 68 | x = x + blk.drop_path1(blk.ls1(attn_x)) 69 | x = x + blk.drop_path2(blk.ls2(blk.mlp(blk.norm2(x)))) 70 | x = self.norm(x) 71 | return x, attn_score 72 | 73 | def init_weights_vit_timm(module: nn.Module, name: str = ''): 74 | """ ViT weight initialization, original timm impl (for reproducibility) """ 75 | if isinstance(module, nn.Linear): 76 | trunc_normal_(module.weight, std=.02) 77 | if module.bias is not None: 78 | nn.init.zeros_(module.bias) 79 | elif hasattr(module, 'init_weights'): 80 | module.init_weights() 81 | 82 | def main(): 83 | BS, PN, FS = 1, 2342, 256 84 | # model = deit_small_distilled_patch16_224(save="/home/jieyr/code/TransGeo2022/save") 85 | model = PCTrans( 86 | dim=256, 87 | num_heads=8, 88 | mlp_ratio=4, 89 | qkv_bias=True, 90 | depth=4, 91 | init_values=None, 92 | drop=0.0, 93 | attn_drop=0.0, 94 | drop_path_rate=0.0 95 | ) 96 | feats = torch.rand((BS, PN, FS)) 97 | attn_feats, attn_score = model(feats) 98 | print(attn_feats.size(), attn_score) 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /pretrain/GAPR.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/pretrain/GAPR.pth -------------------------------------------------------------------------------- /pretrain/config.yaml: -------------------------------------------------------------------------------- 1 | dataloaders: 2 | train: 3 | augment: 4 | if_jrr: false 5 | name: TrainAugment 6 | rotate_cmd: zxy10 7 | translate_delta: 1.0 8 | collate: 9 | name: MetricCollate 10 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/train 11 | num_workers: 4 12 | sampler: 13 | batch_expansion_rate: 1.4 14 | batch_size: 16 15 | batch_size_limit: 32 16 | max_batches: null 17 | name: HeteroTripletSample 18 | dist: 19 | backend: nccl 20 | find_unused_parameters: false 21 | method: 22 | loss: 23 | batch_loss: 24 | margin: 1.0 25 | style: hard 26 | name: GAPRLoss 27 | overlap_loss: 28 | corr_dist: 2.0 29 | overlap_loss_scale: 1.0 30 | point_loss: 31 | corr_dist: 2.0 32 | margin: 10 33 | neg_dist: 20.0 34 | pos_dist: 2.1 35 | sample_num: 64 36 | style: soft 37 | point_loss_scale: 0.5 38 | model: 39 | debug: false 40 | meangem: 41 | eps: 1.0e-06 42 | p: 3.0 43 | minkfpn: 44 | conv0_kernel_size: 5 45 | in_channels: 1 46 | layers: 47 | - 1 48 | - 1 49 | - 1 50 | num_top_down: 1 51 | out_channels: 256 52 | planes: 53 | - 32 54 | - 64 55 | - 64 56 | quant_size: 0.6 57 | name: GAPR 58 | pctrans: 59 | attn_drop: 0.0 60 | depth: 1 61 | dim: 256 62 | drop: 0.0 63 | drop_path_rate: 0.0 64 | init_values: null 65 | mlp_ratio: 4 66 | num_heads: 2 67 | qkv_bias: true 68 | results: 69 | logs: null 70 | weights: /home/jieyr/code/ppr/results/weights/Ablation 71 | train: 72 | batch_expansion_th: 0.7 73 | epochs: 40 74 | lr: 0.001 75 | scheduler_milestones: 76 | - 15 77 | - 30 78 | weight_decay: 0.001 79 | -------------------------------------------------------------------------------- /results/evaluate/readme.txt: -------------------------------------------------------------------------------- 1 | The results of evaluation is saved here. -------------------------------------------------------------------------------- /results/weights/readme.txt: -------------------------------------------------------------------------------- 1 | The weights of training is saved here. -------------------------------------------------------------------------------- /scripts/add_path.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH=$PYTHONPATH:/home/jieyr/Codes/GAPR 2 | -------------------------------------------------------------------------------- /scripts/clean.sh: -------------------------------------------------------------------------------- 1 | rm -r /home/jieyr/Codes/GAPR/results/evaluate/* 2 | rm -r /home/jieyr/Codes/GAPR/results/logs/* 3 | rm -r /home/jieyr/Codes/GAPR/results/weights/* -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train/train.py 2 | -------------------------------------------------------------------------------- /train/train.py: -------------------------------------------------------------------------------- 1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D 2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR 3 | 4 | import os 5 | import time 6 | import argparse 7 | import yaml 8 | import torch 9 | from tqdm import tqdm 10 | from typing import Dict, Any, List 11 | import numpy as np 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | 15 | from datasets.dataloders.lprdataloader import LPRDataLoader 16 | from models.lprmodel import LPRModel 17 | from loss.lprloss import LPRLoss 18 | from misc.utils import get_datetime, tensors2device, avg_stats 19 | 20 | from torch.utils.tensorboard import SummaryWriter 21 | 22 | def parse_opt()->dict: 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--yaml", type=str, required=True) 25 | parser.add_argument("--local_rank", type=int, required=True) 26 | opt = parser.parse_args() 27 | opt = vars(opt) 28 | f = open(opt["yaml"], encoding="utf-8") 29 | lprtrain = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件 30 | f.close() 31 | return lprtrain 32 | 33 | def main(**kw): 34 | # 初始化torch.distributed 35 | local_rank = int(os.environ["LOCAL_RANK"]) 36 | torch.cuda.set_device(local_rank) 37 | 38 | dist.init_process_group(backend=kw["dist"]["backend"]) # nccl是GPU设备上最快、最推荐的后端 39 | 40 | # get dataloders 41 | dataloaders = {phase: LPRDataLoader(**kw["dataloaders"]["train"]) for phase in kw["dataloaders"]} 42 | # get model 43 | model = LPRModel() 44 | model.construct(**kw["method"]["model"]) 45 | # get loss function 46 | loss_fn = LPRLoss(**kw["method"]["loss"]) 47 | # model to local_rank 48 | model.model = model.model.to(local_rank) 49 | # construct DDP model 50 | model.model = DDP(model.model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=kw["dist"]["find_unused_parameters"]) 51 | # initialize optimizer after construction of DDP model 52 | optimizer = torch.optim.Adam(model.model.parameters(), lr=kw["train"]["lr"], weight_decay=kw["train"]["weight_decay"]) 53 | # get scheduler 54 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, kw["train"]["scheduler_milestones"], gamma=0.1) 55 | 56 | # set results 57 | writer, weights_path = None, None 58 | if local_rank == 0: 59 | model_name = get_datetime() 60 | if kw["results"]["weights"] is not None: 61 | weights_path = os.path.join(kw["results"]["weights"], model_name) 62 | if not os.path.exists(weights_path): os.mkdir(weights_path) 63 | # save config yaml 64 | with open(os.path.join(weights_path, "config.yaml"), "w") as file: 65 | file.write(yaml.dump(dict(kw), allow_unicode=True)) 66 | if kw["results"]["logs"] is not None: 67 | logs_path = os.path.join(kw["results"]["logs"], model_name) 68 | writer = SummaryWriter(logs_path) 69 | 70 | 71 | # get phases from dataloaders 72 | phases = list(dataloaders.keys()) 73 | # visualize len of phases database 74 | if local_rank == 0: 75 | for phase in phases: 76 | print("Dataloder: {} set len = {}".format(phase, len(dataloaders[phase].dataset))) 77 | 78 | itera = None 79 | if local_rank == 0: itera = tqdm(range(kw["train"]["epochs"])) 80 | else: itera = range(kw["train"]["epochs"]) 81 | for epoch in itera: 82 | for phase in phases: 83 | # switch mode 84 | if phase=="train": model.model.train() 85 | else: model.model.eval() 86 | 87 | # wait barrier 88 | dist.barrier() 89 | 90 | phase_stats:List[Dict] = [] 91 | 92 | for data, mask in dataloaders[phase]: 93 | # data to device 94 | data = tensors2device(data, device=local_rank) 95 | # clear grad 96 | optimizer.zero_grad() 97 | 98 | with torch.set_grad_enabled(phase == "train"): 99 | output = model(data) 100 | loss, stats = loss_fn(output, mask) 101 | if phase == "train": 102 | loss.backward() 103 | optimizer.step() 104 | 105 | phase_stats.append(stats) 106 | torch.cuda.empty_cache() 107 | 108 | # ******* PHASE END ******* 109 | # compute mean stats for the epoch 110 | phase_avg_stats = avg_stats(phase_stats) 111 | # print and save stats 112 | if local_rank == 0: loss_fn.print_stats(epoch, phase, writer, phase_avg_stats) 113 | 114 | # ******* EPOCH END ******* 115 | # scheduler 116 | if scheduler is not None: scheduler.step() 117 | 118 | if local_rank == 0 and weights_path is not None: 119 | model.save(os.path.join(weights_path, "{}.pth".format(epoch))) 120 | 121 | if __name__ == "__main__": 122 | main(**parse_opt()) 123 | 124 | # CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node 2 train.py --------------------------------------------------------------------------------