├── .gitignore
├── LICENSE
├── README.md
├── config
├── evaluate
│ └── once.yaml
└── gapr
│ └── train.yaml
├── datasets
├── dataloders
│ ├── __init__.py
│ ├── augments
│ │ ├── __init__.py
│ │ ├── augment.py
│ │ └── utils.py
│ ├── collates
│ │ ├── __init__.py
│ │ ├── lprcollate.py
│ │ └── utils.py
│ ├── lprdataloader.py
│ └── samplers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── batch.py
│ │ ├── hetero.py
│ │ ├── homo.py
│ │ ├── lprbatchsampler.py
│ │ └── utils.py
└── lprdataset.py
├── evaluate
├── once.py
└── utils.py
├── loss
├── __init__.py
├── base.py
├── gapr.py
├── lprloss.py
├── overlap.py
├── point.py
└── triplet.py
├── media
├── description.png
└── pipeline.png
├── misc
└── utils.py
├── models
├── __init__.py
├── gapr.py
├── lprmodel.py
└── utils
│ ├── aggregation
│ └── gem.py
│ ├── extraction
│ └── mink
│ │ ├── minkfpn.py
│ │ ├── resnet.py
│ │ └── utils.py
│ └── transformers
│ └── transgeo.py
├── pretrain
├── GAPR.pth
└── config.yaml
├── results
├── evaluate
│ └── readme.txt
└── weights
│ └── readme.txt
├── scripts
├── add_path.sh
├── clean.sh
└── train.sh
└── train
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # results of weights and evaluation
132 | results/evaluate/20*
133 | results/weights/20*
134 |
135 | # .vscode
136 | *.vscode/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 SYSU RAPID Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GAPR
2 |
3 | ## Introduction
4 | [RA-L 23] Heterogeneous Deep Metric Learning for Ground and Aerial Point Cloud-Based Place Recognition
5 |
6 | In this paper, we propose a heterogeneous deep metric learning pipeline for ground and aerial point cloud-based place recognition in large-scale environments.The pipeline extracts local features from ground and aerial raw point clouds by a sparse convolution module. The local features are processed by transformer encoders to capture the overlaps between ground and aerial point clouds, and then transformed to unified descriptors for retrieval purposes by backpropagation of heterogeneous loss functions.To facilitate training and provide a reliable benchmark, a large-scale dataset is also proposed, which is collected from well-equipped ground and aerial robotic platforms. We demonstrate the superiority of the proposed method by comparing it with existing well-performed methods. We also show that our method is capable to detect loop closures in a collaborative ground and aerial robotic system in the experimental results.
7 |
8 |
9 |
10 |

11 |
12 |
13 | Task Illustration
14 |
15 |
16 |
17 |
18 |

19 |
20 |
21 | GAPR Pipeline
22 |
23 |
24 |
25 | ## Contributors
26 | [Yingrui Jie 揭英睿](https://github.com/yingruijie),
27 | [Yilin Zhu 朱奕霖](https://github.com/inntoy), and
28 | [Hui Cheng 成慧](https://cse.sysu.edu.cn/content/2504) from
29 | [SYSU RAPID Lab](http://lab.sysu-robotics.com).
30 |
31 | ## Citation
32 | ```tex
33 | @ARTICLE{10173571,
34 | author={Jie, Yingrui and Zhu, Yilin and Cheng, Hui},
35 | journal={IEEE Robotics and Automation Letters},
36 | title={Heterogeneous Deep Metric Learning for Ground and Aerial Point Cloud-Based Place Recognition},
37 | year={2023},
38 | volume={},
39 | number={},
40 | pages={1-8},
41 | doi={10.1109/LRA.2023.3292623}}
42 | ```
43 |
44 |
45 | # Usage
46 | ## Environment
47 | This project has been tested on a system with Ubuntu 18.04. Main dependencies include: CUDA >= 10.2; PyTorch >= 1.9.1; MinkowskiEngine >= 0.5.4; Opne3D >= 0.15.2. Please set up the requirments as follows.
48 | 1. Install [cuda-10.2](https://developer.nvidia.com/cuda-10.2-download-archive).
49 |
50 | 2. Create the anaconda environment.
51 | ```sh
52 | conda create -n gapr python=3.8
53 | conda activate gapr
54 | ```
55 | 3. [PyTorch](https://pytorch.org/).
56 | ```sh
57 | pip install torch==1.9.1+cu102 torchvision==0.10.1+cu102 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
58 | ```
59 |
60 | 4. [MinkowskiEngine](https://github.com/NVIDIA/MinkowskiEngine).
61 | ```sh
62 | conda install openblas-devel -c anaconda
63 | git clone https://github.com/NVIDIA/MinkowskiEngine.git
64 | cd MinkowskiEngine
65 | export CXX=g++-7
66 | git checkout v0.5.4
67 | python setup.py install --blas_include_dirs=${CONDA_PREFIX}/include --blas=openblas
68 | cd ..
69 | ```
70 |
71 | 5. Install requirements.
72 | ```sh
73 | # install setuptools firstly to avoid some bugs
74 | pip install setuptools==58.0.4
75 | pip install tqdm open3d tensorboard pandas matplotlib pillow ptflops timm==0.9.2
76 | ```
77 |
78 | 6. Download this repository.
79 | ```sh
80 | git clone https://github.com/SYSU-RoboticsLab/GAPR.git
81 | cd GAPR
82 | ```
83 | Add the python path before running codes:
84 | ```sh
85 | export PYTHONPATH=$PYTHONPATH:/PATH_TO_CODE/GAPR
86 | ```
87 |
88 | ## Dataset
89 | Please download our [benchmark dataset](https://pan.baidu.com/s/1TsxSNZVkGwpZjBM0eNXglw?pwd=zxx4) and unpack the tar file.
90 | Run the following command to check the dataset (`train` and `evaluate`).
91 | ```sh
92 | python datasets/lprdataset.py --dataset /PATH_TO_DATASET/benchmark/train
93 | python datasets/lprdataset.py --dataset /PATH_TO_DATASET/benchmark/evaluate
94 | ```
95 |
96 | ## Evaluate
97 | 1. Change the path of dataset in `config/evaluate/once.yaml`.
98 | ```yaml
99 | # ...
100 | dataloaders:
101 | evaluate:
102 | dataset: /PATH_TO_DATASET/benchmark/evaluate
103 | # ...
104 | ```
105 | 2. We provide a pretrain weights for evaluation.
106 | ```sh
107 | python evaluate/once.py --weights pretrain/GAPR.pth --yaml config/evaluate/once.yaml
108 | ```
109 | Parameter `weights` is used to set the path of model weights. The results are saved at `results/evaluate/YYMMDD_HHMMSS`.
110 | ## Train
111 | 1. Change the path of dataset in `config/gapr/train.yaml`.
112 | ```yaml
113 | # ...
114 | dataloaders:
115 | train:
116 | dataset: /PATH_TO_DATASET/benchmark/train
117 | # ...
118 | ```
119 | 2. Select the GPUs and start training. For example, here GPU 1,3 are sued and `nproc_per_node` is set to 2 (number of selected GPUs).
120 | ```sh
121 | CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train/train.py --yaml config/gapr/train.yaml
122 | ```
123 | The training weights are saved at `results/weights/YYMMDD_HHMMSS`.
124 |
125 | # Acknowledgement
126 | We acknowledge the authors of [MinkLoc3D](https://github.com/jac99/MinkLoc3D) for their excellent codebase which has been used as a starting point for this project.
127 |
--------------------------------------------------------------------------------
/config/evaluate/once.yaml:
--------------------------------------------------------------------------------
1 | dataloaders:
2 | evaluate:
3 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/evaluate
4 | collate:
5 | name: MetricCollate
6 | augment:
7 | name: EvaluateAugment
8 | rotate_cmd: zxy10
9 | translate_delta: 0.0
10 | if_jrr: no
11 | sampler:
12 | name: BatchSample
13 | batch_size: 1
14 | batch_size_limit: null
15 | batch_expansion_rate: null
16 | # sample kw
17 | shuffle: false
18 | max_batches: null
19 | num_workers: 1
20 |
21 |
--------------------------------------------------------------------------------
/config/gapr/train.yaml:
--------------------------------------------------------------------------------
1 | dataloaders:
2 | train:
3 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/train
4 | collate:
5 | name: MetricCollate
6 | augment:
7 | name: TrainAugment
8 | rotate_cmd: zxy10
9 | translate_delta: 1.0
10 | if_jrr: no
11 | sampler:
12 | name: HeteroTripletSample
13 | batch_size: 16
14 | batch_size_limit: 32
15 | batch_expansion_rate: 1.4
16 | max_batches: null
17 | num_workers: 4
18 |
19 | method:
20 | model:
21 | name: GAPR
22 | debug: no
23 | minkfpn:
24 | quant_size: 0.6
25 | in_channels: 1
26 | out_channels: 256
27 | num_top_down: 1
28 | conv0_kernel_size: 5
29 | layers: [1, 1, 1]
30 | planes: [32, 64, 64]
31 | pctrans:
32 | dim: 256
33 | num_heads: 2
34 | mlp_ratio: 4
35 | depth: 1
36 | qkv_bias: yes
37 | init_values: null
38 | drop: 0.0
39 | attn_drop: 0.0
40 | drop_path_rate: 0.0
41 | meangem:
42 | p: 3.0
43 | eps: 0.000001
44 | loss:
45 | name: GAPRLoss
46 | batch_loss:
47 | margin: 1.0
48 | style: hard
49 | point_loss:
50 | margin: 10.0
51 | style: soft
52 | corr_dist: 2.0
53 | sample_num: 64
54 | pos_dist: 2.1
55 | neg_dist: 20.0
56 | overlap_loss:
57 | corr_dist: 2.0
58 | point_loss_scale: 0.5
59 | overlap_loss_scale: 1.0
60 |
61 | train:
62 | lr: 0.001
63 | epochs: 40
64 | weight_decay: 0.001
65 | batch_expansion_th: 0.7 # no used
66 | scheduler_milestones: [15, 30] # no used
67 |
68 | dist:
69 | backend: nccl
70 | find_unused_parameters: no
71 |
72 | results:
73 | weights: results/weights
74 | logs: null # no used
75 |
--------------------------------------------------------------------------------
/datasets/dataloders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/__init__.py
--------------------------------------------------------------------------------
/datasets/dataloders/augments/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/augments/__init__.py
--------------------------------------------------------------------------------
/datasets/dataloders/augments/augment.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as transforms
2 | import open3d as o3d
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import open3d as o3d
7 | from typing import List, Dict, Tuple
8 | from datasets.dataloders.augments.utils import *
9 |
10 | class Augment:
11 | """
12 | # Wrapper for Pointcloud Augment
13 | """
14 | def __init__(self, name:str, rotate_cmd:str, translate_delta:float, if_jrr:bool):
15 | print("Augment: name=%s, rotate=%s, translate=%.3f, jrr=%s " % (name, rotate_cmd, translate_delta, if_jrr))
16 | self.rotate = RandomRotation(rotate_cmd)
17 | self.translate = RandomTranslation(translate_delta)
18 |
19 | if if_jrr: raise NotImplementedError("Augment: jrr is currently not implemented.")
20 | self.jrr = None
21 |
22 | def __call__(self, e:torch.Tensor):
23 | # jrr
24 | if self.jrr is not None: e0 = self.jrr(e)
25 | else: e0 = e
26 | # rotate
27 | e1, rotms = self.rotate(e0)
28 | # translate
29 | e2, trans = self.translate(e1)
30 | # align data type and device
31 | e2 = e2.to(e.device).type_as(e).contiguous()
32 | rotms = rotms.to(e.device).type_as(e).contiguous()
33 | trans = trans.to(e.device).type_as(e).contiguous()
34 | return e2, rotms, trans
35 |
36 |
--------------------------------------------------------------------------------
/datasets/dataloders/augments/utils.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import numpy as np
5 | import math
6 | import random
7 | import torch
8 |
9 |
10 | class RandomRotation:
11 | """
12 | # Random Rotate pointclouds and return matrix
13 | """
14 | def __init__(self, cmd:str):
15 | if cmd not in [None, "zxy10", "zxy20", "so3"]: raise NotImplementedError("RandomRotate: cmd in [None, zxy10, zxy20,so3]")
16 | self.cmd = cmd
17 |
18 | def getRotateMatrixFromRotateVector(self, axis:torch.Tensor, theta:torch.Tensor)->torch.Tensor:
19 | """
20 | # Get Rotate Matrix from Rotate Vector\n
21 | ## input \n
22 | axis.size() == [bs, 3] \n
23 | theta.size() == [bs] \n
24 | ## output \n
25 | rotateMatrix.size() == [bs, 3, 3] \n
26 | """
27 | device = axis.device
28 | bs = axis.size()[0]
29 | # [bs, 3]
30 | axis = axis / torch.norm_except_dim(v=axis, pow=2, dim=0)
31 | # [bs, 1, 3], [bs, 3, 1]
32 | axisH, axisV = axis.unsqueeze(1), axis.unsqueeze(2)
33 | # [bs, 1, 1]
34 | cosTheta = torch.cos(theta).reshape(-1, 1, 1)
35 | # [bs, 1, 1]
36 | sinTheta = torch.sin(theta).reshape(-1, 1, 1)
37 | # [bs, 3, 3]
38 | eye = torch.eye(3, device=device).expand(bs, 3, 3)
39 | # axis^ [bs, 3, 3]
40 | axisCaret = torch.cross(eye, axisH.expand(bs, 3, 3), dim=2)
41 | # so3: R = cos(theta) * I + (1-cos(theta)) * dot(a, aT) + sin(theta) * a^
42 | r = cosTheta * eye + (1.0-cosTheta) * torch.bmm(axisV, axisH) + sinTheta * axisCaret
43 | return r
44 |
45 | def __call__(self, coords:torch.Tensor):
46 | device = coords.device
47 | BS, PN, D = coords.shape
48 | # initial theta and axis
49 | theta, axis = torch.zeros((BS), device=device), torch.tensor([[0.0, 0.0, 1.0]] * BS, device=device)
50 | if self.cmd == "zxy10":
51 | # theta [-pi, pi]
52 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi
53 | alpha = torch.rand(BS,device=device) * 2 * np.pi - np.pi
54 | beta = torch.rand(BS,device=device) * np.pi * 10.0 / 180.0
55 | # alpha_axis is a vector in xOy plane
56 | alpha_axis = torch.stack([torch.sin(alpha), torch.cos(alpha), torch.zeros((BS), device=device)], dim=1)
57 | # print(alpha_axis, beta)
58 | alpha_mat = self.getRotateMatrixFromRotateVector(alpha_axis, beta)
59 | axis = torch.bmm(alpha_mat, axis.unsqueeze(2)).squeeze(2)
60 | elif self.cmd == "zxy20":
61 | # theta [-pi, pi]
62 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi
63 | alpha = torch.rand(BS,device=device) * 2 * np.pi - np.pi
64 | beta = torch.rand(BS,device=device) * np.pi * 20.0 / 180.0
65 | # alpha_axis is a vector in xOy plane
66 | alpha_axis = torch.stack([torch.sin(alpha), torch.cos(alpha), torch.zeros((BS), device=device)], dim=1)
67 | # print(alpha_axis, beta)
68 | alpha_mat = self.getRotateMatrixFromRotateVector(alpha_axis, beta)
69 | axis = torch.bmm(alpha_mat, axis.unsqueeze(2)).squeeze(2)
70 | elif self.cmd == "so3":
71 | theta = torch.rand(BS,device=device) * 2 * np.pi - np.pi
72 | axis = torch.rand((BS, 3), device=device) - 0.5
73 | axis = axis / torch.norm_except_dim(axis, dim=1)
74 |
75 | rots_mat = self.getRotateMatrixFromRotateVector(axis, theta).type_as(coords)
76 | coords = torch.bmm(rots_mat, coords.transpose(1,2)).transpose(1, 2)
77 | return coords, rots_mat
78 |
79 |
80 | class RandomTranslation:
81 | """
82 | # Random Translation
83 | """
84 | def __init__(self, delta=0.05):
85 | self.delta = delta
86 |
87 | def __call__(self, coords:torch.Tensor):
88 | BS, device = coords.shape[0], coords.device
89 | trans = self.delta * torch.randn(BS, 3, device=device)
90 | return coords + trans.unsqueeze(1), trans
91 |
92 |
93 | class RandomFlip:
94 | def __init__(self, p):
95 | # p = [p_x, p_y, p_z] probability of flipping each axis
96 | assert len(p) == 3
97 | assert 0 < sum(p) <= 1, 'sum(p) must be in (0, 1] range, is: {}'.format(sum(p))
98 | self.p = p
99 | self.p_cum_sum = np.cumsum(p)
100 |
101 | def __call__(self, coords):
102 | r = random.random()
103 | if r <= self.p_cum_sum[0]:
104 | # Flip the first axis
105 | coords[..., 0] = -coords[..., 0]
106 | elif r <= self.p_cum_sum[1]:
107 | # Flip the second axis
108 | coords[..., 1] = -coords[..., 1]
109 | elif r <= self.p_cum_sum[2]:
110 | # Flip the third axis
111 | coords[..., 2] = -coords[..., 2]
112 |
113 | return coords
114 |
115 | class RandomScale:
116 | def __init__(self, min, max):
117 | self.scale = max - min
118 | self.bias = min
119 |
120 | def __call__(self, coords):
121 | s = self.scale * np.random.rand(1) + self.bias
122 | return coords * s.astype(np.float32)
123 |
124 | class RandomShear:
125 | def __init__(self, delta=0.1):
126 | self.delta = delta
127 |
128 | def __call__(self, coords):
129 | T = np.eye(3) + self.delta * np.random.randn(3, 3)
130 | return coords @ T.astype(np.float32)
131 |
132 |
133 | class JitterPoints:
134 | def __init__(self, sigma=0.01, clip=None, p=1.):
135 | assert 0 < p <= 1.
136 | assert sigma > 0.
137 |
138 | self.sigma = sigma
139 | self.clip = clip
140 | self.p = p
141 |
142 | def __call__(self, e:torch.Tensor):
143 | """ Randomly jitter points. jittering is per point.
144 | Input:
145 | BxNx3 array, original batch of point clouds
146 | Return:
147 | BxNx3 array, jittered batch of point clouds
148 | """
149 |
150 | sample_shape = (e.shape[0],)
151 | if self.p < 1.:
152 | # Create a mask for points to jitter
153 | m = torch.distributions.categorical.Categorical(probs=torch.tensor([1 - self.p, self.p]))
154 | mask = m.sample(sample_shape=sample_shape)
155 | else:
156 | mask = torch.ones(sample_shape, dtype=torch.int64 )
157 |
158 | mask = mask == 1
159 | jitter = self.sigma * torch.randn_like(e[mask])
160 |
161 | if self.clip is not None:
162 | jitter = torch.clamp(jitter, min=-self.clip, max=self.clip)
163 |
164 | e[mask] = e[mask] + jitter
165 | return e
166 |
167 |
168 | class RemoveRandomPoints:
169 | def __init__(self, r):
170 | if type(r) is list or type(r) is tuple:
171 | assert len(r) == 2
172 | assert 0 <= r[0] <= 1
173 | assert 0 <= r[1] <= 1
174 | self.r_min = float(r[0])
175 | self.r_max = float(r[1])
176 | else:
177 | assert 0 <= r <= 1
178 | self.r_min = None
179 | self.r_max = float(r)
180 |
181 | def __call__(self, e:torch.Tensor):
182 | n = len(e)
183 | if self.r_min is None:
184 | r = self.r_max
185 | else:
186 | # Randomly select removal ratio
187 | r = random.uniform(self.r_min, self.r_max)
188 |
189 | mask = np.random.choice(range(n), size=int(n*r), replace=False) # select elements to remove
190 | e[mask] = torch.zeros_like(e[mask])
191 | return e
192 |
193 |
194 | class RemoveRandomBlock:
195 | """
196 | Randomly remove part of the point cloud. Similar to PyTorch RandomErasing but operating on 3D point clouds.
197 | Erases fronto-parallel cuboid.
198 | Instead of erasing we set coords of removed points to (0, 0, 0) to retain the same number of points
199 | """
200 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3)):
201 | self.p = p
202 | self.scale = scale
203 | self.ratio = ratio
204 |
205 | def get_params(self, coords:torch.Tensor):
206 | # Find point cloud 3D bounding box
207 | flattened_coords = coords.contiguous().view(-1, 3)
208 | min_coords, _ = torch.min(flattened_coords, dim=0)
209 | max_coords, _ = torch.max(flattened_coords, dim=0)
210 | span = max_coords - min_coords
211 | area = span[0] * span[1]
212 | erase_area = random.uniform(self.scale[0], self.scale[1]) * area
213 | aspect_ratio = random.uniform(self.ratio[0], self.ratio[1])
214 |
215 | h = math.sqrt(erase_area * aspect_ratio)
216 | w = math.sqrt(erase_area / aspect_ratio)
217 |
218 | x = min_coords[0] + random.uniform(0, 1) * (span[0] - w)
219 | y = min_coords[1] + random.uniform(0, 1) * (span[1] - h)
220 |
221 | return x, y, w, h
222 |
223 | def __call__(self, coords):
224 | if random.random() < self.p:
225 | x, y, w, h = self.get_params(coords) # Fronto-parallel cuboid to remove
226 | mask = (x < coords[..., 0]) & (coords[..., 0] < x+w) & (y < coords[..., 1]) & (coords[..., 1] < y+h)
227 | coords[mask] = torch.zeros_like(coords[mask])
228 | return coords
229 |
--------------------------------------------------------------------------------
/datasets/dataloders/collates/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/collates/__init__.py
--------------------------------------------------------------------------------
/datasets/dataloders/collates/lprcollate.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | from time import sleep
5 | import torch
6 | import numpy as np
7 | import open3d as o3d
8 | import torch.nn.functional as F
9 | import matplotlib.pyplot as plt
10 | from typing import Any, Dict, List
11 | from datasets.lprdataset import LPRDataset
12 | from scipy.spatial.transform import Rotation as R
13 | # common
14 | from datasets.dataloders.collates.utils import align_pcs, in_sorted_array, triplet_mask
15 |
16 | def LPRCollate(dataset:LPRDataset, augment, name:str, **kw):
17 | """
18 | # Wrapper for all collate_fn
19 | ## Format
20 | ```
21 | def collate_fn(data_list):
22 | # data_list defined in LPRdataset.__getitem__()
23 | # data_list = [
24 | # [label0, pc0],
25 | # [label1, pc1]
26 | # ]
27 | ...
28 | # data
29 | data:Dict[str, torch.Tensor] = {"data":data}
30 |
31 | mask:Dict[str, torch.Tensor] = {"mask":mask}
32 | return data, mask
33 | ```
34 | ## Training
35 | ```
36 | for data, mask in dataloader:
37 | data = tensors2device(data, device)
38 | output = model(data)
39 | loss = loss_fn(output, mask)
40 | ...
41 | ```
42 | """
43 | if name == "BaseCollate":
44 | def BaseCollate(data_list):
45 | """
46 | # BaseCollate
47 | * align clouds
48 | * convert labels to tensor
49 | """
50 | clouds = [e[1] for e in data_list]
51 | clouds = align_pcs(clouds)
52 | clouds = torch.stack(clouds, dim=0)
53 |
54 | labels = torch.tensor([e[0] for e in data_list])
55 | data:Dict[str, torch.Tensor] = {"clouds":clouds}
56 | mask:Dict[str, torch.Tensor] = {"labels":labels}
57 | return data, mask
58 | return BaseCollate
59 |
60 | elif name == "MetricCollate":
61 | def MetricCollate(data_list):
62 | """
63 | # Metric Learning Collate Function
64 | """
65 | # constructs a batch object
66 | raw_coords = [e[1] for e in data_list]
67 | labels = [e[0] for e in data_list]
68 | # align points number
69 | raw_coords = align_pcs(raw_coords)
70 | # Tensor: raw_coords: [BS, PN, 3],
71 | raw_coords = torch.stack(raw_coords, dim=0)
72 | BS, device = raw_coords.shape[0], raw_coords.device
73 |
74 | # get tums from dataset
75 | tums = []
76 | for ndx in labels: tums.append(dataset.get_tum(ndx))
77 | tums = np.asarray(tums)
78 | # Tensor: raw_rotms: [BS, 3, 3], raw_trans: [BS, 3]
79 | # tf_global2raw is [raw_rotms, raw_trans]
80 | raw_trans = torch.tensor(tums[:, 1:4], device=device).type_as(raw_coords)
81 | raw_rotms = torch.tensor(R.from_quat(tums[:, 4:8]).as_matrix(), device=device).type_as(raw_coords)
82 |
83 | # apply augment
84 | if augment is not None:
85 | # Tensor: aug_coords: [BS, PN, 3], aug_rotms: [BS, 3, 3], aug_trans: [BS, 3]
86 | # coords = aug_rotms * raw_coords + aug_trans, tf_aug2raw is [aug_rotms, aug_trans]
87 | coords, aug_rotms, aug_trans = augment(raw_coords.clone())
88 | # compute tf_global2aug = tf_global2raw * tf_aug2raw.inverse()
89 | # tf_aug2raw.inverse() = [aug_rotms.T, -aug_rotms.T*aug_trans]
90 | aug_rotms_inv, aug_trans_inv = aug_rotms.transpose(1,2), -torch.bmm(aug_rotms.transpose(1,2), aug_trans.unsqueeze(2)).squeeze(2)
91 | # trans = raw_trans + raw_rotms * aug_rotms.inverse() * aug_trans
92 | trans = raw_trans + torch.bmm(raw_rotms, aug_trans_inv.unsqueeze(2)).squeeze(2)
93 | # rotms = raw_rotms * aug_rotms.inverse()
94 | rotms = torch.bmm(raw_rotms, aug_rotms_inv)
95 | else:
96 | trans, rotms, coords = raw_trans.clone(), raw_rotms.clone(), raw_coords.clone()
97 |
98 | # set feats to 1, or color if rgb pointcloud
99 | feats = torch.ones((coords.shape[0], coords.shape[1], 1))
100 | # compute positives and negatives mask
101 | positives_mask, negatives_mask = triplet_mask(dataset, labels)
102 |
103 | # get geneous
104 | geneous = dataset.get_all_geneous()[labels]
105 | geneous = torch.tensor(geneous, dtype=torch.int)
106 | # get labels
107 | labels = torch.tensor(labels)
108 | # write to data and mask
109 | data:Dict[str, torch.Tensor] = {"coords":coords, "feats":feats, "geneous":geneous}
110 | mask:Dict[str, torch.Tensor] = {"labels":labels, "geneous":geneous, "rotms":rotms, "trans":trans, "positives":positives_mask, "negatives":negatives_mask}
111 | return data, mask
112 | return MetricCollate
113 | else:
114 | raise NotImplementedError("LPRCollate: %s not implemented" % name)
115 |
116 |
--------------------------------------------------------------------------------
/datasets/dataloders/collates/utils.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import open3d as o3d
8 | from typing import List, Dict, Tuple
9 | from datasets.lprdataset import LPRDataset
10 |
11 | def align_pcs(pcs:List[torch.Tensor], align_size:int=None)->List[torch.Tensor]:
12 | """
13 | # align points number in pointclouds
14 | ## Input
15 | * pcs
16 | * align_size: points number \n
17 | ## Output
18 | * newpcs
19 | """
20 | if align_size is None:
21 | # if None, find the max size
22 | max_size = 0
23 | for pc in pcs:
24 | if pc.size()[0] > max_size: max_size = pc.size()[0]
25 | align_size = max_size
26 | else:
27 | for pc in pcs:
28 | assert pc.size()[0] <= align_size, "LPRCollate: pc.size()[0] <= align_size"
29 |
30 | newpcs:List[torch.Tensor] = []
31 | for pc in pcs:
32 | # zero padding
33 | newpcs.append(F.pad(pc, (0,0,0,align_size-pc.size()[0]), "constant", 0))
34 | return newpcs
35 |
36 | def display_inlier_outlier(cloud, ind):
37 | inlier_cloud = cloud.select_by_index(ind)
38 | outlier_cloud = cloud.select_by_index(ind, invert=True)
39 |
40 | print("Showing outliers (red) and inliers (gray): ")
41 | outlier_cloud.paint_uniform_color([1, 0, 0])
42 | inlier_cloud.paint_uniform_color([0.8, 0.8, 0.8])
43 | o3d.visualization.draw_geometries([inlier_cloud, outlier_cloud], window_name='Open3D Removal Outlier', width=1920,
44 | height=1080, left=50, top=50, point_show_normal=False, mesh_show_wireframe=False,
45 | mesh_show_back_face=False)
46 |
47 | def in_sorted_array(e: int, array: np.ndarray) -> bool:
48 | pos = np.searchsorted(array, e)
49 | if pos == len(array) or pos == -1:
50 | return False
51 | else:
52 | return array[pos] == e
53 |
54 | def triplet_mask(dataset:LPRDataset, labels:List[int])->Tuple[torch.Tensor, torch.Tensor]:
55 | positives_mask = [[in_sorted_array(e, np.sort(np.asarray(dataset.get_positives(label)))) for e in labels] for label in labels]
56 | negatives_mask = [[not in_sorted_array(e, np.sort(np.asarray(dataset.get_non_negatives(label)))) for e in labels] for label in labels]
57 | positives_mask = torch.tensor(positives_mask)
58 | negatives_mask = torch.tensor(negatives_mask)
59 | return positives_mask, negatives_mask
--------------------------------------------------------------------------------
/datasets/dataloders/lprdataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import torch
4 | import argparse
5 | from torch.utils.data import DataLoader
6 | import torch.distributed as dist
7 | from typing import Dict, List, Any
8 | from time import sleep
9 | from datasets.dataloders.augments.augment import Augment
10 | from datasets.dataloders.samplers.lprbatchsampler import LPRBatchSampler
11 | from datasets.dataloders.collates.lprcollate import LPRCollate
12 | from datasets.lprdataset import LPRDataset
13 |
14 | from misc.utils import str2bool
15 |
16 | def LPRDataLoader(**kw):
17 | """
18 | Create dataloaders
19 | """
20 |
21 | augment = None
22 | if "augment" in kw: augment = Augment(**kw["augment"])
23 |
24 | dataset = LPRDataset(
25 | rootpath=kw["dataset"],
26 | )
27 |
28 | sampler = LPRBatchSampler(
29 | dataset=dataset,
30 | **kw["sampler"]
31 | )
32 | # Collate function collates items into a batch and applies a 'set transform' on the entire batch
33 | collate = LPRCollate(dataset=dataset, augment=augment, **kw["collate"])
34 | dataloader = DataLoader(
35 | dataset,
36 | batch_sampler=sampler,
37 | collate_fn=collate,
38 | num_workers=kw["num_workers"],
39 | pin_memory=True
40 | )
41 | return dataloader
42 |
43 |
44 | def parse_opt()->dict:
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--yaml", type=str, default="config/dataloader.yaml")
47 | parser.add_argument("--local_rank", type=int, default=None)
48 | opt = parser.parse_args()
49 | opt = vars(opt)
50 | f = open(opt["yaml"], encoding="utf-8")
51 | kw:Dict[str, Any] = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件
52 | f.close()
53 | kw.update(opt)
54 | return kw
55 |
56 | def test_lprataloader(**kw):
57 | if kw["local_rank"] is not None:
58 | local_rank = int(os.environ["LOCAL_RANK"])
59 | torch.cuda.set_device(local_rank)
60 | dist.init_process_group(backend="nccl")
61 |
62 |
63 | dataloader = LPRDataLoader(**kw["dataloader"])
64 | for epoch in range(kw["show"]["epoch"]):
65 | print("epoch", epoch)
66 | for data, mask in dataloader:
67 | if kw["show"]["data"]:
68 | for e in data: print(e, "\n", data[e])
69 | if kw["show"]["mask"]:
70 | for e in mask: print(e, "\n", mask[e])
71 | sleep(kw["show"]["sleep"])
72 | return
73 |
74 | if __name__ == "__main__":
75 | test_lprataloader(**parse_opt())
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/datasets/dataloders/samplers/__init__.py
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | class BaseSample(object, metaclass=abc.ABCMeta):
4 | """
5 | # Base class for all sample
6 | """
7 | @abc.abstractmethod
8 | def __init__(self):
9 | pass
10 |
11 | @abc.abstractmethod
12 | def __call__(self):
13 | """
14 | # Generate heterogeneous indices of batches
15 | """
16 | pass
17 |
18 | @abc.abstractmethod
19 | def get_k(self)->int:
20 | """
21 | # Ensure batch_size % k == 0
22 | """
23 | pass
24 |
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/batch.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from typing import List
4 | from datasets.lprdataset import LPRDataset
5 | from datasets.dataloders.samplers.base import BaseSample
6 |
7 | class BatchSample(BaseSample):
8 | """
9 | # Batch sampling for dataset
10 | """
11 | def __init__(self, dataset:LPRDataset, shuffle:bool, max_batches:int):
12 | print("Sampling Mechanism: BatchSample")
13 | self.dataset = dataset
14 | self.max_batches = max_batches
15 | self.k = 1
16 | self.shuffle = shuffle
17 |
18 | def get_k(self):
19 | return self.k
20 |
21 | def __call__(self, batch_size:int) -> List[List[int]]:
22 | indices = self.dataset.get_indices()
23 | indices = np.sort(indices)
24 | if self.shuffle: random.shuffle(indices)
25 | # remove tail
26 | indices = indices[:indices.shape[0] - indices.shape[0] % batch_size]
27 | # reshape to (batches, batch_size) & tolist
28 | batch_idx = indices.reshape((-1, batch_size)).tolist()
29 | return batch_idx
30 |
31 |
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/hetero.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from typing import List
4 | from copy import deepcopy
5 | from datasets.lprdataset import LPRDataset
6 | from datasets.dataloders.samplers.base import BaseSample
7 |
8 | class HeteroTripletSample(BaseSample):
9 | def __init__(self, dataset:LPRDataset, max_batches:int):
10 | """
11 | # Sampling mechanism for heterogeneous data
12 | """
13 | print("Sampling Mechanism: HeteroTripletSample")
14 | self.dataset = dataset
15 | self.max_batches = max_batches
16 | self.k = 2
17 |
18 | def get_k(self) -> int:
19 | return self.k
20 |
21 | def __call__(self, batch_size:int) -> List[List[int]]:
22 | """
23 | # Generate heterogeneous indices of batches
24 | ## Input
25 | * batch_size
26 | ## Output
27 | batch_idx: [[batch0], [batch1], ..., [batchn]]
28 | """
29 |
30 | assert self.k == 2, "HeteroTripletSample: sampler can sample only k=3 elements from the same class"
31 | assert batch_size >= self.k, "HeteroTripletSample: batch_size >= k"
32 | assert batch_size%self.k == 0, "HeteroTripletSample: batch_size%k == 0"
33 |
34 | # batches indices of an training epoch
35 | batch_idx:List[List[int]] = []
36 | # unused
37 | unused_elements_ndx:List[int] = self.dataset.get_indices().tolist()
38 | # current
39 | current_batch:List[int] = []
40 |
41 | # items with heterogeneous positive samples in dataset
42 | anchors:List[int] = self.dataset.get_anchors().tolist()
43 |
44 | while True:
45 | anchor = random.choice(anchors)
46 | anchor_geneous = self.dataset.get_all_geneous()[anchor]
47 | current_batch.append(anchor)
48 | anchors.remove(anchor)
49 | unused_elements_ndx.remove(anchor)
50 |
51 | unused_elements_ndx_np = np.asarray(unused_elements_ndx)
52 | positives = self.dataset.get_positives(anchor)
53 |
54 | for gid, _ in enumerate(self.dataset.get_geneous_names()):
55 | if gid == anchor_geneous: continue
56 | geneous_positives = np.intersect1d(positives, self.dataset.get_homoindices(gid))
57 | assert geneous_positives.shape[0] > 0, "HeteroTripletSampler: gpos.shape[0] = 0"
58 | unused_geneous_positives = np.intersect1d(
59 | unused_elements_ndx_np,
60 | geneous_positives
61 | )
62 | this_geneous_positive:int=None
63 | if len(unused_geneous_positives) != 0:
64 | this_geneous_positive = random.choice(unused_geneous_positives.tolist())
65 | unused_elements_ndx.remove(this_geneous_positive)
66 | else:
67 | this_geneous_positive = random.choice(geneous_positives.tolist())
68 |
69 | current_batch.append(this_geneous_positive)
70 |
71 | if this_geneous_positive in anchors: anchors.remove(this_geneous_positive)
72 |
73 | if len(current_batch) >= batch_size:
74 | assert len(current_batch) % self.k == 0
75 | batch_idx.append(deepcopy(current_batch))
76 | current_batch = []
77 | if (self.max_batches is not None) and (len(batch_idx) >= self.max_batches):
78 | break
79 |
80 | if len(unused_elements_ndx) == 0 or len(anchors) == 0:
81 | break
82 | return batch_idx
83 |
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/homo.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import random
5 | import numpy as np
6 | from typing import List
7 | from datasets.lprdataset import LPRDataset
8 | from datasets.dataloders.samplers.base import BaseSample
9 |
10 | class HomoTripletSample(BaseSample):
11 | def __init__(self, dataset:LPRDataset, max_batches:int):
12 | """
13 | # Homogeneous sampling
14 | * Sampler returning list of indices to form a mini-batch
15 | * Samples elements in groups consisting of k=2 similar elements (positives)
16 | * Batch has the following structure: item1_1, ..., item1_k, item2_1, ... item2_k, itemn_1, ..., itemn_k
17 | ## Input
18 | * dataset
19 | * max_batches
20 | """
21 | print("Sampling Mechanism: HomoTripletSample")
22 | self.dataset = dataset
23 | self.max_batches = max_batches
24 | self.k = 2
25 |
26 | def get_k(self) -> int:
27 | return self.k
28 |
29 | def __call__(self, batch_size:int) -> List[List[int]]:
30 | assert self.k == 2, "HomoTripletSample: sampler can sample only k=2 elements from the same class"
31 | assert batch_size >= 2*self.k, "HomoTripletSample: batch_size > 2*k"
32 | assert batch_size%self.k == 0, "HomoTripletSample: batch_size%k == 0"
33 |
34 | # Generate training/evaluation batches.
35 | # batch_idx holds indexes of elements in each batch as a list of lists
36 | batch_idx:List[List[int]] = []
37 |
38 | unused_elements_ndx:List[int] = self.dataset.get_indices().tolist()
39 |
40 | current_batch:List[int] = []
41 |
42 | while True:
43 | if len(current_batch) >= batch_size or len(unused_elements_ndx) == 0:
44 | # Flush out batch, when it has a desired size, or a smaller batch, when there's no more
45 | # elements to process
46 | if len(current_batch) >= 2*self.k:
47 | # Ensure there're at least two groups of similar elements, otherwise, it would not be possible
48 | # to find negative examples in the batch
49 | assert len(current_batch) % self.k == 0, "HomoTripletSample: Incorrect bach size: {}".format(len(current_batch))
50 | batch_idx.append(current_batch)
51 | current_batch = []
52 | if (self.max_batches is not None) and (len(batch_idx) >= self.max_batches):
53 | break
54 | if len(unused_elements_ndx) == 0:
55 | break
56 |
57 | # Add k=2 similar elements to the batch
58 | selected_element = random.choice(unused_elements_ndx)
59 |
60 | unused_elements_ndx.remove(selected_element)
61 |
62 | positives = list(self.dataset.get_positives(selected_element))
63 | if len(positives) == 0:
64 | # Broken dataset element without any positives
65 | continue
66 | unused_positives = [e for e in positives if e in unused_elements_ndx]
67 | # If there're unused elements similar to selected_element, sample from them
68 | # otherwise sample from all similar elements
69 | if len(unused_positives) > 0:
70 | second_positive = random.choice(unused_positives)
71 | unused_elements_ndx.remove(second_positive)
72 | else:
73 | second_positive = random.choice(list(positives))
74 | current_batch += [selected_element, second_positive]
75 |
76 | for batch in batch_idx:
77 | assert len(batch) % self.k == 0, "Incorrect bach size: {}".format(len(batch))
78 |
79 | return batch_idx
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/lprbatchsampler.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Any
2 | import torch.distributed as dist
3 | from torch.utils.data import Sampler
4 | from datasets.lprdataset import LPRDataset
5 | from datasets.dataloders.samplers.utils import broadcast_batch_idx
6 |
7 | class LPRBatchSampler(Sampler[List[int]]):
8 | """
9 | # Wrapper for all sampler
10 | """
11 | def __init__(
12 | self,
13 | dataset:LPRDataset,
14 | name:str,
15 | batch_size:int,
16 | batch_size_limit:int,
17 | batch_expansion_rate:float,
18 | **kw,
19 | ):
20 | """
21 | # Re-generate batch indices
22 | # Input
23 | * dataset
24 | * batch_size: initial batch size
25 | * sample: `["BaseSample", "HomoTripletSample", "HeteroTripletSample", "RandomSample"]`
26 | * batch_size_limit: max batch size
27 | * batch_expansion_rate
28 | * max_batches
29 | """
30 | # sample factory
31 | self.sample_fn = None
32 | if name == "BatchSample":
33 | from datasets.dataloders.samplers.batch import BatchSample
34 | self.sample_fn = BatchSample(dataset=dataset, **kw)
35 | elif name == "HomoTripletSample":
36 | from datasets.dataloders.samplers.homo import HomoTripletSample
37 | self.sample_fn = HomoTripletSample(dataset=dataset, **kw)
38 | elif name == "HeteroTripletSample":
39 | from datasets.dataloders.samplers.hetero import HeteroTripletSample
40 | self.sample_fn = HeteroTripletSample(dataset=dataset, **kw)
41 | else:
42 | raise NotImplementedError("LPRBatchSampler: %s sample_fn not implemented" % name)
43 |
44 | # gpu mode
45 | self.use_dist = False
46 | if dist.is_initialized():
47 | # multi-gpu
48 | self.use_dist = True
49 | if dist.get_rank() == 0: print("LPRBatchSampler: multi-gpu mode")
50 | else:
51 | # single-gpu
52 | print("LPRBatchSampler: sigle-gpu mode")
53 |
54 |
55 | self.batch_size = batch_size - batch_size%self.sample_fn.get_k()
56 | self.batch_size_limit = batch_size_limit
57 | self.batch_expansion_rate = batch_expansion_rate
58 | if batch_expansion_rate is not None:
59 | assert batch_expansion_rate > 1., "LPRBatchSampler: batch_expansion_rate must be greater than 1"
60 | assert batch_size <= batch_size_limit, "LPRBatchSampler: batch_size_limit must be greater or equal to batch_size"
61 |
62 | self.batch_idx = []
63 |
64 |
65 | def __iter__(self):
66 | """
67 | # Generate A Bacth_idx
68 | """
69 | # multi-gpu
70 | if self.use_dist:
71 | gen_rank = 0
72 | all_batch_idx:List[List[int]] = None
73 | if dist.get_rank() == gen_rank:
74 | # generate all_batch_idx
75 | all_batch_idx = self.sample_fn(self.batch_size)
76 | else: pass
77 | # broadcast all_batch_idx to all process
78 | self.batch_idx = broadcast_batch_idx(
79 | batch_size=self.batch_size,
80 | all_batch_idx=all_batch_idx,
81 | gen_rank=gen_rank
82 | )
83 | # single-gpu
84 | else:
85 | self.batch_idx = self.sample_fn(self.batch_size)
86 |
87 |
88 | for batch in self.batch_idx: yield batch
89 |
90 |
91 | def __len__(self):
92 | return len(self.batch_idx)
93 |
94 | def expand_batch(self):
95 | """
96 | # Expand batch_size by batch_expansion_rate
97 | """
98 | if self.batch_expansion_rate is None:
99 | print("LPRBatchSampler: WARNING batch_expansion_rate is None")
100 | return
101 |
102 | if self.batch_size >= self.batch_size_limit:
103 | return
104 |
105 | old_batch_size = self.batch_size
106 | self.batch_size = int(self.batch_size * self.batch_expansion_rate)
107 | self.batch_size = min(self.batch_size, self.batch_size_limit)
108 |
109 | self.batch_size = self.batch_size - self.batch_size%self.sample_fn.get_k()
110 |
111 | print("LPRBatchSampler: Batch size increased from: {} to {}".format(old_batch_size, self.batch_size))
112 |
--------------------------------------------------------------------------------
/datasets/dataloders/samplers/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | import torch.distributed as dist
4 | from typing import List
5 |
6 | def broadcast_batch_idx(
7 | batch_size:int,
8 | all_batch_idx:List[List[int]],
9 | gen_rank:int,
10 | ):
11 | """
12 | # Assign all_batch_idx to all processes evenly
13 | ## Input
14 | * batch_size
15 | * all_batch_idx
16 | * gen_rank: process id that genenrates all_batch_idx
17 | """
18 | assert dist.is_available() and dist.is_initialized(), "broadcast: sampler broadcast must be dist.is_initialized()"
19 | rank, world_size = dist.get_rank(), dist.get_world_size()
20 | assert gen_rank < world_size, "broadcast: sampler gen_rank >= word_size"
21 |
22 | # num_size[0] is the number of batch, num_size[1] is the batch_size, broadcast for initilize
23 | num_size = torch.tensor([0, 0], dtype=torch.int64).to(rank)
24 |
25 | broadcast_batch_idx:torch.Tensor = None
26 |
27 | if rank == gen_rank:
28 | assert all_batch_idx is not None, "broadcast_batch_idx: rank == gen_rank and batch_idx is None"
29 | # print("HeteroBatchSampler: rank %d generating batch idx" % rank)
30 | # deepcopy all_batch_idx before writing
31 | all_batch_idx = copy.deepcopy(all_batch_idx)
32 | all_batch_idx = [e for e in all_batch_idx if len(e)==batch_size]
33 | # remove tail to ensure all_batch_idx % world_size == 0
34 | all_batch_idx = all_batch_idx[:len(all_batch_idx)-len(all_batch_idx)%world_size]
35 | # print("cut len = {}, each = {}".format(num_cut_batchs, num_cut_batchs/num_replicas))
36 | broadcast_batch_idx = torch.tensor(all_batch_idx).detach().type(torch.int64).to(rank)
37 | # record num_size
38 | num_size[0], num_size[1] = broadcast_batch_idx.size()
39 | else:
40 | pass
41 |
42 | # broadcast num_size
43 | dist.broadcast(num_size, gen_rank)
44 | # print("rank {} num_size = {}".format(rank, num_size))
45 |
46 | # initialize batch_idx according to num_size
47 | if rank == gen_rank:
48 | pass
49 | else:
50 | broadcast_batch_idx = torch.zeros((num_size[0], num_size[1])).detach().type(torch.int64).to(rank)
51 |
52 | # print("rank {} broadcast_batch_idx = {}".format(rank, broadcast_batch_idx.size()))
53 | dist.broadcast(broadcast_batch_idx, gen_rank)
54 |
55 | # broadcast_batch_idx = [[int(c) for c in r] for r in broadcast_batch_idx.cpu()]
56 | broadcast_batch_idx = broadcast_batch_idx.cpu().numpy().tolist()
57 | assert len(broadcast_batch_idx)%world_size == 0, "broadcast: len(broadcast_batch_idx)%world_size != 0"
58 | avg_num = int(len(broadcast_batch_idx)/world_size)
59 | batch_idx = broadcast_batch_idx[rank*avg_num: (rank+1)*avg_num]
60 |
61 | return batch_idx
--------------------------------------------------------------------------------
/datasets/lprdataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import argparse
5 | import numpy as np
6 | import open3d as o3d
7 | from tqdm import tqdm
8 | from torch.utils.data import Dataset
9 | import matplotlib.pyplot as plt
10 | from typing import List
11 | from misc.utils import str2bool
12 |
13 | class LPRDataset(Dataset):
14 | """
15 | # Dataset wrapper for LPRDataset
16 | """
17 | def __init__(self, rootpath:str,):
18 | self.rootpath = rootpath
19 | assert os.path.exists(self.rootpath), "Cannot access rootpath {}".format(self.rootpath)
20 | print("LPRDataset: {}".format(self.rootpath))
21 | # 0: ground, 1: aerial
22 | self.geneous_names = ["ground", "aerial"]
23 | self.Ng = len(self.geneous_names)
24 | self.geneous = np.load(os.path.join(self.rootpath, "geneous.npy"))
25 | self.Nm = self.geneous.shape[0]
26 | # self.homoindices
27 | # ground: [0, 2, 3, 5, 8, ...]
28 | # aerial: [1, 4, 6, 7, 9, ...]
29 | self.homoindices = [[] for _ in self.geneous_names]
30 | for ndx in range(self.Nm):
31 | self.homoindices[self.geneous[ndx]].append(ndx)
32 | self.homoindices = [np.asarray(e) for e in self.homoindices]
33 |
34 | # tum format (Nm, 8) [t, x, y, z, qx, qy, qz, qw]
35 | self.tum = np.load(os.path.join(self.rootpath, "tum.npy"))
36 | assert self.Nm == self.tum.shape[0], "LPRDataset: self.Nm != self.tum.shape[0]"
37 |
38 | # make self check files
39 | self.checkpath = os.path.join(self.rootpath, "selfcheck")
40 | if not os.path.exists(self.checkpath): os.mkdir(self.checkpath)
41 |
42 | self.anchors:np.ndarray = None
43 | # load data
44 | self.pcs = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "pointcloud.npy")) for ndx in range(self.Nm)]
45 | self.positives = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "positives.npy")) for ndx in range(self.Nm)]
46 | self.non_negatives = [np.load(os.path.join(self.rootpath, "items", "%06d"%ndx, "non_negatives.npy")) for ndx in range(self.Nm)]
47 | self.get_anchors()
48 |
49 | def __len__(self):
50 | return self.Nm
51 |
52 | def __getitem__(self, ndx):
53 | # Load point cloud and apply transform
54 | pc = torch.tensor(self.get_pc(ndx))
55 | return ndx, pc
56 |
57 | def get_indices(self) -> np.ndarray:
58 | return np.arange(self.Nm)
59 |
60 | def get_homoindices(self, geneous_id:int) -> np.ndarray:
61 | return np.copy(self.homoindices[geneous_id])
62 |
63 | def get_geneous_names(self) -> List[str]:
64 | return self.geneous_names
65 |
66 | def get_all_geneous(self) -> np.ndarray:
67 | return np.copy(self.geneous)
68 |
69 | def get_positives(self, ndx:int) -> np.ndarray:
70 | return np.copy(self.positives[ndx])
71 |
72 | def get_non_negatives(self, ndx:int) -> np.ndarray:
73 | return np.copy(self.non_negatives[ndx])
74 |
75 | def get_tum(self, ndx:int):
76 | return np.copy(self.tum[ndx])
77 |
78 | def get_correspondences(self, source_ndx:int, target_ndx:int) -> np.ndarray:
79 | path = os.path.join(
80 | self.rootpath,
81 | "items",
82 | "%06d"%source_ndx,
83 | "correspondence",
84 | "%06d.npy"%target_ndx
85 | )
86 | return np.load(path)
87 |
88 | def get_pc(self, ndx) -> np.ndarray:
89 | return np.copy(self.pcs[ndx])
90 |
91 | def get_anchors(self) -> np.ndarray:
92 | """
93 | # Get indices of items with heterogeneous positive samples in dataset
94 | """
95 | if self.anchors is not None: return np.copy(self.anchors)
96 | print("LPRDataset: self.anchors is None, generating")
97 | anchors = []
98 | for i in self.get_indices():
99 | positives = self.get_positives(i)
100 | is_anchor = True
101 | for gid, gname in enumerate(self.geneous_names):
102 | if np.intersect1d(positives, self.get_homoindices(gid)).shape[0] == 0:
103 | is_anchor = False
104 | break
105 | if is_anchor: anchors.append(i)
106 |
107 | self.anchors = np.asarray(anchors)
108 |
109 | for gid, gname in enumerate(self.get_geneous_names()):
110 | ganchors = np.intersect1d(
111 | self.anchors,
112 | self.get_homoindices(gid)
113 | ).shape[0]
114 | print("LPRDataset: %s has %d anchors" % (gname, ganchors))
115 |
116 | return np.copy(self.anchors)
117 |
118 | def check_hetero_triplet(self):
119 | """
120 | # Count hetero triplet number
121 | """
122 | print("LPRDataset: check hetero triplet")
123 | # multi_geneous_positives
124 | mgp = np.zeros((self.Ng, self.Ng))
125 | mgn = np.zeros((self.Ng, self.Ng))
126 | for i in tqdm(self.get_indices()):
127 | sgid = self.geneous[i]
128 | positives = self.get_positives(i)
129 | non_negative = self.get_non_negatives(i)
130 | for tgid in range(self.Ng):
131 | mgp[sgid][tgid] += np.intersect1d(positives, self.homoindices[tgid]).shape[0]
132 | # mgn[sgid][tgid] += np.intersect1d(non_negative, self.homoindices[tgid]).shape[0]
133 | mgn[sgid][tgid] += np.intersect1d(
134 | np.setdiff1d(self.get_indices(), non_negative),
135 | self.homoindices[tgid]
136 | ).shape[0]
137 | mgp = mgp/np.array([self.homoindices[0].shape[0], self.homoindices[1].shape[0]])
138 | mgn = mgn/np.array([self.homoindices[0].shape[0], self.homoindices[1].shape[0]])
139 | print("Avg positive:")
140 | print(str(mgp))
141 | print("Avg negative:")
142 | print(str(mgn))
143 | return
144 |
145 | def check_positives(self, step:int=1):
146 | print("LPRDataset: check_positives")
147 | pos_map:dict[str, np.ndarray] = {}
148 | for sgid, source in enumerate(self.get_geneous_names()):
149 | for tgid, target in enumerate(self.get_geneous_names()):
150 | keyname = "%s-%s" % (source, target)
151 | npos = []
152 | nmap = []
153 | sindices = self.get_homoindices(sgid)
154 | tindices = self.get_homoindices(tgid)
155 | for sndx in tqdm(sindices, desc=keyname):
156 | this_npos = np.intersect1d(
157 | self.get_positives(sndx),
158 | tindices
159 | ).shape[0]
160 | this_npos = int(this_npos/step)*step
161 | if this_npos in npos: nmap[npos.index(this_npos)] += 1
162 | else:
163 | npos.append(this_npos)
164 | nmap.append(1)
165 | this_pos_map = np.asarray([npos, nmap])
166 | sort_ndx = np.argsort(this_pos_map[0])
167 | this_pos_map = this_pos_map[:, sort_ndx]
168 | pos_map[keyname] = this_pos_map
169 |
170 | plt.figure(figsize=(7,4))
171 |
172 | plt.grid()
173 | for keyname in pos_map:
174 | plt.plot(pos_map[keyname][0], pos_map[keyname][1])
175 |
176 |
177 | plt.xlabel("Number of positive samples in database")
178 | plt.ylabel("Number of queries")
179 | plt.legend(list(pos_map.keys()))
180 | plt.show()
181 |
182 | return
183 |
184 | def check_pn(self):
185 | print("LPRDataset: check points number")
186 | for gid, geneous in enumerate(self.geneous_names):
187 | if self.get_homoindices(gid).shape[0] == 0: continue
188 | pn = 0
189 | for i in self.homoindices[gid]:
190 | pn += self.get_pc(i).shape[0]
191 | avgpn = pn / self.homoindices[gid].shape[0]
192 | print("%s avg pn = %.3f" % (geneous, avgpn))
193 | return
194 |
195 |
196 | def show_submaps(self, N=10):
197 | """
198 | # visualize some submaps
199 | """
200 | anchors = self.get_anchors()
201 |
202 | ganchors = np.intersect1d(
203 | anchors,
204 | self.get_homoindices(1)
205 | )
206 | for _ in range(N):
207 | a = random.choice(ganchors)
208 | p = random.choice(
209 | np.intersect1d(
210 | self.get_homoindices(0),
211 | self.get_positives(a)
212 | )
213 | )
214 | pcda = o3d.geometry.PointCloud()
215 | pcda.points = o3d.utility.Vector3dVector(self.get_pc(a))
216 |
217 | pcdp = o3d.geometry.PointCloud()
218 | pcdp.points = o3d.utility.Vector3dVector(self.get_pc(p) + np.asarray([-70, 0, 0]))
219 |
220 | o3d.visualization.draw_geometries(
221 | [pcda, pcdp],
222 | window_name="left: gorund, right: aerial"
223 | )
224 |
225 |
226 |
227 |
228 | def test_dataset():
229 | parser = argparse.ArgumentParser()
230 | parser.add_argument("--dataset", type=str, required=True)
231 | parser.add_argument("--check_positives", type=str2bool, default=True)
232 | parser.add_argument("--check_hetero_triplet", type=str2bool, default=True)
233 | parser.add_argument("--check_pn", type=str2bool, default=True)
234 | parser.add_argument("--get_anchors", type=str2bool, default=True)
235 | parser.add_argument("--show_submaps", type=str2bool, default=5)
236 | opt = parser.parse_args()
237 | opt = vars(opt)
238 | lprdataset = LPRDataset(rootpath=opt["dataset"])
239 |
240 | if opt["check_positives"]: lprdataset.check_positives()
241 | if opt["check_hetero_triplet"]: lprdataset.check_hetero_triplet()
242 | if opt["check_pn"]: lprdataset.check_pn()
243 | if opt["get_anchors"]: lprdataset.get_anchors()
244 | if opt["show_submaps"] > 0: lprdataset.show_submaps(opt["show_submaps"])
245 |
246 | return
247 |
248 | if __name__ == "__main__":
249 | test_dataset()
--------------------------------------------------------------------------------
/evaluate/once.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 | import yaml
6 | from datasets.dataloders.lprdataloader import LPRDataLoader
7 | from models.lprmodel import LPRModel
8 | from evaluate.utils import get_embeddings, get_hetero_topN_recall, get_hetero_recall_precision, show_closest
9 |
10 | from tqdm import tqdm
11 | from misc.utils import get_datetime
12 | import matplotlib.pyplot as plt
13 |
14 | def parse_opt()->dict:
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--weights", type=str, required=True)
17 | parser.add_argument("--yaml", type=str, required=True)
18 | parser.add_argument("--tn", type=int, default=30)
19 | parser.add_argument("--rp", type=int, default=100)
20 | parser.add_argument("--save", type=str, default="results/evaluate/")
21 | opt = parser.parse_args()
22 | opt = vars(opt)
23 | f = open(opt["yaml"], encoding="utf-8")
24 | lpreval = yaml.load(f, Loader=yaml.FullLoader)
25 |
26 | lpreval.update(opt)
27 | return lpreval
28 |
29 |
30 | def feat_l2d_mat(embeddings: np.ndarray) -> np.ndarray:
31 | Nm, Fs = embeddings.shape
32 | distance = np.linalg.norm(embeddings.reshape((Nm, 1, Fs)) - embeddings.reshape((1, Nm, Fs)), axis=2)
33 | distance += np.eye(Nm)*(np.max(distance)+1) # eye
34 | return distance
35 |
36 |
37 | def main(**kw):
38 |
39 | dataloader = LPRDataLoader(**kw["dataloaders"]["evaluate"])
40 |
41 | device:str = None
42 | if torch.cuda.is_available(): device = "cuda"
43 | else: device = "cpu"
44 | print("Device: {}".format(device))
45 | assert os.path.exists(kw["weights"]), "Cannot open network weights: {}".format(kw["weights"])
46 | print("Loading weights: {}".format(kw["weights"]))
47 |
48 | lprmodel = LPRModel()
49 | lprmodel.load(kw["weights"], device)
50 |
51 | # check savepath
52 | savepath = None
53 | if kw["save"] is not None:
54 | assert os.path.exists(kw["save"]), "Path does not exist, please run: mkdir " + kw["save"]
55 | savepath = os.path.join(kw["save"], get_datetime())
56 | os.mkdir(savepath)
57 | print("Save path:", savepath)
58 |
59 | # recall-precision
60 | if kw["rp"] < 1:
61 | print("Evaluation of Recall-Precision: Skip.")
62 | else:
63 | print("Evaluation of Recall-Precision: %d steps." % kw["rp"])
64 | distance = feat_l2d_mat(get_embeddings(lprmodel, dataloader, device, print_stats=False))
65 | rp = get_hetero_recall_precision(dataloader.dataset, distance, num_eval=kw["rp"])
66 | plt.figure()
67 |
68 | plt.xlim(-0.1, 1.1)
69 | plt.ylim(-0.1, 1.1)
70 | plt.grid()
71 | for st in rp:
72 | plt.plot(rp[st]["xy"][0], rp[st]["xy"][1])
73 | for i, d in enumerate(rp[st]["ds"]):
74 | plt.annotate(
75 | text="%.2f"%d,
76 | xy=(rp[st]["xy"][0][i], rp[st]["xy"][1][i]),
77 | xytext=(rp[st]["xy"][0][i], rp[st]["xy"][1][i]),
78 | fontsize=10,
79 | )
80 | plt.xlabel("recall")
81 | plt.ylabel("precision")
82 | plt.legend(list(rp))
83 | if savepath is not None: plt.savefig(os.path.join(savepath, "recall-precision.png"))
84 | plt.close()
85 |
86 |
87 | # average topN-recall
88 | if kw["tn"] < 1:
89 | print("Evaluation of TopN-Recall: Skip.")
90 | else:
91 | topNs = []
92 | print("Evaluation of TopN-Recall: %d epochs, the average is taken." % kw["tn"])
93 | print("(The results are saved each epoch. Enter Ctrl+C to stop.)")
94 | iterator = tqdm(range(kw["tn"]))
95 | for _ in iterator:
96 | # get descriptor distance
97 | distance = feat_l2d_mat(get_embeddings(lprmodel, dataloader, device, print_stats=False))
98 | # append to all topN recall
99 | topNs.append(get_hetero_topN_recall(dataloader.dataset, distance))
100 | # take average values
101 | tn = {}
102 | for e in topNs[0]: tn[e] = np.stack([topN[e] for topN in topNs], axis=0).mean(axis=0)
103 |
104 | plt.figure()
105 | plt.grid()
106 | for e in tn: plt.plot(tn[e])
107 | plt.xlabel("TopN")
108 | plt.ylabel("Recall")
109 | plt.legend(list(tn))
110 | if savepath is not None: plt.savefig(os.path.join(savepath, "topN-recall.png"))
111 | plt.close()
112 |
113 | stats = "Top1-Recall: "
114 | for e in list(tn.keys()): stats += "%s:%.3f|" % (e, tn[e][0])
115 | iterator.set_postfix_str(stats)
116 |
117 | return
118 |
119 |
120 | if __name__ == "__main__":
121 | main(**parse_opt())
--------------------------------------------------------------------------------
/evaluate/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | import open3d as o3d
7 | from tqdm import tqdm
8 | from typing import List
9 | from time import sleep
10 | from sklearn.neighbors import KDTree
11 | from sklearn import manifold
12 | from torch.utils.data import DataLoader
13 | from datasets.lprdataset import LPRDataset
14 | from models.lprmodel import LPRModel
15 | from misc.utils import tensors2device
16 |
17 |
18 |
19 |
20 |
21 |
22 | def get_embeddings(lprmodel: LPRModel, dataloader: DataLoader, device:str, print_stats:bool=True):
23 | lprmodel.model = lprmodel.model.to(device)
24 | lprmodel.model.eval()
25 | embeddings = []
26 | if print_stats: iterater = tqdm(dataloader, desc="Getting embedding")
27 | else: iterater=dataloader
28 | for data, mask in iterater:
29 | data = tensors2device(data, device)
30 | with torch.no_grad():
31 | output = lprmodel(data)
32 | assert "embeddings" in output, "Evaluate: no embeddings in model output"
33 | embeddings.append(output["embeddings"].clone().detach().cpu().numpy())
34 | # visualize_cnn_feats_scores(output["feats"], output["scores"])
35 | # visualize_cnn_feats(output["coords"], output["feats"], data["coords"])
36 | # visualize_scores(output["coords"], output["scores"], data["coords"])
37 |
38 | embeddings = np.concatenate(embeddings, axis=0)
39 | if print_stats: print("Embeddings size = ", embeddings.shape)
40 | # np.save("examples/minkloc3d/results/embeddings", embeddings)
41 | return embeddings
42 |
43 | def get_topN_recall_curve(
44 | dataset:LPRDataset,
45 | distance:np.ndarray,
46 | source_indices:np.ndarray,
47 | target_indices:np.ndarray,
48 | topN:int=10
49 | ):
50 | # top_one_percent = int(target_indices.shape[0]/100)
51 | topN_count = np.zeros((topN,), dtype=np.int32)
52 | topN_recall = np.zeros((topN,))
53 | for j in range(topN):
54 | for sndx in source_indices:
55 | real_positive = np.intersect1d(
56 | dataset.get_positives(sndx),
57 | target_indices
58 | )
59 | # print("true_positives", true_positives)
60 | pred_positive = target_indices[np.argsort(distance[sndx][target_indices])][:j]
61 |
62 | recall_positive = np.intersect1d(
63 | pred_positive,
64 | real_positive
65 | )
66 | if recall_positive.shape[0] != 0:
67 | topN_count[j] += 1
68 | topN_recall[j] = float(topN_count[j]) / float(source_indices.shape[0])
69 |
70 | return topN_recall[1:]
71 |
72 | def get_hetero_topN_recall(
73 | dataset:LPRDataset,
74 | distance:np.ndarray,
75 | savepath:str=None,
76 | show:bool=False
77 | ):
78 | assert len(dataset) == distance.shape[0], "Evaluate: len(datasets) == embeddings.shape[0]"
79 |
80 | # source gen
81 | geneous_names = dataset.get_geneous_names()
82 |
83 | all_topN_recall = {}
84 | for sgid, source in enumerate(geneous_names):
85 | sgindices = dataset.get_homoindices(sgid)
86 | for tgid, target in enumerate(geneous_names):
87 | tgindices = dataset.get_homoindices(tgid)
88 | # print(source, "-", target)
89 | if sgindices.shape[0] == 0 or tgindices.shape[0] == 0:
90 | print("no instance in source or target, continue")
91 | continue
92 | topN_recall = get_topN_recall_curve(dataset, distance, sgindices, tgindices)
93 | all_topN_recall["{}-{}".format(source, target)] = topN_recall
94 | # print("all-all")
95 | all_topN_recall["all-all"] = get_topN_recall_curve(dataset, distance, dataset.get_indices(), dataset.get_indices())
96 |
97 | plt.figure()
98 |
99 | # plt.ylim(-0.1, 1.1)
100 | plt.grid()
101 | for topN_recall in all_topN_recall:
102 | plt.plot(all_topN_recall[topN_recall])
103 | plt.xlabel("topN")
104 | plt.ylabel("recall")
105 | plt.legend(list(all_topN_recall))
106 | if savepath is not None: plt.savefig(os.path.join(savepath, "topN-recall.png"))
107 | elif show: plt.show()
108 | else: plt.close()
109 | return all_topN_recall
110 |
111 | def get_recall_precision_curve(
112 | dataset:LPRDataset,
113 | distance:np.ndarray,
114 | source_indices:np.ndarray,
115 | target_indices:np.ndarray,
116 | num_eval:int,
117 | ):
118 | # rp = np.array([[0.0, 1.0], [1.0, 0.0]])
119 | rp = np.empty((0,2))
120 | ds = np.linspace(np.min(distance)-0.01, np.max(distance)+0.01, num_eval)
121 | for threshold in ds:
122 | threshold_rp = np.empty((0,2))
123 | for i in source_indices:
124 | real_positive = np.intersect1d(
125 | dataset.get_positives(i),
126 | target_indices,
127 | )
128 | pred_positive = np.intersect1d(
129 | np.where(distance[i] < threshold)[0],
130 | target_indices,
131 | )
132 |
133 | # if real_positive.shape[0] == 0 or pred_positive.shape[0] == 0: continue
134 | tp = np.intersect1d(real_positive, pred_positive).shape[0]
135 | fn = np.setdiff1d(real_positive, pred_positive).shape[0]
136 | fp = np.setdiff1d(pred_positive, real_positive).shape[0]
137 | # tqdmiter.write(str(tp)+" "+str(fn)+" "+str(fp))
138 | recall, precision = 0., 0.
139 | if tp == 0:
140 | if fn == 0 and fp == 0: continue
141 | elif fn == 0 and fp != 0: recall, precision = 1., 0.
142 | elif fn != 0 and fp == 0: recall, precision = 0., 1.
143 | else: recall, precision = 0., 0.
144 | else:
145 | recall = float(tp)/float(tp+fn)
146 | precision = float(tp)/float(tp+fp)
147 | # this_rp.append([recall, pricision])
148 | threshold_rp = np.concatenate([threshold_rp, np.asarray([[recall, precision]])], axis=0)
149 |
150 | if threshold_rp.shape[0] == 0: continue
151 | threshold_rp = np.mean(np.asarray(threshold_rp), axis=0)
152 |
153 | # tqdm_iter.set_postfix(recall=threshold_rp[0], precision=threshold_rp[1])
154 | rp = np.concatenate([rp, threshold_rp.reshape(1,2)], axis=0)
155 | # [N, 2] -> [2, N]
156 | rp = rp.T
157 | indices = np.argsort(rp[0])
158 | rp = rp[:, indices]
159 | ds = ds[indices]
160 | return rp, ds
161 |
162 | def get_hetero_recall_precision(
163 | dataset:LPRDataset,
164 | distance:np.ndarray,
165 | savepath:str=None,
166 | num_eval:int=100,
167 | show:bool=False
168 | ):
169 |
170 | assert len(dataset) == distance.shape[0], "Evaluate: len(datasets) == embeddings.shape[0]"
171 | Nm = distance.shape[0]
172 | distance = distance.copy() + np.eye(Nm)*(np.max(distance)+0.01)
173 |
174 | all_rp = {}
175 |
176 | geneous_names = dataset.get_geneous_names()
177 | for sgid, source in enumerate(geneous_names):
178 | sgndx_all = dataset.get_homoindices(sgid)
179 | for tgid, target in enumerate(geneous_names):
180 | tgndx_all = dataset.get_homoindices(tgid)
181 | st = "{}-{}".format(source, target)
182 | # print(st)
183 | if sgndx_all.shape[0] == 0 or tgndx_all.shape[0] == 0:
184 | print("no instance in source or target, continue")
185 | continue
186 | # recall-pricision
187 | all_rp[st] = {}
188 | all_rp[st]["xy"], all_rp[st]["ds"] = get_recall_precision_curve(dataset, distance, sgndx_all, tgndx_all, num_eval)
189 | # print("all-all")
190 | all_rp["all-all"] = {}
191 | all_rp["all-all"]["xy"], all_rp["all-all"]["ds"] = get_recall_precision_curve(dataset, distance, dataset.get_indices(), dataset.get_indices(), num_eval)
192 |
193 | plt.figure(figsize=(20,20))
194 |
195 | plt.xlim(-0.1, 1.1)
196 | plt.ylim(-0.1, 1.1)
197 | plt.grid()
198 | for st in all_rp:
199 | plt.plot(all_rp[st]["xy"][0], all_rp[st]["xy"][1])
200 | for i, d in enumerate(all_rp[st]["ds"]):
201 | plt.annotate(
202 | text="%.2f"%d,
203 | xy=(all_rp[st]["xy"][0][i], all_rp[st]["xy"][1][i]),
204 | xytext=(all_rp[st]["xy"][0][i], all_rp[st]["xy"][1][i]),
205 | fontsize=10,
206 | )
207 | plt.xlabel("recall")
208 | plt.ylabel("precision")
209 | plt.legend(list(all_rp))
210 | if savepath is not None: plt.savefig(os.path.join(savepath, "recall-precision.png"))
211 | elif show: plt.show()
212 | else: plt.close()
213 | return all_rp
214 |
215 |
216 | def show_closest(dataset:LPRDataset, distance:np.ndarray):
217 | print("Show Closest Submaps")
218 |
219 | geneous_names = dataset.get_geneous_names()
220 | for sgid, source in enumerate(geneous_names):
221 | sgndx_all = dataset.get_homoindices(sgid)
222 | for tgid, target in enumerate(geneous_names):
223 | if source == target: continue
224 |
225 | tgndx_all = dataset.get_homoindices(tgid)
226 | for _ in range(10):
227 | d = 4.0
228 | anchor, closest = None, None
229 | while d > 2.0:
230 | anchor = random.choice(sgndx_all)
231 | closest = tgndx_all[np.argsort(distance[anchor, tgndx_all])][0]
232 | d = distance[anchor][closest]
233 |
234 |
235 | suc = "False"
236 | if closest in dataset.get_positives(anchor): suc = "True"
237 |
238 | anchor_pcd = o3d.geometry.PointCloud()
239 | anchor_pcd.points = o3d.utility.Vector3dVector(dataset.get_pc(anchor) - np.asarray([40,0,0]))
240 |
241 | closest_pcd = o3d.geometry.PointCloud()
242 | closest_pcd.points = o3d.utility.Vector3dVector(dataset.get_pc(closest) + np.asarray([40,0,0]))
243 |
244 | o3d.visualization.draw_geometries(
245 | [anchor_pcd, closest_pcd],
246 | window_name="%s-%s: result=%s, distance=%.3f"%(source, target, suc, d)
247 | )
248 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/loss/__init__.py
--------------------------------------------------------------------------------
/loss/base.py:
--------------------------------------------------------------------------------
1 | import abc
2 |
3 | class BaseLoss(object, metaclass=abc.ABCMeta):
4 | """
5 | # Base class for all loss functions
6 |
7 | """
8 | @abc.abstractmethod
9 | def __init__(self):
10 | pass
11 |
12 | @abc.abstractmethod
13 | def __call__(self):
14 | pass
15 |
16 | @abc.abstractmethod
17 | def print_stats(self):
18 | pass
19 |
--------------------------------------------------------------------------------
/loss/gapr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import open3d as o3d
5 | from typing import List, Dict, Any, Tuple
6 | from torch.utils.tensorboard import SummaryWriter
7 | import random
8 |
9 |
10 | from loss.base import BaseLoss
11 | from loss.triplet import BatchTripletLoss
12 | from loss.point import PointTripletLoss
13 | from loss.overlap import OverlapLoss
14 |
15 | class GAPRLoss(BaseLoss):
16 | def __init__(self, batch_loss:Dict, point_loss:Dict, overlap_loss:Dict, point_loss_scale:float, overlap_loss_scale:float):
17 | super().__init__()
18 | print("GAPRLoss: point_loss_scale=%.2f overlap_loss_scale=%.2f"%(point_loss_scale, overlap_loss_scale))
19 | self.batch_loss = BatchTripletLoss(**batch_loss)
20 | self.point_loss = PointTripletLoss(**point_loss)
21 | self.overlap_loss = OverlapLoss(**overlap_loss)
22 | self.point_loss_scale = point_loss_scale
23 | self.overlap_loss_scale = overlap_loss_scale
24 |
25 | def __call__(self,
26 | # model
27 | embeddings:torch.Tensor,
28 | coords:List[torch.Tensor],
29 | feats:List[torch.Tensor],
30 | scores:List[torch.Tensor],
31 | # mask
32 | rotms:torch.Tensor,
33 | trans:torch.Tensor,
34 | positives_mask:torch.Tensor,
35 | negatives_mask:torch.Tensor,
36 | geneous:torch.Tensor
37 | ):
38 | # get global coords
39 | device, BS = embeddings.device, embeddings.shape[0]
40 | rotms, trans = rotms.to(device), trans.to(device)
41 | # R*p + T
42 | global_coords = [torch.mm(rotms[ndx], coords[ndx].clone().detach().transpose(0,1)).transpose(0,1) + trans[ndx].unsqueeze(0) for ndx in range(BS)]
43 | # compute point loss
44 | point_loss, point_stats = self.point_loss(feats, global_coords, positives_mask)
45 | # compute attention loss
46 | overlap_loss, overlap_stats = self.overlap_loss(scores, global_coords, positives_mask, geneous)
47 | # compute batch loss
48 | batch_loss, batch_stats = self.batch_loss(embeddings, embeddings, positives_mask, negatives_mask)
49 |
50 | stats = {"batch":batch_stats, "point":point_stats, "overlap":overlap_stats}
51 | # stats.update(mean_point_stats_show)
52 | return batch_loss+self.point_loss_scale*point_loss+self.overlap_loss_scale*overlap_loss, stats
53 |
54 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
55 | self.batch_loss.print_stats(epoch, phase, writer, stats["batch"])
56 | self.point_loss.print_stats(epoch, phase, writer, stats["point"])
57 | self.overlap_loss.print_stats(epoch, phase, writer, stats["overlap"])
58 | # print("point_consistence_loss: pos=%.3f, neg=%.3f" % (stats["pos_l2ds"], stats["neg_l2ds"]))
59 | return
--------------------------------------------------------------------------------
/loss/lprloss.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 | from misc.utils import tensors2numbers
3 | from torch.utils.tensorboard import SummaryWriter
4 |
5 | class LPRLoss:
6 | """
7 | # Wrapper loss function
8 | """
9 | def __init__(self, name:str, **kw):
10 | self.name = name
11 | if self.name == "GAPRLoss":
12 | from loss.gapr import GAPRLoss
13 | self.loss_fn = GAPRLoss(**kw)
14 | else:
15 | raise NotImplementedError("LPRLoss: loss_fn %s not implemented" % self.name)
16 |
17 | def __call__(self, output:Dict[str, Any], mask:Dict[str, Any]):
18 | loss, stats = None, None
19 | if self.name == "GAPRLoss":
20 | assert set(["embeddings", "coords", "feats", "scores"]) <= set(output.keys())
21 | assert set(["positives", "negatives", "rotms", "trans", "geneous"]) <= set(mask.keys())
22 | loss, stats = self.loss_fn(
23 | output["embeddings"],
24 | output["coords"],
25 | output["feats"],
26 | output["scores"],
27 | mask["rotms"],
28 | mask["trans"],
29 | mask["positives"],
30 | mask["negatives"],
31 | mask["geneous"]
32 | )
33 | else:
34 | raise NotImplementedError("LPRLoss: loss_fn %s not implemented" % self.name)
35 |
36 | assert loss is not None and stats is not None
37 | stats = tensors2numbers(stats)
38 | return loss, stats
39 |
40 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
41 | """
42 | # visualize stats
43 | """
44 | if self.name == "GAPRLoss":
45 | self.loss_fn.print_stats(epoch, phase, writer, stats)
46 | else:
47 | raise NotImplementedError("LPRLoss: loss_fn %s.print_stats() not implemented" % self.name)
--------------------------------------------------------------------------------
/loss/overlap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import open3d as o3d
3 | import numpy as np
4 | from typing import List, Dict, Any
5 | from torch.utils.tensorboard import SummaryWriter
6 |
7 | from loss.base import BaseLoss
8 |
9 | class OverlapLoss(BaseLoss):
10 | def __init__(self, corr_dist:float):
11 | super().__init__()
12 | print("OverlapLoss: corr_dist=%.2f"%(corr_dist))
13 | self.corr_dist = corr_dist
14 | def __call__(self, scores:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor, geneous:torch.Tensor):
15 | source_indices, target_indices = torch.where(positives_mask == True)
16 | keep_0 = geneous[source_indices] != geneous[target_indices]
17 | keep_1 = source_indices < target_indices
18 | select_indices = np.where(keep_0 & keep_1)
19 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist()
20 | losses = []
21 | stats = {"fitness":[],"inpair_min":[], "inpair_mean":[], "inpair_max":[], "nopair_min":[], "nopair_mean":[], "nopair_max":[]}
22 | if len(source_indices) == 0:
23 | # no hetero positive pair, refer to sampler
24 | return torch.zeros((1,), device=scores[0].device).type_as(scores[0]), {"loss":0.0,"fitness":0.0,"inpair_min":0.0, "inpair_mean":0.0, "inpair_max":0.0, "nopair_min":0.0, "nopair_mean":0.0, "nopair_max":0.0}
25 |
26 | for sndx, tndx in zip(source_indices, target_indices):
27 | # construct pcd from coords
28 | source_pcd = o3d.geometry.PointCloud()
29 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy())
30 | source_pcd.paint_uniform_color([0,0,1])
31 | target_pcd = o3d.geometry.PointCloud()
32 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy())
33 | target_pcd.paint_uniform_color([1,0,0])
34 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4))
35 | corr_set = np.asarray(reg_p2p.correspondence_set)
36 | assert reg_p2p.fitness > 0.05
37 |
38 |
39 | Ns, Nt = coords[sndx].shape[0], coords[tndx].shape[0]
40 | source_inpair_indices, source_nopair_indices = corr_set[:, 0], np.setdiff1d(np.arange(Ns), corr_set[:, 0])
41 | target_inpair_indices, target_nopair_indices = corr_set[:, 1], np.setdiff1d(np.arange(Nt), corr_set[:, 1])
42 |
43 |
44 | source_loss = scores[sndx][source_nopair_indices].mean() + 1.0 - scores[sndx][source_inpair_indices].mean()
45 | target_loss = scores[tndx][target_nopair_indices].mean() + 1.0 - scores[tndx][target_inpair_indices].mean()
46 | losses += [source_loss, target_loss]
47 |
48 | stats["fitness"] += [reg_p2p.fitness]
49 | stats["inpair_min"] += [scores[sndx][source_inpair_indices].min().item(), scores[tndx][target_inpair_indices].min().item()]
50 | stats["inpair_mean"] += [scores[sndx][source_inpair_indices].mean().item(), scores[tndx][target_inpair_indices].mean().item()]
51 | stats["inpair_max"] += [scores[sndx][source_inpair_indices].max().item(), scores[tndx][target_inpair_indices].max().item()]
52 | stats["nopair_min"] += [scores[sndx][source_nopair_indices].min().item(), scores[tndx][target_nopair_indices].min().item()]
53 | stats["nopair_mean"] += [scores[sndx][source_nopair_indices].mean().item(), scores[tndx][target_nopair_indices].mean().item()]
54 | stats["nopair_max"] += [scores[sndx][source_nopair_indices].max().item(), scores[tndx][target_nopair_indices].max().item()]
55 |
56 | loss = torch.stack(losses).mean()
57 | avg_stats = {e: np.mean(stats[e]) for e in stats}
58 | avg_stats["loss"] = loss.item()
59 |
60 | return loss, avg_stats
61 |
62 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
63 | print("OverlapLoss: %.3f" % (stats["loss"]))
64 | print("Overlap: %.3f, %.3f, %.3f | Non-overlap: %.3f, %.3f, %.3f | Fitness: %.3f" % (
65 | stats["inpair_min"], stats["inpair_mean"], stats["inpair_max"],
66 | stats["nopair_min"], stats["nopair_mean"], stats["nopair_max"],
67 | stats["fitness"]
68 | ))
69 | return
--------------------------------------------------------------------------------
/loss/point.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import random
4 | import numpy as np
5 | import open3d as o3d
6 | from typing import List, Dict, Any
7 | from loss.base import BaseLoss
8 | from torch.utils.tensorboard import SummaryWriter
9 |
10 | def get_max_per_row(mat:torch.Tensor, mask:torch.Tensor):
11 | non_zero_rows = torch.any(mask, dim=1)
12 | mat_masked = mat.clone()
13 | mat_masked[~mask] = 0
14 | return torch.max(mat_masked, dim=1), non_zero_rows
15 |
16 |
17 | def get_min_per_row(mat:torch.Tensor, mask:torch.Tensor):
18 | non_inf_rows = torch.any(mask, dim=1)
19 | mat_masked = mat.clone()
20 | mat_masked[~mask] = float("inf")
21 | return torch.min(mat_masked, dim=1), non_inf_rows
22 |
23 |
24 | class PointTripletLoss(BaseLoss):
25 | def __init__(self, margin:float, style:str, corr_dist:float, sample_num:int, pos_dist:float, neg_dist:float):
26 | super().__init__()
27 | assert style in ["soft", "hard"]
28 | print("PointTripletLoss: margin=%.1f, style=%s" % (margin, style))
29 | self.margin = margin
30 | self.style = style
31 | self.corr_dist = corr_dist
32 | self.sample_num = sample_num
33 | self.pos_dist = pos_dist
34 | self.neg_dist = neg_dist
35 |
36 | def __call__(self, feats:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor):
37 | source_indices, target_indices = torch.where(positives_mask == True)
38 | select_indices = torch.where(source_indices < target_indices)
39 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist()
40 |
41 | losses = []
42 | point_stats = {
43 | "fitness":[],
44 | "triplet_num":[],
45 | "non_zero_triplet_num":[],
46 | "pos_min":[],
47 | "pos_mean":[],
48 | "pos_max":[],
49 | "neg_min":[],
50 | "neg_mean":[],
51 | "neg_max":[]
52 | }
53 |
54 | for sndx, tndx in zip(source_indices, target_indices):
55 | # construct pcd from coords
56 | source_pcd = o3d.geometry.PointCloud()
57 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy())
58 | source_pcd.paint_uniform_color([0,0,1])
59 | target_pcd = o3d.geometry.PointCloud()
60 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy())
61 | target_pcd.paint_uniform_color([1,0,0])
62 | # o3d.visualization.draw_geometries([source_pcd, target_pcd])
63 | # icp and get points set
64 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4))
65 | corr_set = np.asarray(reg_p2p.correspondence_set)
66 | assert reg_p2p.fitness > 0.05
67 | # sample, Ns = 64
68 | sample_indices = np.random.choice(corr_set.shape[0], min(corr_set.shape[0], self.sample_num))
69 | Ns = sample_indices.shape[0]
70 | # sample_set:
71 | # [ [s0, s1, s2, ... ], Ns
72 | # [t0, t1, t2, ... ] ] Ns
73 | sample_set = corr_set[sample_indices].T.tolist()
74 |
75 | # sample coords and feats
76 | scoord, tcoord = coords[sndx][sample_set[0]], coords[tndx][sample_set[1]]
77 | sfeat, tfeat = feats[sndx][sample_set[0]], feats[tndx][sample_set[1]]
78 | # Ns * Ns
79 | coord_dist = torch.norm(scoord.unsqueeze(1) - tcoord.unsqueeze(0), dim=2)
80 | # Ns * Ns
81 | feat_dist = torch.norm(sfeat.unsqueeze(1) - tfeat.unsqueeze(0), dim=2)
82 | # get hardest positive and negative
83 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(feat_dist, coord_dist < self.pos_dist)
84 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(feat_dist, coord_dist > self.neg_dist)
85 | # positive <=> anchor <=> negative
86 | a_keep_idx = torch.where(a1p_keep & a2n_keep)[0]
87 | triplet_num = a_keep_idx.shape[0]
88 | if triplet_num == 0: continue
89 |
90 | anc_ind = torch.arange(Ns).to(hardest_positive_indices.device)[a_keep_idx]
91 | pos_ind = hardest_positive_indices[a_keep_idx]
92 | neg_ind = hardest_negative_indices[a_keep_idx]
93 |
94 | triplet_dist = torch.norm(sfeat[anc_ind] - tfeat[pos_ind], dim=1) - torch.norm(sfeat[anc_ind] - tfeat[neg_ind], dim=1)
95 |
96 | non_zero_triplet_num = torch.where((triplet_dist + self.margin) > 0)[0].shape[0]
97 |
98 |
99 | if self.style == "hard":
100 | this_pair_loss = F.relu(triplet_dist + self.margin).mean()
101 | elif self.style == "soft":
102 | this_pair_loss = torch.log(1+self.margin*torch.exp(triplet_dist)).mean()
103 | else:
104 | raise NotImplementedError(f"PointTripletLoss: unkown style {self.style}")
105 | # this_pair_loss = F.relu(triplet_dist).mean()
106 |
107 | losses.append(this_pair_loss)
108 |
109 | point_stats["fitness"].append(reg_p2p.fitness)
110 | point_stats["triplet_num"].append(triplet_num)
111 | point_stats["non_zero_triplet_num"].append(non_zero_triplet_num)
112 | point_stats["pos_min"].append(hardest_positive_dist[a_keep_idx].min().item())
113 | point_stats["pos_mean"].append(hardest_positive_dist[a_keep_idx].mean().item())
114 | point_stats["pos_max"].append(hardest_positive_dist[a_keep_idx].max().item())
115 | point_stats["neg_min"].append(hardest_negative_dist[a_keep_idx].min().item())
116 | point_stats["neg_mean"].append(hardest_negative_dist[a_keep_idx].mean().item())
117 | point_stats["neg_max"].append(hardest_negative_dist[a_keep_idx].max().item())
118 |
119 | avg_point_stats = {e: np.mean(point_stats[e]) for e in point_stats}
120 | loss = torch.stack(losses).mean()
121 | avg_point_stats["loss"] = loss.item()
122 |
123 | return loss, avg_point_stats
124 |
125 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
126 | print("PointTripletLoss: %.3f" % (stats["loss"]))
127 | print(
128 | "Positive: %.3f, %.3f, %.3f | Negative: %.3f, %.3f, %.3f | Triplet: %.1f/%.1f | Fitness:%.3f" %
129 | (
130 | stats["pos_min"], stats["pos_mean"] ,stats["pos_max"],
131 | stats["neg_min"], stats["neg_mean"] ,stats["neg_max"],
132 | stats["triplet_num"], stats["non_zero_triplet_num"], stats["fitness"]
133 | )
134 | )
135 | return
136 |
137 |
138 | ########################### PointConstrativeLoss from LoGG3D #################################
139 |
140 | def hashM(arr, M):
141 | if isinstance(arr, np.ndarray):
142 | N, D = arr.shape
143 | else:
144 | N, D = len(arr[0]), len(arr)
145 |
146 | hash_vec = np.zeros(N, dtype=np.int64)
147 | for d in range(D):
148 | if isinstance(arr, np.ndarray):
149 | hash_vec += arr[:, d] * M**d
150 | else:
151 | hash_vec += arr[d] * M**d
152 | return hash_vec
153 |
154 |
155 | def pdist(A, B, dist_type="L2"):
156 | if dist_type == "L2":
157 | D2 = torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2)
158 | return torch.sqrt(D2 + 1e-7)
159 | elif dist_type == "SquareL2":
160 | return torch.sum((A.unsqueeze(1) - B.unsqueeze(0)).pow(2), 2)
161 | else:
162 | raise NotImplementedError("Not implemented")
163 |
164 |
165 | # class PointContrastiveLoss(BaseLoss):
166 | class PointContrastiveLoss(BaseLoss):
167 | def __init__(self, corr_dist:float, pos_margin:float, neg_margin:float, neg_weight:float, num_pos:int, num_hn_samples:int):
168 | super().__init__()
169 | self.corr_dist = corr_dist
170 | self.pos_margin = pos_margin
171 | self.neg_margin = neg_margin
172 | self.neg_weight = neg_weight
173 | self.num_pos = num_pos
174 | self.num_hn_samples = num_hn_samples
175 |
176 |
177 | def __call__(self, feats:List[torch.Tensor], coords:List[torch.Tensor], positives_mask:torch.Tensor):
178 | source_indices, target_indices = torch.where(positives_mask == True)
179 | select_indices = torch.where(source_indices < target_indices)
180 | source_indices, target_indices = source_indices[select_indices].tolist(), target_indices[select_indices].tolist()
181 |
182 | losses = []
183 | stats = {}
184 |
185 | for sndx, tndx in zip(source_indices, target_indices):
186 | # construct pcd from coords
187 | source_pcd = o3d.geometry.PointCloud()
188 | source_pcd.points = o3d.utility.Vector3dVector(coords[sndx].clone().detach().cpu().numpy())
189 | source_pcd.paint_uniform_color([0,0,1])
190 | target_pcd = o3d.geometry.PointCloud()
191 | target_pcd.points = o3d.utility.Vector3dVector(coords[tndx].clone().detach().cpu().numpy())
192 | target_pcd.paint_uniform_color([1,0,0])
193 | # o3d.visualization.draw_geometries([source_pcd, target_pcd])
194 | # icp and get points set
195 | reg_p2p = o3d.pipelines.registration.registration_icp(source_pcd, target_pcd, self.corr_dist, np.eye(4))
196 | corr_set = np.asarray(reg_p2p.correspondence_set)
197 | assert reg_p2p.fitness > 0.05
198 | losses.append(self.point_contrastive_loss(feats[sndx], feats[tndx], corr_set))
199 |
200 | loss = torch.stack(losses).mean()
201 | stats["loss"] = loss.item()
202 |
203 | return loss, stats
204 |
205 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
206 | print("PointContrastiveLoss: %.3f" % (stats["loss"]))
207 | return
208 |
209 | def point_contrastive_loss(self,
210 | F0:torch.Tensor, F1:torch.Tensor,
211 | positive_pairs:np.ndarray,
212 | # point_pos_margin:float, point_neg_margin:float,
213 | # point_neg_weight:float,
214 | # num_pos=128, num_hn_samples=2048
215 | ):
216 | """
217 | Randomly select "num-pos" positive pairs.
218 | Find the hardest-negative (from a random subset of num_hn_samples) for each point in a positive pair.
219 | Calculate contrastive loss on the tuple (p1,p2,hn1,hn2)
220 | Based on: https://github.com/chrischoy/FCGF/blob/master/lib/trainer.py
221 | """
222 | N0, N1 = len(F0), len(F1)
223 | N_pos_pairs = len(positive_pairs)
224 | hash_seed = max(N0, N1)
225 | sel0 = np.random.choice(N0, min(N0, self.num_hn_samples), replace=False)
226 | sel1 = np.random.choice(N1, min(N1, self.num_hn_samples), replace=False)
227 |
228 | if N_pos_pairs > self.num_pos:
229 | pos_sel = np.random.choice(N_pos_pairs, self.num_pos, replace=False)
230 | sample_pos_pairs = positive_pairs[pos_sel]
231 | else:
232 | sample_pos_pairs = positive_pairs
233 |
234 | # Find negatives for all F1[positive_pairs[:, 1]]
235 | subF0, subF1 = F0[sel0], F1[sel1]
236 |
237 | pos_ind0 = sample_pos_pairs[:, 0] # .long()
238 | pos_ind1 = sample_pos_pairs[:, 1] # .long()
239 | posF0, posF1 = F0[pos_ind0], F1[pos_ind1]
240 |
241 | D01 = pdist(posF0, subF1, dist_type="L2")
242 | D10 = pdist(posF1, subF0, dist_type="L2")
243 |
244 | D01min, D01ind = D01.min(1)
245 | D10min, D10ind = D10.min(1)
246 |
247 | if not isinstance(positive_pairs, np.ndarray):
248 | positive_pairs = np.array(positive_pairs, dtype=np.int64)
249 |
250 | pos_keys = hashM(positive_pairs, hash_seed)
251 |
252 | D01ind = sel1[D01ind.cpu().numpy()]
253 | D10ind = sel0[D10ind.cpu().numpy()]
254 | neg_keys0 = hashM([pos_ind0, D01ind], hash_seed)
255 | neg_keys1 = hashM([D10ind, pos_ind1], hash_seed)
256 |
257 | mask0 = torch.from_numpy(
258 | np.logical_not(np.isin(neg_keys0, pos_keys, assume_unique=False)))
259 | mask1 = torch.from_numpy(
260 | np.logical_not(np.isin(neg_keys1, pos_keys, assume_unique=False)))
261 | pos_loss = F.relu((posF0 - posF1).pow(2).sum(1) - self.pos_margin)
262 | neg_loss0 = F.relu(self.neg_margin - D01min[mask0]).pow(2)
263 | neg_loss1 = F.relu(self.neg_margin - D10min[mask1]).pow(2)
264 |
265 | pos_loss = pos_loss.mean()
266 | neg_loss = (neg_loss0.mean() + neg_loss1.mean()) / 2
267 | loss = pos_loss + self.neg_weight * neg_loss
268 | return loss
269 |
--------------------------------------------------------------------------------
/loss/triplet.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | from typing import List, Dict, Any
7 | from torch.utils.tensorboard import SummaryWriter
8 | from loss.base import BaseLoss
9 | import matplotlib.pylab as plt
10 |
11 | def get_max_per_row(mat:torch.Tensor, mask:torch.Tensor):
12 | non_zero_rows = torch.any(mask, dim=1)
13 | mat_masked = mat.clone()
14 | mat_masked[~mask] = 0
15 | return torch.max(mat_masked, dim=1), non_zero_rows
16 |
17 | def get_min_per_row(mat:torch.Tensor, mask:torch.Tensor):
18 | non_inf_rows = torch.any(mask, dim=1)
19 | mat_masked = mat.clone()
20 | mat_masked[~mask] = float("inf")
21 | return torch.min(mat_masked, dim=1), non_inf_rows
22 |
23 | class TripletMiner:
24 | def __init__(self):
25 | return
26 | def __call__(self, dist_mat:torch.Tensor, positives_mask:torch.Tensor, negatives_mask:torch.Tensor):
27 | # [Ns, Nt] mat
28 | assert dist_mat.shape == positives_mask.shape == negatives_mask.shape
29 | with torch.no_grad():
30 | # Based on pytorch-metric-learning implementation
31 | (hardest_positive_dist, hardest_positive_indices), a1p_keep = get_max_per_row(dist_mat, positives_mask)
32 | (hardest_negative_dist, hardest_negative_indices), a2n_keep = get_min_per_row(dist_mat, negatives_mask)
33 | a_keep_idx = torch.where(a1p_keep & a2n_keep)[0]
34 | anc_ind = torch.arange(dist_mat.size(0)).to(hardest_positive_indices.device)[a_keep_idx]
35 | pos_ind = hardest_positive_indices[a_keep_idx]
36 | neg_ind = hardest_negative_indices[a_keep_idx]
37 |
38 | stats = {
39 | "triplet_num" :a_keep_idx.shape[0],
40 | "max_pos_dist" :torch.max(hardest_positive_dist[a_keep_idx]).item(),
41 | "mean_pos_dist":torch.mean(hardest_positive_dist[a_keep_idx]).item(),
42 | "min_pos_dist" :torch.min(hardest_positive_dist[a_keep_idx]).item(),
43 | "max_neg_dist" :torch.max(hardest_negative_dist[a_keep_idx]).item(),
44 | "mean_neg_dist":torch.mean(hardest_negative_dist[a_keep_idx]).item(),
45 | "min_neg_dist" :torch.min(hardest_negative_dist[a_keep_idx]).item(),
46 | }
47 | return anc_ind, pos_ind, neg_ind, stats
48 |
49 |
50 |
51 | class BatchTripletLoss(BaseLoss):
52 | def __init__(self, margin:float, style:str):
53 | super().__init__()
54 | assert style in ["soft", "hard"]
55 | print("BatchTripletLoss: margin=%.1f, style=%s"%(margin, style))
56 | self.miner = TripletMiner()
57 | self.margin = margin
58 | self.style = style
59 | return
60 |
61 | def __call__(self,
62 | source_feats:torch.Tensor, target_feats:torch.Tensor,
63 | positives_mask:torch.Tensor, negative_mask:torch.Tensor
64 | ):
65 | stats = {}
66 | # get dist l2d mat
67 | dist_mat = torch.norm(source_feats.unsqueeze(1) - target_feats.unsqueeze(0), dim=2)
68 | # miner
69 | anc, pos, neg, miner_stats = self.miner(dist_mat, positives_mask, negative_mask)
70 | stats.update(miner_stats)
71 | pos_dist = torch.norm(source_feats[anc] - target_feats[pos], dim=1)
72 | neg_dist = torch.norm(source_feats[anc] - target_feats[neg], dim=1)
73 | triplet_dist = pos_dist - neg_dist
74 | with torch.no_grad():
75 | stats["norm"] = torch.norm(source_feats, dim=1).mean().item()
76 | stats["non_zero_triplet_num"] = torch.where((triplet_dist + self.margin) > 0)[0].shape[0]
77 |
78 | if self.style == "hard":
79 | loss = F.relu(triplet_dist + self.margin).mean()
80 | elif self.style == "soft":
81 | loss = torch.log(1+self.margin*torch.exp(triplet_dist)).mean()
82 | else:
83 | raise NotImplementedError(f"BatchTripletLoss: unkown style {self.style}")
84 |
85 | stats["loss"] = loss.item()
86 | return loss, stats
87 |
88 | def print_stats(self, epoch:int, phase:str, writer:SummaryWriter, stats:Dict[str, Any]):
89 | print("TripletLoss: %.3f, Norm: %.3f, All/Non-zero: %.1f/%.1f"%(
90 | stats["loss"], stats["norm"], stats["triplet_num"], stats["non_zero_triplet_num"]
91 | ))
92 | print("Positive: %.3f, %.3f, %.3f | Negative: %.3f, %.3f, %.3f (min, avg, max)"%(
93 | stats["min_pos_dist"], stats["mean_pos_dist"], stats["max_pos_dist"],
94 | stats["min_neg_dist"], stats["mean_neg_dist"], stats["max_neg_dist"],
95 | ))
96 | return
97 |
--------------------------------------------------------------------------------
/media/description.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/media/description.png
--------------------------------------------------------------------------------
/media/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/media/pipeline.png
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | import argparse
4 | import numpy as np
5 | import matplotlib.pyplot as plt
6 | from typing import Any, Dict, List
7 |
8 | def get_datetime():
9 | return time.strftime("%Y%m%d_%H%M%S")
10 |
11 | def str2bool(v):
12 | if v.lower() in ("yes", "true", "t", "y", "1"):
13 | return True
14 | elif v.lower() in ("no", "false", "f", "n", "0"):
15 | return False
16 | else:
17 | raise argparse.ArgumentTypeError("Unsupported value encountered.")
18 |
19 | def get_idx_from_string(elem):
20 | """
21 | 000021.npy -> 21
22 | """
23 | return int(elem.split(".")[0])
24 |
25 |
26 | def tensors2numbers(data):
27 | """
28 | ```python
29 | stats = {e: stats[e].item() if torch.is_tensor(stats[e]) else stats[e] for e in stats}
30 | ```
31 | """
32 | if data is None: return data
33 | else:
34 | if torch.is_tensor(data):
35 | return data.item()
36 | elif isinstance(data, list) or isinstance(data, tuple):
37 | for i, _ in enumerate(data):
38 | data[i] = tensors2numbers(data[i])
39 | return data
40 | elif isinstance(data, dict):
41 | for e in data:
42 | data[e] = tensors2numbers(data[e])
43 | return data
44 | else:
45 | return data
46 |
47 | def tensors2device(data:Any, device:torch.device):
48 | """
49 | # [tensor.to(device)]
50 | """
51 | if data is None: return data
52 | else:
53 | if torch.is_tensor(data):
54 | return data.to(device)
55 | elif isinstance(data, list) or isinstance(data, tuple):
56 | for i, _ in enumerate(data):
57 | data[i] = tensors2device(data[i], device)
58 | return data
59 | elif isinstance(data, dict):
60 | for e in data:
61 | data[e] = tensors2device(data[e], device)
62 | return data
63 | else:
64 | raise NotImplementedError("tensors2device: %s not implemented error"%str(type(data)))
65 |
66 |
67 | def avg_stats(stats:List):
68 | avg = stats[0]
69 | for e in avg:
70 | if isinstance(avg[e], Dict):
71 | this_stats = [stats[i][e] for i in range(len(stats))]
72 | avg[e] = avg_stats(this_stats)
73 | else:
74 | avg[e] = np.mean([stats[i][e] for i in range(len(stats))])
75 | return avg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/models/__init__.py
--------------------------------------------------------------------------------
/models/gapr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import Dict
4 |
5 | from models.utils.aggregation.gem import MeanGeM
6 | from models.utils.extraction.mink.minkfpn import MinkFPN
7 | from models.utils.transformers.transgeo import PCTrans
8 |
9 | class GAPR(nn.Module):
10 | def __init__(self, minkfpn:Dict, pctrans:Dict, meangem:Dict, **kw):
11 | super(GAPR, self).__init__()
12 | print("Model: GAPR")
13 | self.minkfpn = MinkFPN(**minkfpn)
14 | self.geneous_names = ["ground", "aerial"]
15 |
16 | self.ground_trans = PCTrans(**pctrans)
17 | self.aerial_trans = PCTrans(**pctrans)
18 |
19 | self.meangem = MeanGeM(**meangem)
20 |
21 |
22 | def forward(self, coords:torch.Tensor, feats:torch.Tensor, geneous:torch.Tensor):
23 | BS = geneous.shape[0]
24 | cnn_coords, cnn_feats = self.minkfpn(coords, feats)
25 | attn_feats, attn_scores = [], []
26 | for ndx in range(BS):
27 | if self.geneous_names[geneous[ndx].item()] == "ground":
28 | attn_feat, attn_score = self.ground_trans(cnn_feats[ndx].unsqueeze(0))
29 | attn_feats.append(attn_feat.squeeze(0))
30 | attn_scores.append(attn_score.squeeze(0))
31 | elif self.geneous_names[geneous[ndx].item()] == "aerial":
32 | attn_feat, attn_score = self.aerial_trans(cnn_feats[ndx].unsqueeze(0))
33 | attn_feats.append(attn_feat.squeeze(0))
34 | attn_scores.append(attn_score.squeeze(0))
35 | else: raise NotImplementedError
36 |
37 |
38 | batch_feats = torch.stack([self.meangem(feat) for feat in attn_feats], dim=0)
39 | # batch_feats = torch.stack([self.meangem(feat) for feat in cnn_feats], dim=0)
40 |
41 | return cnn_coords, cnn_feats, attn_scores, batch_feats
42 |
--------------------------------------------------------------------------------
/models/lprmodel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import copy
4 | import yaml
5 | from typing import Dict, Any
6 |
7 | class LPRModel:
8 | """
9 | # Wrapper for models
10 | """
11 | def __init__(self):
12 | return
13 |
14 | def construct(self, name:str, **kw):
15 | self.config = copy.deepcopy(kw)
16 | self.config["name"] = name
17 | self.name = name
18 | self.model = None
19 | if self.name == "GAPR":
20 | from models.gapr import GAPR
21 | self.model = GAPR(**kw)
22 | else:
23 | raise NotImplementedError("LPRModel: model %s not implemented" % self.name)
24 |
25 | def __call__(self, data:Dict[str, Any]) -> Dict[str, Any]:
26 | output:Dict[str, Any] = {}
27 | if self.name == "GAPR":
28 | assert set(["coords", "feats", "geneous"]) <= set(data.keys())
29 | output["coords"], output["feats"], output["scores"], output["embeddings"] = self.model(data["coords"], data["feats"], data["geneous"])
30 | else:
31 | raise NotImplementedError("LPRModel: model %s not implemented" % self.name)
32 | return output
33 |
34 | def save(self, path:str):
35 | pth_file_dict = {"config":self.config, "weight": self.model.module.state_dict()}
36 | torch.save(pth_file_dict, path)
37 | return
38 |
39 | def load(self, path, device):
40 | pth_file_dict = torch.load(path, map_location=device)
41 | print("LPRModel: load\n", pth_file_dict["config"])
42 | self.construct(**pth_file_dict["config"])
43 | self.model.load_state_dict(pth_file_dict["weight"])
44 | return
45 |
46 |
47 | def import_and_save(self, config_path: str, weight_path:str, save_path:str):
48 | """
49 | import models from other project and save it
50 | """
51 | assert (os.path.exists(config_path)) and (os.path.exists(weight_path)) and (not os.path.exists(save_path))
52 | pth_file_dict = {}
53 | # load weights
54 | pth_file_dict["weight"] = torch.load(weight_path)
55 | # load config
56 | f = open(config_path, encoding="utf-8")
57 | pth_file_dict["config"] = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件
58 | f.close()
59 | torch.save(pth_file_dict, save_path)
60 | return
61 |
--------------------------------------------------------------------------------
/models/utils/aggregation/gem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class GeM(nn.Module):
5 | def __init__(self, pn=256, p=3, eps=1e-6):
6 | super(GeM, self).__init__()
7 | self.p = nn.Parameter(torch.ones(1) * p)
8 | self.eps = eps
9 | self.f = nn.AvgPool1d(pn) # pn = 256
10 | def forward(self, x:torch.Tensor):
11 | temp = x.clamp(min=self.eps).pow(self.p)
12 | temp = self.f(temp)
13 | temp = temp.pow(1./self.p)
14 | # 防止把第一维压缩掉
15 | temp = temp.squeeze(dim=2)
16 | return temp
17 |
18 | class MeanGeM(nn.Module):
19 | def __init__(self, p:float, eps:float):
20 | # p=3, eps=0.0000001
21 | super(MeanGeM, self).__init__()
22 | self.p = nn.Parameter(torch.ones(1) * p)
23 | self.eps = eps
24 | def forward(self, x:torch.Tensor):
25 | # x: [pn, fs]
26 | x = x.clamp(min=self.eps).pow(self.p)
27 | x = x.mean(dim=0)
28 | x = x.pow(1./self.p)
29 | return x
--------------------------------------------------------------------------------
/models/utils/extraction/mink/minkfpn.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import torch
5 | import torch.nn as nn
6 | import MinkowskiEngine as ME
7 | from MinkowskiEngine.modules.resnet_block import BasicBlock
8 | from models.utils.extraction.mink.resnet import ResNetBase
9 | from models.utils.extraction.mink.utils import minkowski_decomposed, minkowski_sparse
10 | from typing import List
11 |
12 | class MinkFPN(ResNetBase):
13 | # Feature Pyramid Network (FPN) architecture implementation using Minkowski ResNet building blocks
14 | def __init__(self, # minkloc3d
15 | quant_size:float, # 0.6
16 | in_channels:int, # 1
17 | out_channels:int, # 256
18 | num_top_down:int, # 1
19 | conv0_kernel_size:int, # 5
20 | layers:List, # (1, 1, 1)
21 | planes:List, # (32, 64, 64)
22 | block=BasicBlock, # defualt
23 | ):
24 | self.quant_size = quant_size
25 | assert len(layers) == len(planes)
26 | assert 1 <= len(layers)
27 | assert 0 <= num_top_down <= len(layers)
28 | self.num_bottom_up = len(layers)
29 | self.num_top_down = num_top_down
30 | self.conv0_kernel_size = conv0_kernel_size
31 | self.block = block
32 | self.layers = layers
33 | self.planes = planes
34 | self.lateral_dim = out_channels
35 | self.init_dim = planes[0]
36 | ResNetBase.__init__(self, in_channels, out_channels, D=3)
37 |
38 | def network_initialization(self, in_channels, out_channels, D):
39 | assert len(self.layers) == len(self.planes)
40 | assert len(self.planes) == self.num_bottom_up
41 |
42 | self.convs = nn.ModuleList() # Bottom-up convolutional blocks with stride=2
43 | self.bn = nn.ModuleList() # Bottom-up BatchNorms
44 | self.blocks = nn.ModuleList() # Bottom-up blocks
45 | self.tconvs = nn.ModuleList() # Top-down tranposed convolutions
46 | self.conv1x1 = nn.ModuleList() # 1x1 convolutions in lateral connections
47 |
48 | # The first convolution is special case, with kernel size = 5
49 | self.inplanes = self.planes[0]
50 | self.conv0 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=self.conv0_kernel_size,
51 | dimension=D)
52 | self.bn0 = ME.MinkowskiBatchNorm(self.inplanes)
53 |
54 | for plane, layer in zip(self.planes, self.layers):
55 | self.convs.append(ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D))
56 | self.bn.append(ME.MinkowskiBatchNorm(self.inplanes))
57 | self.blocks.append(self._make_layer(self.block, plane, layer))
58 |
59 | # Lateral connections
60 | for i in range(self.num_top_down):
61 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - i], self.lateral_dim, kernel_size=1,
62 | stride=1, dimension=D))
63 | self.tconvs.append(ME.MinkowskiConvolutionTranspose(self.lateral_dim, self.lateral_dim, kernel_size=2,
64 | stride=2, dimension=D))
65 | # There's one more lateral connection than top-down TConv blocks
66 | if self.num_top_down < self.num_bottom_up:
67 | # Lateral connection from Conv block 1 or above
68 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[-1 - self.num_top_down], self.lateral_dim, kernel_size=1,
69 | stride=1, dimension=D))
70 | else:
71 | # Lateral connection from Con0 block
72 | self.conv1x1.append(ME.MinkowskiConvolution(self.planes[0], self.lateral_dim, kernel_size=1,
73 | stride=1, dimension=D))
74 |
75 | self.relu = ME.MinkowskiReLU(inplace=True)
76 |
77 | def forward(self, coords:torch.Tensor, feats:torch.Tensor):
78 | # Sparse Quant
79 | x = minkowski_sparse(coords, feats, self.quant_size)
80 |
81 | # *** BOTTOM-UP PASS ***
82 | # First bottom-up convolution is special (with bigger kernel)
83 | feature_maps = []
84 | x = self.conv0(x)
85 | x = self.bn0(x)
86 | x = self.relu(x)
87 | if self.num_top_down == self.num_bottom_up:
88 | feature_maps.append(x)
89 |
90 | # BOTTOM-UP PASS
91 | for ndx, (conv, bn, block) in enumerate(zip(self.convs, self.bn, self.blocks)):
92 | x = conv(x) # Downsample (conv stride=2 with 2x2x2 kernel)
93 | x = bn(x)
94 | x = self.relu(x)
95 | x = block(x)
96 | if self.num_bottom_up - 1 - self.num_top_down <= ndx < len(self.convs) - 1:
97 | feature_maps.append(x)
98 |
99 | assert len(feature_maps) == self.num_top_down
100 |
101 | x = self.conv1x1[0](x)
102 |
103 | # TOP-DOWN PASS
104 | for ndx, tconv in enumerate(self.tconvs):
105 | x = tconv(x) # Upsample using transposed convolution
106 | x = x + self.conv1x1[ndx+1](feature_maps[-ndx - 1])
107 |
108 | # Decomposed
109 | cnn_coords, cnn_feats = minkowski_decomposed(x, self.quant_size)
110 | return cnn_coords, cnn_feats
111 |
112 | if __name__ == "__main__":
113 | model = MinkFPN(
114 | quant_size=0.6,
115 | in_channels=1,
116 | out_channels=256,
117 | num_top_down=1,
118 | conv0_kernel_size=5,
119 | layers=[1,1,1],
120 | planes=[32,64,64],
121 | ).cuda()
122 | BS, PN, FS = 16, 23212, 1
123 | coords, feats = torch.rand((BS, PN, 3))*60.0, torch.rand((BS, PN, FS))
124 | coords, feats = coords.cuda(), feats.cuda()
125 | cnn_coords, cnn_feats = model(coords, feats)
126 |
--------------------------------------------------------------------------------
/models/utils/extraction/mink/resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
2 | #
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to
6 | # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7 | # of the Software, and to permit persons to whom the Software is furnished to do
8 | # so, subject to the following conditions:
9 | #
10 | # The above copyright notice and this permission notice shall be included in all
11 | # copies or substantial portions of the Software.
12 | #
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | # SOFTWARE.
20 | #
21 | # Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
22 | # Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
23 | # of the code.
24 |
25 | import torch.nn as nn
26 |
27 | import MinkowskiEngine as ME
28 | from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
29 |
30 |
31 | class ResNetBase(nn.Module):
32 | block = None
33 | layers = ()
34 | init_dim = 64
35 | planes = (64, 128, 256, 512)
36 |
37 | def __init__(self, in_channels, out_channels, D=3):
38 | nn.Module.__init__(self)
39 | self.D = D
40 | assert self.block is not None
41 |
42 | self.network_initialization(in_channels, out_channels, D)
43 | self.weight_initialization()
44 |
45 | def network_initialization(self, in_channels, out_channels, D):
46 | self.inplanes = self.init_dim
47 | self.conv1 = ME.MinkowskiConvolution(
48 | in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D)
49 |
50 | self.bn1 = ME.MinkowskiBatchNorm(self.inplanes)
51 | self.relu = ME.MinkowskiReLU(inplace=True)
52 |
53 | self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D)
54 |
55 | self.layer1 = self._make_layer(
56 | self.block, self.planes[0], self.layers[0], stride=2)
57 | self.layer2 = self._make_layer(
58 | self.block, self.planes[1], self.layers[1], stride=2)
59 | self.layer3 = self._make_layer(
60 | self.block, self.planes[2], self.layers[2], stride=2)
61 | self.layer4 = self._make_layer(
62 | self.block, self.planes[3], self.layers[3], stride=2)
63 |
64 | self.conv5 = ME.MinkowskiConvolution(
65 | self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D)
66 | self.bn5 = ME.MinkowskiBatchNorm(self.inplanes)
67 |
68 | self.glob_avg = ME.MinkowskiGlobalMaxPooling()
69 |
70 | self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True)
71 |
72 | def weight_initialization(self):
73 | for m in self.modules():
74 | if isinstance(m, ME.MinkowskiConvolution):
75 | ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
76 |
77 | if isinstance(m, ME.MinkowskiBatchNorm):
78 | nn.init.constant_(m.bn.weight, 1)
79 | nn.init.constant_(m.bn.bias, 0)
80 |
81 | def _make_layer(self,
82 | block,
83 | planes,
84 | blocks,
85 | stride=1,
86 | dilation=1,
87 | bn_momentum=0.1):
88 | downsample = None
89 | if stride != 1 or self.inplanes != planes * block.expansion:
90 | downsample = nn.Sequential(
91 | ME.MinkowskiConvolution(
92 | self.inplanes,
93 | planes * block.expansion,
94 | kernel_size=1,
95 | stride=stride,
96 | dimension=self.D),
97 | ME.MinkowskiBatchNorm(planes * block.expansion))
98 | layers = []
99 | layers.append(
100 | block(
101 | self.inplanes,
102 | planes,
103 | stride=stride,
104 | dilation=dilation,
105 | downsample=downsample,
106 | dimension=self.D))
107 | self.inplanes = planes * block.expansion
108 | for i in range(1, blocks):
109 | layers.append(
110 | block(
111 | self.inplanes,
112 | planes,
113 | stride=1,
114 | dilation=dilation,
115 | dimension=self.D))
116 |
117 | return nn.Sequential(*layers)
118 |
119 | def forward(self, x):
120 | x = self.conv1(x)
121 | x = self.bn1(x)
122 | x = self.relu(x)
123 | x = self.pool(x)
124 |
125 | x = self.layer1(x)
126 | x = self.layer2(x)
127 | x = self.layer3(x)
128 | x = self.layer4(x)
129 |
130 | x = self.conv5(x)
131 | x = self.bn5(x)
132 | x = self.relu(x)
133 |
134 | x = self.glob_avg(x)
135 | return self.final(x)
136 |
137 |
138 | class ResNet14(ResNetBase):
139 | BLOCK = BasicBlock
140 | LAYERS = (1, 1, 1, 1)
141 |
142 |
143 | class ResNet18(ResNetBase):
144 | BLOCK = BasicBlock
145 | LAYERS = (2, 2, 2, 2)
146 |
147 |
148 | class ResNet34(ResNetBase):
149 | BLOCK = BasicBlock
150 | LAYERS = (3, 4, 6, 3)
151 |
152 |
153 | class ResNet50(ResNetBase):
154 | BLOCK = Bottleneck
155 | LAYERS = (3, 4, 6, 3)
156 |
157 |
158 | class ResNet101(ResNetBase):
159 | BLOCK = Bottleneck
160 | LAYERS = (3, 4, 23, 3)
161 |
162 |
--------------------------------------------------------------------------------
/models/utils/extraction/mink/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import MinkowskiEngine as ME
3 |
4 | def minkowski_sparse(coords:torch.Tensor, feats:torch.Tensor, quant_size:float):
5 | device = coords.device
6 | # sparse_quantize
7 | quant_coords, quant_feats = [], []
8 | for coord, feat in zip(coords.clone().detach().cpu(), feats.clone().detach().cpu()):
9 | quant_coord, quant_feat = ME.utils.sparse_quantize(
10 | coordinates=coord, features=feat, quantization_size=quant_size
11 | )
12 | quant_coords.append(quant_coord)
13 | quant_feats.append(quant_feat)
14 |
15 | # batch collate
16 | batch_coords, batch_feats = ME.utils.sparse_collate(quant_coords, quant_feats)
17 | # to sparse tensor
18 | sparse_tensor = ME.SparseTensor(features=batch_feats.to(device=device), coordinates=batch_coords.to(device=device))
19 | return sparse_tensor
20 |
21 | def minkowski_decomposed(sparse_tensor, quant_size):
22 | coords, feats = sparse_tensor.decomposed_coordinates_and_features
23 | # de-quantize coordinates
24 | a = torch.tensor([2])
25 | coords = [e.double()*quant_size for e in coords]
26 | return coords, feats
27 |
--------------------------------------------------------------------------------
/models/utils/transformers/transgeo.py:
--------------------------------------------------------------------------------
1 | # Author: Sijie Zhu, https://github.com/Jeff-Zilence/TransGeo2022
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import torch
5 | import torch.nn as nn
6 | from functools import partial
7 |
8 | from timm.models.vision_transformer import Block
9 | from timm.models.layers import trunc_normal_
10 | class PCTrans(nn.Module):
11 | def __init__(self,
12 | dim:int,
13 | num_heads:int,
14 | mlp_ratio:int,
15 | depth:int,
16 | qkv_bias:bool,
17 | init_values:float,
18 | drop:float,
19 | attn_drop:float,
20 | drop_path_rate:float
21 | ):
22 | super().__init__()
23 | assert depth >= 1
24 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
25 | self.blocks = nn.Sequential(*[
26 | Block(
27 | dim=dim,
28 | num_heads=num_heads,
29 | mlp_ratio=mlp_ratio,
30 | qkv_bias=qkv_bias,
31 | init_values=init_values,
32 | proj_drop=drop,
33 | attn_drop=attn_drop,
34 | drop_path=dpr[i],
35 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
36 | act_layer=nn.GELU
37 | )
38 | for i in range(depth)])
39 | self.norm = partial(nn.LayerNorm, eps=1e-6)(dim)
40 | self.num_heads = float(num_heads)
41 |
42 | def forward(self, x:torch.Tensor):
43 | attn_score = None
44 |
45 | for i, blk in enumerate(self.blocks):
46 | attn_x = blk.norm1(x)
47 | if i == len(self.blocks)-1:
48 | # decompose attn forward
49 | B, N, C = attn_x.shape
50 | qkv = blk.attn.qkv(attn_x).reshape(B, N, 3, blk.attn.num_heads, C // blk.attn.num_heads).permute(2, 0, 3, 1, 4)
51 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
52 |
53 | attn = (q @ k.transpose(-2, -1)) * blk.attn.scale
54 | attn = attn.softmax(dim=-1)
55 |
56 | # get attn_score
57 | attn_score = attn.sum(axis=1).sum(axis=1) - self.num_heads
58 | attn_score = torch.sigmoid(attn_score)
59 |
60 | attn = blk.attn.attn_drop(attn)
61 |
62 | attn_x = (attn @ v).transpose(1, 2).reshape(B, N, C)
63 | attn_x = blk.attn.proj(attn_x)
64 | attn_x = blk.attn.proj_drop(attn_x)
65 | else:
66 | attn_x = blk.attn(attn_x)
67 |
68 | x = x + blk.drop_path1(blk.ls1(attn_x))
69 | x = x + blk.drop_path2(blk.ls2(blk.mlp(blk.norm2(x))))
70 | x = self.norm(x)
71 | return x, attn_score
72 |
73 | def init_weights_vit_timm(module: nn.Module, name: str = ''):
74 | """ ViT weight initialization, original timm impl (for reproducibility) """
75 | if isinstance(module, nn.Linear):
76 | trunc_normal_(module.weight, std=.02)
77 | if module.bias is not None:
78 | nn.init.zeros_(module.bias)
79 | elif hasattr(module, 'init_weights'):
80 | module.init_weights()
81 |
82 | def main():
83 | BS, PN, FS = 1, 2342, 256
84 | # model = deit_small_distilled_patch16_224(save="/home/jieyr/code/TransGeo2022/save")
85 | model = PCTrans(
86 | dim=256,
87 | num_heads=8,
88 | mlp_ratio=4,
89 | qkv_bias=True,
90 | depth=4,
91 | init_values=None,
92 | drop=0.0,
93 | attn_drop=0.0,
94 | drop_path_rate=0.0
95 | )
96 | feats = torch.rand((BS, PN, FS))
97 | attn_feats, attn_score = model(feats)
98 | print(attn_feats.size(), attn_score)
99 |
100 | if __name__ == "__main__":
101 | main()
--------------------------------------------------------------------------------
/pretrain/GAPR.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-RoboticsLab/GAPR/57dbfce2b64ac69cdc4f10a3b480ba917438a2b2/pretrain/GAPR.pth
--------------------------------------------------------------------------------
/pretrain/config.yaml:
--------------------------------------------------------------------------------
1 | dataloaders:
2 | train:
3 | augment:
4 | if_jrr: false
5 | name: TrainAugment
6 | rotate_cmd: zxy10
7 | translate_delta: 1.0
8 | collate:
9 | name: MetricCollate
10 | dataset: /nas/slam/datasets/GAPR/dataset/benchmark/train
11 | num_workers: 4
12 | sampler:
13 | batch_expansion_rate: 1.4
14 | batch_size: 16
15 | batch_size_limit: 32
16 | max_batches: null
17 | name: HeteroTripletSample
18 | dist:
19 | backend: nccl
20 | find_unused_parameters: false
21 | method:
22 | loss:
23 | batch_loss:
24 | margin: 1.0
25 | style: hard
26 | name: GAPRLoss
27 | overlap_loss:
28 | corr_dist: 2.0
29 | overlap_loss_scale: 1.0
30 | point_loss:
31 | corr_dist: 2.0
32 | margin: 10
33 | neg_dist: 20.0
34 | pos_dist: 2.1
35 | sample_num: 64
36 | style: soft
37 | point_loss_scale: 0.5
38 | model:
39 | debug: false
40 | meangem:
41 | eps: 1.0e-06
42 | p: 3.0
43 | minkfpn:
44 | conv0_kernel_size: 5
45 | in_channels: 1
46 | layers:
47 | - 1
48 | - 1
49 | - 1
50 | num_top_down: 1
51 | out_channels: 256
52 | planes:
53 | - 32
54 | - 64
55 | - 64
56 | quant_size: 0.6
57 | name: GAPR
58 | pctrans:
59 | attn_drop: 0.0
60 | depth: 1
61 | dim: 256
62 | drop: 0.0
63 | drop_path_rate: 0.0
64 | init_values: null
65 | mlp_ratio: 4
66 | num_heads: 2
67 | qkv_bias: true
68 | results:
69 | logs: null
70 | weights: /home/jieyr/code/ppr/results/weights/Ablation
71 | train:
72 | batch_expansion_th: 0.7
73 | epochs: 40
74 | lr: 0.001
75 | scheduler_milestones:
76 | - 15
77 | - 30
78 | weight_decay: 0.001
79 |
--------------------------------------------------------------------------------
/results/evaluate/readme.txt:
--------------------------------------------------------------------------------
1 | The results of evaluation is saved here.
--------------------------------------------------------------------------------
/results/weights/readme.txt:
--------------------------------------------------------------------------------
1 | The weights of training is saved here.
--------------------------------------------------------------------------------
/scripts/add_path.sh:
--------------------------------------------------------------------------------
1 | export PYTHONPATH=$PYTHONPATH:/home/jieyr/Codes/GAPR
2 |
--------------------------------------------------------------------------------
/scripts/clean.sh:
--------------------------------------------------------------------------------
1 | rm -r /home/jieyr/Codes/GAPR/results/evaluate/*
2 | rm -r /home/jieyr/Codes/GAPR/results/logs/*
3 | rm -r /home/jieyr/Codes/GAPR/results/weights/*
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train/train.py
2 |
--------------------------------------------------------------------------------
/train/train.py:
--------------------------------------------------------------------------------
1 | # Author: Jacek Komorowski, https://github.com/jac99/MinkLoc3D
2 | # Modified: Yingrui Jie, https://github.com/SYSU-RoboticsLab/GAPR
3 |
4 | import os
5 | import time
6 | import argparse
7 | import yaml
8 | import torch
9 | from tqdm import tqdm
10 | from typing import Dict, Any, List
11 | import numpy as np
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 |
15 | from datasets.dataloders.lprdataloader import LPRDataLoader
16 | from models.lprmodel import LPRModel
17 | from loss.lprloss import LPRLoss
18 | from misc.utils import get_datetime, tensors2device, avg_stats
19 |
20 | from torch.utils.tensorboard import SummaryWriter
21 |
22 | def parse_opt()->dict:
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--yaml", type=str, required=True)
25 | parser.add_argument("--local_rank", type=int, required=True)
26 | opt = parser.parse_args()
27 | opt = vars(opt)
28 | f = open(opt["yaml"], encoding="utf-8")
29 | lprtrain = yaml.load(f, Loader=yaml.FullLoader) #读取yaml文件
30 | f.close()
31 | return lprtrain
32 |
33 | def main(**kw):
34 | # 初始化torch.distributed
35 | local_rank = int(os.environ["LOCAL_RANK"])
36 | torch.cuda.set_device(local_rank)
37 |
38 | dist.init_process_group(backend=kw["dist"]["backend"]) # nccl是GPU设备上最快、最推荐的后端
39 |
40 | # get dataloders
41 | dataloaders = {phase: LPRDataLoader(**kw["dataloaders"]["train"]) for phase in kw["dataloaders"]}
42 | # get model
43 | model = LPRModel()
44 | model.construct(**kw["method"]["model"])
45 | # get loss function
46 | loss_fn = LPRLoss(**kw["method"]["loss"])
47 | # model to local_rank
48 | model.model = model.model.to(local_rank)
49 | # construct DDP model
50 | model.model = DDP(model.model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=kw["dist"]["find_unused_parameters"])
51 | # initialize optimizer after construction of DDP model
52 | optimizer = torch.optim.Adam(model.model.parameters(), lr=kw["train"]["lr"], weight_decay=kw["train"]["weight_decay"])
53 | # get scheduler
54 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, kw["train"]["scheduler_milestones"], gamma=0.1)
55 |
56 | # set results
57 | writer, weights_path = None, None
58 | if local_rank == 0:
59 | model_name = get_datetime()
60 | if kw["results"]["weights"] is not None:
61 | weights_path = os.path.join(kw["results"]["weights"], model_name)
62 | if not os.path.exists(weights_path): os.mkdir(weights_path)
63 | # save config yaml
64 | with open(os.path.join(weights_path, "config.yaml"), "w") as file:
65 | file.write(yaml.dump(dict(kw), allow_unicode=True))
66 | if kw["results"]["logs"] is not None:
67 | logs_path = os.path.join(kw["results"]["logs"], model_name)
68 | writer = SummaryWriter(logs_path)
69 |
70 |
71 | # get phases from dataloaders
72 | phases = list(dataloaders.keys())
73 | # visualize len of phases database
74 | if local_rank == 0:
75 | for phase in phases:
76 | print("Dataloder: {} set len = {}".format(phase, len(dataloaders[phase].dataset)))
77 |
78 | itera = None
79 | if local_rank == 0: itera = tqdm(range(kw["train"]["epochs"]))
80 | else: itera = range(kw["train"]["epochs"])
81 | for epoch in itera:
82 | for phase in phases:
83 | # switch mode
84 | if phase=="train": model.model.train()
85 | else: model.model.eval()
86 |
87 | # wait barrier
88 | dist.barrier()
89 |
90 | phase_stats:List[Dict] = []
91 |
92 | for data, mask in dataloaders[phase]:
93 | # data to device
94 | data = tensors2device(data, device=local_rank)
95 | # clear grad
96 | optimizer.zero_grad()
97 |
98 | with torch.set_grad_enabled(phase == "train"):
99 | output = model(data)
100 | loss, stats = loss_fn(output, mask)
101 | if phase == "train":
102 | loss.backward()
103 | optimizer.step()
104 |
105 | phase_stats.append(stats)
106 | torch.cuda.empty_cache()
107 |
108 | # ******* PHASE END *******
109 | # compute mean stats for the epoch
110 | phase_avg_stats = avg_stats(phase_stats)
111 | # print and save stats
112 | if local_rank == 0: loss_fn.print_stats(epoch, phase, writer, phase_avg_stats)
113 |
114 | # ******* EPOCH END *******
115 | # scheduler
116 | if scheduler is not None: scheduler.step()
117 |
118 | if local_rank == 0 and weights_path is not None:
119 | model.save(os.path.join(weights_path, "{}.pth".format(epoch)))
120 |
121 | if __name__ == "__main__":
122 | main(**parse_opt())
123 |
124 | # CUDA_VISIBLE_DEVICES="0,1" python -m torch.distributed.launch --nproc_per_node 2 train.py
--------------------------------------------------------------------------------