├── .gitignore ├── LICENSE ├── README.md ├── database.py ├── eval.py ├── gen_desc ├── .vscode │ ├── launch.json │ ├── settings.json │ └── tasks.json ├── CMakeLists.txt ├── conf │ └── sem_config.yaml ├── gen_cloud │ ├── CMakeLists.txt │ ├── genData.cpp │ ├── genData.hpp │ ├── semanticConf.cpp │ ├── semanticConf.hpp │ └── types.hpp └── kitti_gen.cpp ├── gen_pairs.py ├── gen_pairs_kitti360.py ├── net.py ├── pic └── pipeline.png └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | __pycache__ 131 | 132 | data 133 | runs 134 | result 135 | model 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Lilin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RINet 2 | Code for RA-L 2022 paper [RINet: Efficient 3D Lidar-Based Place Recognition Using Rotation Invariant Neural Network](https://ieeexplore.ieee.org/document/9712221) 3 | 4 | ![pipeline](./pic/pipeline.png) 5 | 6 | ## Citation 7 | 8 | ``` 9 | @ARTICLE{9712221, 10 | author={Li, Lin and Kong, Xin and Zhao, Xiangrui and Huang, Tianxin and Li, Wanlong and Wen, Feng and Zhang, Hongbo and Liu, Yong}, 11 | journal={IEEE Robotics and Automation Letters}, 12 | title={{RINet: Efficient 3D Lidar-Based Place Recognition Using Rotation Invariant Neural Network}}, 13 | year={2022}, 14 | volume={7}, 15 | number={2}, 16 | pages={4321-4328}, 17 | doi={10.1109/LRA.2022.3150499}} 18 | ``` 19 | 20 | ## Environment 21 | ### Conda 22 | ``` 23 | conda create -n rinet python=3.7 24 | conda activate rinet 25 | conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia 26 | conda install tqdm scikit-learn matplotlib tensorboard 27 | ``` 28 | 29 | ## Usage 30 | ### Preprocessing 31 | You can directly use the [descriptors we provide](https://drive.google.com/file/d/1Do36bYZ_LBM209WYvXc_WCOBzrLI7uYu/view?usp=sharing), or you can generate descriptors by yourself according to the descriptions below: 32 | Requirements: [OpenCV](https://opencv.org/), [PCL](https://pointclouds.org/) and [yaml-cpp](https://github.com/jbeder/yaml-cpp). 33 | ``` 34 | cd gen_desc && mkdir build && cd build && cmake .. && make -j4 35 | ``` 36 | If the compilation is successful, then execute the following command to generate the descriptors (All descriptors will be saved to a single binary file "output_file.bin"): 37 | 38 | ``` 39 | ./kitti_gen cloud_folder label_folder output_file.bin 40 | ``` 41 | ### Training 42 | #### Data structure 43 | ``` 44 | data 45 | |---desc_kitti 46 | | |---00.npy 47 | | |---01.npy 48 | | |---.... 49 | |---gt_kitti 50 | | |---00.npz 51 | | |---01.npz 52 | | |---... 53 | |---pose_kitti 54 | | |---00.txt 55 | | |---02.txt 56 | | |--... 57 | |---pairs_kitti 58 | | |... 59 | ``` 60 | You can download the [provided preprocessed data](https://drive.google.com/file/d/1Do36bYZ_LBM209WYvXc_WCOBzrLI7uYu/view?usp=sharing). 61 | 62 | #### Training model 63 | ``` 64 | python train.py --seq='00' 65 | ``` 66 | 67 | ### Testing 68 | Pretrained models can be downloaded from this [link](https://drive.google.com/file/d/1pjoTRlenQJUCDJevMgQsELaL_FbRdPXo/view?usp=sharing). 69 | ``` 70 | python eval.py 71 | ``` 72 | 73 | ## Raw Data 74 | We provide the [raw data](https://drive.google.com/file/d/1N19ZYVKoOvVzrTxiokl6oSR-pEezVPZJ/view?usp=sharing) of the tables and curves in the paper, including compared methods DiSCO and Locus. Raw data for other methods can be found in this [repository](https://github.com/lilin-hitcrt/SSC). 75 | -------------------------------------------------------------------------------- /database.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import os 4 | import numpy as np 5 | import random 6 | from matplotlib import pyplot as plt 7 | import json 8 | import random 9 | 10 | 11 | class SigmoidDataset_eval(Dataset): 12 | def __init__(self, sequs=['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10'], neg_ratio=1, desc_folder="./data/desc_kitti", gt_folder="./data/gt_kitti", eva_ratio=0.1) -> None: 13 | super().__init__() 14 | print(sequs) 15 | self.descs = [] 16 | self.gt_pos = [] 17 | self.gt_neg = [] 18 | self.pos_nums = [0] 19 | self.neg_num = 0 20 | self.pos_num = 0 21 | for seq in sequs: 22 | desc_file = os.path.join(desc_folder, seq+'.npy') 23 | gt_file = os.path.join(gt_folder, seq+'.npz') 24 | self.descs.append(np.load(desc_file)) 25 | gt = np.load(gt_file) 26 | pos = gt['pos'][-int(len(gt['pos'])*eva_ratio):] 27 | neg = gt['neg'][-int(len(gt['neg'])*eva_ratio):] 28 | self.gt_pos.append(pos) 29 | self.gt_neg.append(neg) 30 | self.pos_num += len(self.gt_pos[-1]) 31 | self.pos_nums.append(self.pos_num) 32 | self.neg_num = int(neg_ratio*self.pos_num) 33 | 34 | def __len__(self): 35 | return self.pos_num+self.neg_num 36 | 37 | def __getitem__(self, idx): 38 | if torch.is_tensor(idx): 39 | idx = idx.tolist() 40 | pair = [-1, -1, 0] 41 | if idx >= self.pos_num: 42 | id_seq = random.randint(0, len(self.gt_neg)-1) 43 | id = random.randint(0, len(self.gt_neg[id_seq])-1) 44 | pair = self.gt_neg[int(id_seq)][id] 45 | out = {"desc1": self.descs[int(id_seq)][int( 46 | pair[0])]/50., "desc2": self.descs[int(id_seq)][int(pair[1])]/50., 'label': pair[2]} 47 | return out 48 | for i in range(1, len(self.pos_nums)): 49 | if self.pos_nums[i] > idx: 50 | pair = self.gt_pos[i-1][idx-self.pos_nums[i-1]] 51 | out = {"desc1": self.descs[i-1][int( 52 | pair[0])]/50., "desc2": self.descs[i-1][int(pair[1])]/50., 'label': pair[2]} 53 | return out 54 | 55 | 56 | class SigmoidDataset_train(Dataset): 57 | def __init__(self, sequs=['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10'], neg_ratio=1, desc_folder="./data/desc_kitti", gt_folder="./data/gt_kitti", eva_ratio=0.1) -> None: 58 | super().__init__() 59 | print(sequs) 60 | self.descs = [] 61 | self.gt_pos = [] 62 | self.gt_neg = [] 63 | self.pos_nums = [0] 64 | self.neg_num = 0 65 | self.pos_num = 0 66 | for seq in sequs: 67 | desc_file = os.path.join(desc_folder, seq+'.npy') 68 | gt_file = os.path.join(gt_folder, seq+'.npz') 69 | self.descs.append(np.load(desc_file)) 70 | gt = np.load(gt_file) 71 | pos = gt['pos'][:-int(len(gt['pos'])*eva_ratio)] 72 | neg = gt['neg'][:-int(len(gt['neg'])*eva_ratio)] 73 | self.gt_pos.append(pos) 74 | self.gt_neg.append(neg) 75 | self.pos_num += len(self.gt_pos[-1]) 76 | self.pos_nums.append(self.pos_num) 77 | self.neg_num = int(neg_ratio*self.pos_num) 78 | 79 | def __len__(self): 80 | return self.pos_num+self.neg_num 81 | 82 | def __getitem__(self, idx): 83 | if torch.is_tensor(idx): 84 | idx = idx.tolist() 85 | pair = [-1, -1, 0] 86 | if idx >= self.pos_num: 87 | id_seq = random.randint(0, len(self.gt_neg)-1) 88 | id = random.randint(0, len(self.gt_neg[id_seq])-1) 89 | pair = self.gt_neg[int(id_seq)][id] 90 | out = {"desc1": self.descs[int(id_seq)][int( 91 | pair[0])]/50., "desc2": self.descs[int(id_seq)][int(pair[1])]/50., 'label': pair[2]} 92 | if random.randint(0, 1) > 0: 93 | self.rand_occ(out["desc1"]) 94 | self.rand_occ(out["desc2"]) 95 | return out 96 | for i in range(1, len(self.pos_nums)): 97 | if self.pos_nums[i] > idx: 98 | pair = self.gt_pos[i-1][idx-self.pos_nums[i-1]] 99 | out = {"desc1": self.descs[i-1][int( 100 | pair[0])]/50., "desc2": self.descs[i-1][int(pair[1])]/50., 'label': pair[2]} 101 | if random.randint(0, 1) > 0: 102 | self.rand_occ(out["desc1"]) 103 | self.rand_occ(out["desc2"]) 104 | return out 105 | 106 | def rand_occ(self, in_desc): 107 | n = random.randint(0, 60) 108 | s = random.randint(0, 360-n) 109 | in_desc[:, s:s+n] *= 0 110 | 111 | 112 | class SigmoidDataset(Dataset): 113 | def __init__(self, sequs=['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10'], neg_ratio=1, desc_folder="./data/desc_kitti", gt_folder="./data/gt_kitti") -> None: 114 | super().__init__() 115 | print(sequs) 116 | self.descs = [] 117 | self.gt_pos = [] 118 | self.gt_neg = [] 119 | self.pos_nums = [0] 120 | self.neg_num = 0 121 | self.pos_num = 0 122 | for seq in sequs: 123 | desc_file = os.path.join(desc_folder, seq+'.npy') 124 | gt_file = os.path.join(gt_folder, seq+'.npz') 125 | self.descs.append(np.load(desc_file)) 126 | gt = np.load(gt_file) 127 | self.gt_pos.append(gt['pos']) 128 | self.gt_neg.append(gt['neg']) 129 | self.pos_num += len(self.gt_pos[-1]) 130 | self.pos_nums.append(self.pos_num) 131 | self.neg_num = int(neg_ratio*self.pos_num) 132 | 133 | def __len__(self): 134 | return self.pos_num+self.neg_num 135 | 136 | def __getitem__(self, idx): 137 | if torch.is_tensor(idx): 138 | idx = idx.tolist() 139 | pair = [-1, -1, 0] 140 | if idx >= self.pos_num: 141 | id_seq = random.randint(0, len(self.gt_neg)-1) 142 | id = random.randint(0, len(self.gt_neg[id_seq])-1) 143 | pair = self.gt_neg[int(id_seq)][id] 144 | out = {"desc1": self.descs[int(id_seq)][int( 145 | pair[0])]/50., "desc2": self.descs[int(id_seq)][int(pair[1])]/50., 'label': pair[2]*1.} 146 | if random.randint(0, 2) > 1: 147 | self.rand_occ(out["desc1"]) 148 | self.rand_occ(out["desc2"]) 149 | return out 150 | for i in range(1, len(self.pos_nums)): 151 | if self.pos_nums[i] > idx: 152 | pair = self.gt_pos[i-1][idx-self.pos_nums[i-1]] 153 | out = {"desc1": self.descs[i-1][int(pair[0])]/50., "desc2": self.descs[i-1][int( 154 | pair[1])]/50., 'label': pair[2]*1.} 155 | if random.randint(0, 2) > 1: 156 | self.rand_occ(out["desc1"]) 157 | self.rand_occ(out["desc2"]) 158 | return out 159 | 160 | def rand_occ(self, in_desc): 161 | n = random.randint(0, 60) 162 | s = random.randint(0, 360-n) 163 | in_desc[:, s:s+n] *= 0 164 | 165 | 166 | class evalDataset(Dataset): 167 | def __init__(self, seq="00", desc_folder="./data/desc_kitti", gt_folder="./data/pairs_kitti/neg_100") -> None: 168 | super().__init__() 169 | self.descs = [] 170 | self.pairs = [] 171 | self.num = 0 172 | desc_file = os.path.join(desc_folder, seq+'.npy') 173 | pair_file = os.path.join(gt_folder, seq+'.txt') 174 | self.descs = np.load(desc_file) 175 | self.pairs = np.genfromtxt(pair_file, dtype='int32') 176 | self.num = len(self.pairs) 177 | 178 | def __len__(self): 179 | return self.num 180 | 181 | def __getitem__(self, idx): 182 | if torch.is_tensor(idx): 183 | idx = idx.tolist() 184 | pair = self.pairs[idx] 185 | out = {"desc1": self.descs[int( 186 | pair[0])]/50., "desc2": self.descs[int(pair[1])]/50., 'label': pair[2]} 187 | angle1 = np.random.randint(0, 359) 188 | angle2 = np.random.randint(0, 359) 189 | out["desc1"] = np.roll(out["desc1"], angle1, axis=1) 190 | out["desc2"] = np.roll(out["desc2"], angle2, axis=1) 191 | return out 192 | 193 | 194 | class SigmoidDataset_kitti360(Dataset): 195 | def __init__(self, sequs=['0000', '0002', '0003', '0004', '0005', '0006', '0007', '0009', '0010'], neg_ratio=1, desc_folder="./data/desc_kitti360", gt_folder="./data/gt_kitti360") -> None: 196 | super().__init__() 197 | print(sequs) 198 | self.descs = [] 199 | self.gt_pos = [] 200 | self.gt_neg = [] 201 | self.key_map = [] 202 | self.pos_nums = [0] 203 | self.neg_num = 0 204 | self.pos_num = 0 205 | for seq in sequs: 206 | desc_file = os.path.join(desc_folder, seq+'.npy') 207 | gt_file = os.path.join(gt_folder, seq+'.npz') 208 | self.descs.append(np.load(desc_file)) 209 | self.key_map.append( 210 | json.load(open(os.path.join(desc_folder, seq+'.json')))) 211 | gt = np.load(gt_file) 212 | self.gt_pos.append(gt['pos']) 213 | self.gt_neg.append(gt['neg']) 214 | self.pos_num += len(self.gt_pos[-1]) 215 | self.pos_nums.append(self.pos_num) 216 | self.neg_num = int(neg_ratio*self.pos_num) 217 | 218 | def __len__(self): 219 | return self.pos_num+self.neg_num 220 | 221 | def __getitem__(self, idx): 222 | if torch.is_tensor(idx): 223 | idx = idx.tolist() 224 | pair = [-1, -1, 0] 225 | if idx >= self.pos_num: 226 | id_seq = random.randint(0, len(self.gt_neg)-1) 227 | id = random.randint(0, len(self.gt_neg[id_seq])-1) 228 | pair = self.gt_neg[int(id_seq)][id] 229 | out = {"desc1": self.descs[int(id_seq)][self.key_map[int(id_seq)][str(int( 230 | pair[0]))]]/50., "desc2": self.descs[int(id_seq)][self.key_map[int(id_seq)][str(int(pair[1]))]]/50., 'label': pair[2]} 231 | if random.randint(0, 1) > 0: 232 | self.rand_occ(out["desc1"]) 233 | self.rand_occ(out["desc2"]) 234 | return out 235 | for i in range(1, len(self.pos_nums)): 236 | if self.pos_nums[i] > idx: 237 | pair = self.gt_pos[i-1][idx-self.pos_nums[i-1]] 238 | out = {"desc1": self.descs[i-1][self.key_map[i-1][str(int( 239 | pair[0]))]]/50., "desc2": self.descs[i-1][self.key_map[i-1][str(int(pair[1]))]]/50., 'label': pair[2]} 240 | if random.randint(0, 1) > 0: 241 | self.rand_occ(out["desc1"]) 242 | self.rand_occ(out["desc2"]) 243 | return out 244 | 245 | def rand_occ(self, in_desc): 246 | n = random.randint(0, 60) 247 | s = random.randint(0, 360-n) 248 | in_desc[:, s:s+n] *= 0 249 | 250 | 251 | class evalDataset_kitti360(Dataset): 252 | def __init__(self, seq="0000", desc_folder="./data/desc_kitti360", gt_folder="./data/pairs_kitti360/neg10") -> None: 253 | super().__init__() 254 | self.descs = [] 255 | self.pairs = [] 256 | self.num = 0 257 | desc_file = os.path.join(desc_folder, seq+'.npy') 258 | pair_file = os.path.join(gt_folder, seq+'.txt') 259 | self.descs = np.load(desc_file) 260 | self.pairs = np.genfromtxt(pair_file, dtype='int32') 261 | self.key_map = json.load(open(os.path.join(desc_folder, seq+'.json'))) 262 | self.num = len(self.pairs) 263 | 264 | def __len__(self): 265 | return self.num 266 | 267 | def __getitem__(self, idx): 268 | if torch.is_tensor(idx): 269 | idx = idx.tolist() 270 | pair = self.pairs[idx] 271 | out = {"desc1": self.descs[self.key_map[str(int( 272 | pair[0]))]]/50., "desc2": self.descs[self.key_map[str(int(pair[1]))]]/50., 'label': pair[2]} 273 | return out 274 | 275 | 276 | if __name__ == '__main__': 277 | database = SigmoidDataset_train( 278 | ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10'], 2) 279 | print(len(database)) 280 | for i in range(0, len(database)): 281 | idx = random.randint(0, len(database)-1) 282 | d = database[idx] 283 | print(i, d['label']) 284 | plt.subplot(2, 1, 1) 285 | plt.imshow(d['desc1']) 286 | plt.subplot(2, 1, 2) 287 | plt.imshow(d['desc2']) 288 | plt.show() 289 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from net import RINet, RINet_attention 3 | from database import evalDataset, evalDataset_kitti360 4 | import numpy as np 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | import os 8 | from sklearn import metrics 9 | from matplotlib import pyplot as plt 10 | import sys 11 | import time 12 | import argparse 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | # device=torch.device("cpu") 15 | 16 | 17 | def eval(seq='0000', model_file="./model/attention/00.ckpt", data_type='kitti360'): 18 | net = RINet_attention() 19 | net.load(model_file) 20 | if data_type == 'kitti': 21 | test_dataset = evalDataset(seq) 22 | elif data_type == 'kitti360': 23 | test_dataset = evalDataset_kitti360(seq) 24 | net.to(device=device) 25 | net.eval() 26 | testdataloader = DataLoader( 27 | dataset=test_dataset, batch_size=16384, shuffle=False, num_workers=8) 28 | pred = [] 29 | gt = [] 30 | with torch.no_grad(): 31 | for i_batch, sample_batch in tqdm(enumerate(testdataloader), total=len(testdataloader), desc="Eval seq "+str(seq)): 32 | out, _ = net(sample_batch["desc1"].to( 33 | device=device), sample_batch["desc2"].to(device=device)) 34 | outlabel = out.cpu().tolist() 35 | label = sample_batch['label'] 36 | pred.extend(outlabel) 37 | gt.extend(label.tolist()) 38 | pred = np.nan_to_num(pred) 39 | save_db = np.array([pred, gt]) 40 | save_db = save_db.T 41 | if not os.path.exists('result'): 42 | os.mkdir('result') 43 | np.savetxt(os.path.join('result', seq+'.txt'), save_db, "%.4f") 44 | precision, recall, pr_thresholds = metrics.precision_recall_curve(gt, pred) 45 | plt.plot(recall, precision, color='darkorange', lw=2, label='P-R curve') 46 | plt.axis([0, 1, 0, 1]) 47 | plt.xlabel('Recall') 48 | plt.ylabel('Precision') 49 | plt.title('DL Precision-Recall Curve') 50 | plt.legend(loc="lower right") 51 | F1_score = 2 * precision * recall / (precision + recall) 52 | F1_score = np.nan_to_num(F1_score) 53 | F1_max_score = np.max(F1_score) 54 | print("F1:", F1_max_score) 55 | plt.show() 56 | 57 | 58 | def fast_eval(seq='00', model_file="./model/attention/00.ckpt", desc_file='./data/desc_kitti/00.npy', pair_file='./data/pairs_kitti/neg_100/00.txt', use_l2_dis=False): 59 | net = RINet_attention() 60 | net.load(model_file) 61 | net.to(device=device) 62 | net.eval() 63 | print(net) 64 | desc_o = np.load(desc_file)/50.0 65 | descs_torch = torch.from_numpy(desc_o).to(device) 66 | total_time = 0. 67 | with torch.no_grad(): 68 | torch.cuda.synchronize() 69 | time1 = time.time() 70 | descs = net.gen_feature(descs_torch).cpu().numpy() 71 | torch.cuda.synchronize() 72 | total_time += (time.time()-time1) 73 | print("Feature time:", total_time) 74 | pairs = np.genfromtxt(pair_file, dtype='int32').reshape(-1, 3) 75 | if use_l2_dis: 76 | desc1 = descs[pairs[:, 0]] 77 | desc2 = descs[pairs[:, 1]] 78 | time1 = time.time() 79 | diff = desc1-desc2 80 | diff = 1./np.sum(diff*diff, axis=1) 81 | print("Score time:", time.time()-time1) 82 | diff = diff.reshape(-1, 1) 83 | diff = np.nan_to_num(diff) 84 | label = pairs[:, 2].reshape(-1, 1) 85 | # diff_pos=diff[label>0.9] 86 | # diff_neg=diff[label<0.2] 87 | # plt.plot(list(range(len(diff_pos))),diff_pos,'b.') 88 | # plt.plot(list(range(len(diff_pos),len(diff_pos)+len(diff_neg))),diff_neg,'r.') 89 | # plt.show() 90 | precision, recall, pr_thresholds = metrics.precision_recall_curve( 91 | label, diff) 92 | else: 93 | desc1 = torch.from_numpy(descs[pairs[:, 0]]).to(device) 94 | desc2 = torch.from_numpy(descs[pairs[:, 1]]).to(device) 95 | total_time = 0 96 | with torch.no_grad(): 97 | torch.cuda.synchronize() 98 | time1 = time.time() 99 | scores, _ = net.gen_score(desc1, desc2) 100 | scores = scores.cpu().numpy() 101 | torch.cuda.synchronize() 102 | total_time += (time.time()-time1) 103 | print("Score time:", total_time) 104 | gt = pairs[:, 2].reshape(-1, 1) 105 | np.savetxt("result/"+seq+'.txt', 106 | np.concatenate([scores.reshape(-1, 1), gt.reshape(-1, 1)], axis=1)) 107 | precision, recall, pr_thresholds = metrics.precision_recall_curve( 108 | gt, scores) 109 | F1_score = 2 * precision * recall / (precision + recall) 110 | F1_score = np.nan_to_num(F1_score) 111 | F1_max_score = np.max(F1_score) 112 | print("F1:", F1_max_score) 113 | plt.plot(recall, precision, color='darkorange', lw=2, label='P-R curve') 114 | plt.axis([0, 1, 0, 1]) 115 | plt.xlabel('Recall') 116 | plt.ylabel('Precision') 117 | plt.title('Precision-Recall Curve') 118 | plt.legend(loc="lower right") 119 | plt.show() 120 | 121 | 122 | def recall(seq='00', model_file="./model/attention/00.ckpt", desc_file='./data/desc_kitti/00.npy', pose_file="./data/pose_kitti/00.txt"): 123 | poses = np.genfromtxt(pose_file) 124 | poses = poses[:, [3, 11]] 125 | inner = 2*np.matmul(poses, poses.T) 126 | xx = np.sum(poses**2, 1, keepdims=True) 127 | dis = xx-inner+xx.T 128 | dis = np.sqrt(np.abs(dis)) 129 | id_pos = np.argwhere(dis <= 5) 130 | id_pos = id_pos[id_pos[:, 0]-id_pos[:, 1] > 50] 131 | pos_dict = {} 132 | for v in id_pos: 133 | if v[0] in pos_dict.keys(): 134 | pos_dict[v[0]].append(v[1]) 135 | else: 136 | pos_dict[v[0]] = [v[1]] 137 | descs = np.load(desc_file) 138 | descs /= 50.0 139 | net = RINet_attention() 140 | net.load(model_file) 141 | net.to(device=device) 142 | net.eval() 143 | # print(net) 144 | out_save = [] 145 | recall = np.array([0.]*25) 146 | for v in tqdm(pos_dict.keys()): 147 | candidates = [] 148 | targets = [] 149 | for c in range(0, v-50): 150 | candidates.append(descs[c]) 151 | targets.append(descs[v]) 152 | candidates = np.array(candidates, dtype='float32') 153 | targets = np.array(targets, dtype='float32') 154 | candidates = torch.from_numpy(candidates) 155 | targets = torch.from_numpy(targets) 156 | with torch.no_grad(): 157 | out, _ = net(candidates.to(device=device), 158 | targets.to(device=device)) 159 | out = out.cpu().numpy() 160 | ids = np.argsort(-out) 161 | o = [v] 162 | o += ids[:25].tolist() 163 | out_save.append(o) 164 | for i in range(25): 165 | if ids[i] in pos_dict[v]: 166 | recall[i:] += 1 167 | break 168 | if not os.path.exists('result'): 169 | os.mkdir('result') 170 | np.savetxt(os.path.join('result', seq+'_recall.txt'), out_save, fmt='%d') 171 | recall /= len(pos_dict.keys()) 172 | print(recall) 173 | plt.plot(list(range(1, len(recall)+1)), recall, marker='o') 174 | plt.axis([1, 25, 0, 1]) 175 | plt.xlabel('N top retrievals') 176 | plt.ylabel('Recall (%)') 177 | plt.show() 178 | 179 | 180 | if __name__ == '__main__': 181 | parser = argparse.ArgumentParser() 182 | parser.add_argument('--seq', default='08', 183 | help='Sequence to eval. [default: 08]') 184 | parser.add_argument('--dataset', default="kitti", 185 | help="Dataset (kitti or kitti360). [default: kitti]") 186 | parser.add_argument('--model', default="./model/attention/08.ckpt", 187 | help='Model file. [default: "./model/attention/08.ckpt"]') 188 | parser.add_argument('--desc_file', default='./data/desc_kitti/08.npy', 189 | help='File of descriptors. [default: ./data/desc_kitti/08.npy]') 190 | parser.add_argument('--pairs_file', default='./data/pairs_kitti/neg_100/08.txt', 191 | help='Candidate pairs. [default: ./data/pairs_kitti/neg_100/08.txt]') 192 | parser.add_argument('--pose_file', default="./data/pose_kitti/08.txt", 193 | help='Pose file (eval_type=recall). [default: ./data/pose_kitti/08.txt]') 194 | parser.add_argument('--eval_type', default="f1", 195 | help='Type of evaluation (f1 or recall). [default: f1]') 196 | cfg = parser.parse_args() 197 | if cfg.dataset == "kitti" and cfg.eval_type == "f1": 198 | fast_eval(seq=cfg.seq, model_file=cfg.model, 199 | desc_file=cfg.desc_file, pair_file=cfg.pairs_file) 200 | # eval(seq=cfg.seq,model_file=cfg.model,data_type=cfg.dataset) 201 | elif cfg.dataset == "kitti" and cfg.eval_type == "recall": 202 | recall(cfg.seq, cfg.model, cfg.desc_file, cfg.pose_file) 203 | elif cfg.dataset == "kitti360" and cfg.eval_type == "f1": 204 | eval(seq=cfg.seq, model_file=cfg.model, data_type=cfg.dataset) 205 | else: 206 | print("Error") 207 | -------------------------------------------------------------------------------- /gen_desc/.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "kitti_gen", 9 | "type": "cppdbg", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/bin/kitti_gen", 12 | "args": ["/media/l/yp2/KITTI/odometry/dataset/sequences/00/velodyne/","/media/l/yp2/KITTI/odometry/dataset/sequences/00/labels/","00.bin"], 13 | "stopAtEntry": false, 14 | "cwd": "${workspaceFolder}/bin", 15 | "environment": [], 16 | "externalConsole": false, 17 | "MIMode": "gdb", 18 | "setupCommands": [ 19 | { 20 | "description": "为 gdb 启用整齐打印", 21 | "text": "-enable-pretty-printing", 22 | "ignoreFailures": true 23 | } 24 | ] 25 | }, 26 | { 27 | "name": "test", 28 | "type": "cppdbg", 29 | "request": "launch", 30 | "program": "${workspaceFolder}/bin/test", 31 | "args": ["10"], 32 | "stopAtEntry": false, 33 | "cwd": "${workspaceFolder}/bin", 34 | "environment": [], 35 | "externalConsole": false, 36 | "MIMode": "gdb", 37 | "setupCommands": [ 38 | { 39 | "description": "为 gdb 启用整齐打印", 40 | "text": "-enable-pretty-printing", 41 | "ignoreFailures": true 42 | } 43 | ], 44 | // "preLaunchTask": "C/C++: g++ 生成活动文件", 45 | "miDebuggerPath": "/usr/bin/gdb" 46 | } 47 | ] 48 | } -------------------------------------------------------------------------------- /gen_desc/.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools", 3 | "files.associations": { 4 | "cctype": "cpp", 5 | "clocale": "cpp", 6 | "cmath": "cpp", 7 | "csignal": "cpp", 8 | "cstdarg": "cpp", 9 | "cstddef": "cpp", 10 | "cstdio": "cpp", 11 | "cstdlib": "cpp", 12 | "cstring": "cpp", 13 | "ctime": "cpp", 14 | "cwchar": "cpp", 15 | "cwctype": "cpp", 16 | "array": "cpp", 17 | "atomic": "cpp", 18 | "strstream": "cpp", 19 | "*.tcc": "cpp", 20 | "bitset": "cpp", 21 | "chrono": "cpp", 22 | "complex": "cpp", 23 | "condition_variable": "cpp", 24 | "cstdint": "cpp", 25 | "deque": "cpp", 26 | "forward_list": "cpp", 27 | "list": "cpp", 28 | "unordered_map": "cpp", 29 | "unordered_set": "cpp", 30 | "vector": "cpp", 31 | "exception": "cpp", 32 | "algorithm": "cpp", 33 | "functional": "cpp", 34 | "optional": "cpp", 35 | "ratio": "cpp", 36 | "string_view": "cpp", 37 | "system_error": "cpp", 38 | "tuple": "cpp", 39 | "type_traits": "cpp", 40 | "fstream": "cpp", 41 | "initializer_list": "cpp", 42 | "iomanip": "cpp", 43 | "iosfwd": "cpp", 44 | "iostream": "cpp", 45 | "istream": "cpp", 46 | "limits": "cpp", 47 | "memory": "cpp", 48 | "mutex": "cpp", 49 | "new": "cpp", 50 | "ostream": "cpp", 51 | "numeric": "cpp", 52 | "sstream": "cpp", 53 | "stdexcept": "cpp", 54 | "streambuf": "cpp", 55 | "thread": "cpp", 56 | "cfenv": "cpp", 57 | "cinttypes": "cpp", 58 | "utility": "cpp", 59 | "typeindex": "cpp", 60 | "typeinfo": "cpp", 61 | "bit": "cpp", 62 | "map": "cpp", 63 | "set": "cpp", 64 | "iterator": "cpp", 65 | "memory_resource": "cpp", 66 | "random": "cpp", 67 | "string": "cpp", 68 | "__bit_reference": "cpp", 69 | "__config": "cpp", 70 | "__debug": "cpp", 71 | "__functional_base": "cpp", 72 | "__hash_table": "cpp", 73 | "__locale": "cpp", 74 | "__mutex_base": "cpp", 75 | "__split_buffer": "cpp", 76 | "__tree": "cpp", 77 | "__tuple": "cpp", 78 | "ios": "cpp", 79 | "locale": "cpp", 80 | "queue": "cpp", 81 | "stack": "cpp" 82 | } 83 | } -------------------------------------------------------------------------------- /gen_desc/.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | "tasks": [ 3 | { 4 | "type": "cppbuild", 5 | "label": "C/C++: g++ 生成活动文件", 6 | "command": "/usr/bin/g++", 7 | "args": [ 8 | "-g", 9 | "${file}", 10 | "-o", 11 | "${fileDirname}/${fileBasenameNoExtension}" 12 | ], 13 | "options": { 14 | "cwd": "${workspaceFolder}" 15 | }, 16 | "problemMatcher": [ 17 | "$gcc" 18 | ], 19 | "group": { 20 | "kind": "build", 21 | "isDefault": true 22 | }, 23 | "detail": "调试器生成的任务。" 24 | } 25 | ], 26 | "version": "2.0.0" 27 | } -------------------------------------------------------------------------------- /gen_desc/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(gen_desc) 3 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin) 4 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/lib) 5 | set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) 6 | set(CMAKE_BUILD_TYPE Release) 7 | set(CMAKE_CONFIGURATION_TYPES Debug RelWithDebInfo Release) 8 | set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} -Wall -O3 -march=native") 9 | set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -pg -march=native") 10 | set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -Wall -O3 -march=native") 11 | set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -pg -march=native") 12 | set(CMAKE_CXX_STANDARD 14) 13 | find_package(PCL REQUIRED) 14 | add_subdirectory(gen_cloud) 15 | include_directories( 16 | ${PCL_INCLUDE_DIRS} 17 | ) 18 | add_executable(kitti_gen kitti_gen.cpp) 19 | target_link_libraries(kitti_gen ${PCL_LIBRARIES} gencloud) 20 | -------------------------------------------------------------------------------- /gen_desc/conf/sem_config.yaml: -------------------------------------------------------------------------------- 1 | remap: true #kitti:true;kitti-360:false 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | learning_map: 73 | 0 : 0 # "unlabeled" 74 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 75 | 10: 1 # "car" 76 | 11: 2 # "bicycle" 77 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 78 | 15: 3 # "motorcycle" 79 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 80 | 18: 4 # "truck" 81 | 20: 5 # "other-vehicle" 82 | 30: 6 # "person" 83 | 31: 7 # "bicyclist" 84 | 32: 8 # "motorcyclist" 85 | 40: 9 # "road" 86 | 44: 10 # "parking" 87 | 48: 11 # "sidewalk" 88 | 49: 12 # "other-ground" 89 | 50: 13 # "building" 90 | 51: 14 # "fence" 91 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 92 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 93 | 70: 15 # "vegetation" 94 | 71: 16 # "trunk" 95 | 72: 17 # "terrain" 96 | 80: 18 # "pole" 97 | 81: 19 # "traffic-sign" 98 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 99 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 100 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 101 | 254: 6 # "moving-person" to "person" ------------------------------mapped 102 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 103 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 104 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 105 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 106 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 107 | learning_map_inv: # inverse of previous map 108 | 0: 0 # "unlabeled", and others ignored 109 | 1: 10 # "car" 110 | 2: 11 # "bicycle" 111 | 3: 15 # "motorcycle" 112 | 4: 18 # "truck" 113 | 5: 20 # "other-vehicle" 114 | 6: 30 # "person" 115 | 7: 31 # "bicyclist" 116 | 8: 32 # "motorcyclist" 117 | 9: 40 # "road" 118 | 10: 44 # "parking" 119 | 11: 48 # "sidewalk" 120 | 12: 49 # "other-ground" 121 | 13: 50 # "building" 122 | 14: 51 # "fence" 123 | 15: 70 # "vegetation" 124 | 16: 71 # "trunk" 125 | 17: 72 # "terrain" 126 | 18: 80 # "pole" 127 | 19: 81 # "traffic-sign" 128 | -------------------------------------------------------------------------------- /gen_desc/gen_cloud/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | set(EXECUTABLE_OUTPUT_PATH ${PROJECT_SOURCE_DIR}/bin) 2 | set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_SOURCE_DIR}/lib) 3 | if (NOT PCL_FOUND) 4 | find_package(PCL REQUIRED) 5 | endif () 6 | if (NOT OPENCV_FOUND) 7 | find_package(OpenCV REQUIRED) 8 | endif () 9 | find_package (yaml-cpp REQUIRED) 10 | find_package(Eigen3 REQUIRED) 11 | file(GLOB SRC_LIST *.cpp) 12 | add_library(gencloud SHARED ${SRC_LIST}) 13 | target_include_directories(gencloud PUBLIC 14 | ${PCL_INCLUDE_DIRS} 15 | ${EIGEN3_INCLUDE_DIRS} 16 | ./ 17 | ) 18 | if (NOT YAML_CPP_LIBRARIES) 19 | set(YAML_CPP_LIBRARIES yaml-cpp) 20 | endif () 21 | target_link_libraries(gencloud 22 | ${PCL_LIBRARIES} 23 | ${YAML_CPP_LIBRARIES} 24 | ${OpenCV_LIBS} 25 | ) 26 | 27 | 28 | 29 | 30 | -------------------------------------------------------------------------------- /gen_desc/gen_cloud/genData.cpp: -------------------------------------------------------------------------------- 1 | #include "genData.hpp" 2 | genData::genData(std::string _cloud_path, std::string _label_path, std::shared_ptr _semconf) 3 | { 4 | this->semconf = _semconf; 5 | cloud_path = _cloud_path; 6 | label_path = _label_path; 7 | label_filenames = listDir(label_path, ".label"); 8 | totaldata = label_filenames.size(); 9 | } 10 | std::vector genData::listDir(std::string path, std::string end) 11 | { 12 | DIR *pDir; 13 | struct dirent *ptr; 14 | std::vector files; 15 | if (!(pDir = opendir(path.c_str()))) 16 | { 17 | return files; 18 | } 19 | std::string subFile; 20 | while ((ptr = readdir(pDir)) != 0) 21 | { 22 | subFile = ptr->d_name; 23 | auto rt = subFile.find(end); 24 | if (rt != std::string::npos) 25 | { 26 | files.emplace_back(path + subFile); 27 | } 28 | } 29 | std::sort(files.begin(), files.end()); 30 | return files; 31 | } 32 | 33 | std::vector genData::split(const std::string& str, const std::string& delim){ 34 | std::vector res; 35 | if("" == str) return res; 36 | char * strs = new char[str.length() + 1] ; 37 | strcpy(strs, str.c_str()); 38 | char * d = new char[delim.length() + 1]; 39 | strcpy(d, delim.c_str()); 40 | char *p = strtok(strs, d); 41 | while(p) { 42 | std::string s = p; 43 | res.push_back(s); 44 | p = strtok(NULL, d); 45 | } 46 | return res; 47 | } 48 | 49 | CloudLPtr genData::getCloud() 50 | { 51 | auto cloud_file=split(split(label_filenames[data_id],"/").back(),".")[0]+".bin"; 52 | return getLCloud(cloud_path+cloud_file, label_filenames[data_id]); 53 | } 54 | CloudLPtr genData::getLCloud(std::string file_cloud, std::string file_label) 55 | { 56 | CloudLPtr re_cloud(new CloudL); 57 | std::ifstream in_label(file_label, std::ios::binary); 58 | if (!in_label.is_open()) 59 | { 60 | std::cerr << "No file:" << file_label << std::endl; 61 | exit(-1); 62 | } 63 | in_label.seekg(0, std::ios::end); 64 | uint32_t num_points = in_label.tellg() / sizeof(uint32_t); 65 | in_label.seekg(0, std::ios::beg); 66 | std::vector values_label(num_points); 67 | in_label.read((char *)&values_label[0], num_points * sizeof(uint32_t)); 68 | std::ifstream in_cloud(file_cloud, std::ios::binary); 69 | std::vector values_cloud(4 * num_points); 70 | in_cloud.read((char *)&values_cloud[0], 4 * num_points * sizeof(float)); 71 | re_cloud->points.resize(num_points); 72 | for (uint32_t i = 0; i < num_points; ++i) 73 | { 74 | uint32_t sem_label; 75 | sem_label = semconf->remap(values_label[i]); 76 | re_cloud->points[i].x = values_cloud[4 * i]; 77 | re_cloud->points[i].y = values_cloud[4 * i + 1]; 78 | re_cloud->points[i].z = values_cloud[4 * i + 2]; 79 | re_cloud->points[i].label = sem_label; 80 | } 81 | in_label.close(); 82 | in_cloud.close(); 83 | return re_cloud; 84 | } 85 | 86 | bool genData::getData(CloudLPtr &cloud) 87 | { 88 | if (data_id >= totaldata) 89 | { 90 | return false; 91 | } 92 | if (cloud == NULL) 93 | { 94 | cloud.reset(new CloudL); 95 | } 96 | auto label_file=label_filenames[data_id]; 97 | cloud = getCloud(); 98 | data_id++; 99 | return true; 100 | } 101 | -------------------------------------------------------------------------------- /gen_desc/gen_cloud/genData.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include "semanticConf.hpp" 10 | #include "types.hpp" 11 | class genData 12 | { 13 | private: 14 | CloudLPtr getLCloud(std::string file_cloud, std::string file_label); 15 | CloudLPtr getCloud(); 16 | std::vector listDir(std::string path, std::string end); 17 | std::vector split(const std::string& str, const std::string& delim); 18 | std::string cloud_path,label_path; 19 | std::vector label_filenames; 20 | std::shared_ptr semconf; 21 | std::shared_ptr viewer; 22 | int data_id=0; 23 | public: 24 | EIGEN_MAKE_ALIGNED_OPERATOR_NEW 25 | int totaldata = 0; 26 | genData(std::string cloud_path,std::string label_path,std::shared_ptr semconf); 27 | bool getData(CloudLPtr &cloud); 28 | ~genData()=default; 29 | }; 30 | -------------------------------------------------------------------------------- /gen_desc/gen_cloud/semanticConf.cpp: -------------------------------------------------------------------------------- 1 | #include "semanticConf.hpp" 2 | 3 | semConf::semConf(std::string conf_file) 4 | { 5 | auto data_cfg = YAML::LoadFile(conf_file); 6 | remap_label = data_cfg["remap"].as(); 7 | auto color_map = data_cfg["color_map"]; 8 | learning_map = data_cfg["learning_map"]; 9 | label_map.resize(260); 10 | for (auto it = learning_map.begin(); it != learning_map.end(); ++it) 11 | { 12 | label_map[it->first.as()] = it->second.as(); 13 | } 14 | YAML::const_iterator it; 15 | for (it = color_map.begin(); it != color_map.end(); ++it) 16 | { 17 | // Get label and key 18 | int key = it->first.as(); // <- key 19 | Color color = std::make_tuple( 20 | static_cast(color_map[key][0].as()), 21 | static_cast(color_map[key][1].as()), 22 | static_cast(color_map[key][2].as())); 23 | _color_map[key] = color; 24 | } 25 | auto learning_class = data_cfg["learning_map_inv"]; 26 | for (it = learning_class.begin(); it != learning_class.end(); ++it) 27 | { 28 | int key = it->first.as(); // <- key 29 | _argmax_to_rgb[key] = _color_map[learning_class[key].as()]; 30 | } 31 | } 32 | 33 | int semConf::remap(uint32_t in_label) 34 | { 35 | if (remap_label) 36 | { 37 | return label_map[(int)(in_label & 0x0000ffff)]; 38 | } 39 | else 40 | { 41 | return in_label; 42 | } 43 | } 44 | 45 | Color semConf::getColor(uint32_t label) 46 | { 47 | return _argmax_to_rgb[label]; 48 | } 49 | CloudCPtr semConf::getColorCloud(CloudLPtr &cloud_in) 50 | { 51 | CloudCPtr outcloud(new CloudC); 52 | outcloud->points.resize(cloud_in->points.size()); 53 | for (size_t i = 0; i < outcloud->points.size(); i++) 54 | { 55 | outcloud->points[i].x = cloud_in->points[i].x; 56 | outcloud->points[i].y = cloud_in->points[i].y; 57 | outcloud->points[i].z = cloud_in->points[i].z; 58 | auto color = getColor(cloud_in->points[i].label); 59 | outcloud->points[i].r = std::get<0>(color); 60 | outcloud->points[i].g = std::get<1>(color); 61 | outcloud->points[i].b = std::get<2>(color); 62 | } 63 | outcloud->height = 1; 64 | outcloud->width = outcloud->points.size(); 65 | return outcloud; 66 | } -------------------------------------------------------------------------------- /gen_desc/gen_cloud/semanticConf.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include "types.hpp" 5 | typedef std::tuple Color; 6 | class semConf 7 | { 8 | private: 9 | std::map _color_map, _argmax_to_rgb; 10 | YAML::Node learning_map; 11 | std::vector label_map; 12 | bool remap_label = true; 13 | semConf(); 14 | 15 | public: 16 | semConf(std::string conf_file); 17 | ~semConf() = default; 18 | int remap(uint32_t in_label); 19 | Color getColor(uint32_t label); 20 | CloudCPtr getColorCloud(CloudLPtr &cloud_in); 21 | }; 22 | -------------------------------------------------------------------------------- /gen_desc/gen_cloud/types.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | #include 5 | typedef pcl::PointXYZL PointL; 6 | typedef pcl::PointCloud CloudL; 7 | typedef CloudL::Ptr CloudLPtr; 8 | 9 | typedef pcl::PointXYZ Point; 10 | typedef pcl::PointCloud Cloud; 11 | typedef Cloud::Ptr CloudPtr; 12 | 13 | typedef pcl::PointXYZRGB PointC; 14 | typedef pcl::PointCloud CloudC; 15 | typedef CloudC::Ptr CloudCPtr; 16 | -------------------------------------------------------------------------------- /gen_desc/kitti_gen.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "semanticConf.hpp" 6 | #include "genData.hpp" 7 | 8 | int main(int argc,char** argv){ 9 | if(argc<4){ 10 | std::cout<<"Usage: ./kitti_gen cloud_folder label_folder output_file"< semconf(new semConf("../conf/sem_config.yaml")); 19 | genData gener(cloud_path,label_path, semconf); 20 | CloudLPtr cloud(new CloudL); 21 | int totaldata = gener.totaldata; 22 | int num=0; 23 | pcl::visualization::CloudViewer viewer("cloud"); 24 | std::ofstream fout(argv[3],ios::binary); 25 | while (gener.getData(cloud)){ 26 | std::cout< dis_list; 29 | cloud_out->resize((label_map[19]+1)*360); 30 | dis_list.resize(cloud_out->size(),0.f); 31 | for(auto p:cloud->points){ 32 | if(label_valid[p.label]){ 33 | int angle=std::floor((std::atan2(p.y,p.x)+M_PI)*180./M_PI); 34 | if(angle<0||angle>359){ 35 | continue; 36 | } 37 | float dis=std::sqrt(p.x*p.x+p.y*p.y); 38 | if(dis>50){ 39 | continue; 40 | } 41 | auto& q=cloud_out->at(360*label_map[p.label]+angle); 42 | if(q.label>0){ 43 | float dis_temp=std::sqrt(q.x*q.x+q.y*q.y); 44 | if(use_min[p.label]){ 45 | if(disdis_temp){ 51 | q=p; 52 | dis_list[360*label_map[p.label]+angle]=dis; 53 | } 54 | } 55 | }else{ 56 | q=p; 57 | dis_list[360*label_map[p.label]+angle]=dis; 58 | } 59 | } 60 | } 61 | for(auto dis:dis_list){ 62 | fout.write((char*)(&dis),sizeof(dis)); 63 | } 64 | auto ccloud=semconf->getColorCloud(cloud_out); 65 | viewer.showCloud(ccloud); 66 | ++num; 67 | } 68 | fout.close(); 69 | return 0; 70 | } -------------------------------------------------------------------------------- /gen_pairs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib import pyplot as plt 3 | import sys 4 | import os 5 | 6 | 7 | def run(seq='00'): 8 | pose_file = "/media/l/yp2/KITTI/odometry/dataset/poses/"+seq+".txt" 9 | poses = np.genfromtxt(pose_file) 10 | poses = poses[:, [3, 11]] 11 | inner = 2*np.matmul(poses, poses.T) 12 | xx = np.sum(poses**2, 1, keepdims=True) 13 | dis = xx-inner+xx.T 14 | dis = np.sqrt(np.abs(dis)) 15 | id_pos = np.argwhere(dis < 3) 16 | id_neg = np.argwhere(dis > 20) 17 | # id_pos=id_pos[id_pos[:,0]-id_pos[:,1]>50] 18 | id_neg = id_neg[id_neg[:, 0] > id_neg[:, 1]] 19 | id_pos = np.concatenate( 20 | [id_pos, (id_pos[:, 0]*0+1).reshape(-1, 1)], axis=1) 21 | id_neg = np.concatenate([id_neg, (id_neg[:, 0]*0).reshape(-1, 1)], axis=1) 22 | print(id_pos.shape) 23 | np.savez(seq+'.npz', pos=id_pos, neg=id_neg) 24 | 25 | 26 | def run_sigmoid(seq='00'): 27 | pose_file = "/media/l/yp2/KITTI/odometry/dataset/poses/"+seq+".txt" 28 | poses = np.genfromtxt(pose_file) 29 | poses = poses[:, [3, 11]] 30 | inner = 2*np.matmul(poses, poses.T) 31 | xx = np.sum(poses**2, 1, keepdims=True) 32 | dis = xx-inner+xx.T 33 | dis = np.sqrt(np.abs(dis)) 34 | score = 1.-1./(1+np.exp((10.-dis)/1.5)) 35 | score[dis < 3] = 1 36 | # plt.imshow(score) 37 | # plt.show() 38 | id = np.argwhere(dis > -1) 39 | id = id[id[:, 0] >= id[:, 1]] 40 | label = score[(id[:, 0], id[:, 1])] 41 | label = label.reshape(-1, 1) 42 | out = np.concatenate((id, label), 1) 43 | out_pos = out[out[:, 2] > 0.1] 44 | out_neg = out[out[:, 2] <= 0.1] 45 | print(out_pos.shape) 46 | print(out_neg.shape) 47 | np.savez(seq+'.npz', pos=out_pos, neg=out_neg) 48 | 49 | 50 | if __name__ == '__main__': 51 | seq = "00" 52 | if len(sys.argv) > 1: 53 | seq = sys.argv[1] 54 | run_sigmoid(seq) 55 | # run(seq) 56 | -------------------------------------------------------------------------------- /gen_pairs_kitti360.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | import json 5 | import random 6 | from operator import itemgetter 7 | 8 | 9 | def gen_pairs(sequ, neg_num=1): 10 | folder = "/media/l/yp2/KITTI-360/labels/2013_05_28_drive_" + sequ+"_sync" 11 | pose_file = "/media/l/yp2/KITTI-360/data_poses/2013_05_28_drive_"+sequ+"_sync/poses.txt" 12 | label_files = os.listdir(folder) 13 | label_files.sort() 14 | indexs = [int(v.split(".")[0]) for v in label_files] 15 | posedata = np.genfromtxt(pose_file) 16 | pose_indexs = posedata[:, 0] 17 | pose_indexs = [int(v) for v in pose_indexs] 18 | pose = posedata[:, 1:].reshape(-1, 3, 4)[:, 0:2, 3].tolist() 19 | pose_dict = dict(zip(pose_indexs, pose)) 20 | pose_valid = itemgetter(*indexs)(pose_dict) 21 | pose_valid = np.array(pose_valid) 22 | inner = 2*np.matmul(pose_valid, pose_valid.T) 23 | xx = np.sum(pose_valid**2, 1, keepdims=True) 24 | dis = xx-inner+xx.T 25 | dis = np.sqrt(np.abs(dis)) 26 | score = 1.-1./(1+np.exp((10.-dis)/1.5)) 27 | id = np.argwhere(dis > -1) 28 | id = id[id[:, 0] >= id[:, 1]] 29 | label = score[(id[:, 0], id[:, 1])] 30 | label = label.reshape(-1, 1) 31 | indexs = np.array(indexs, dtype='int') 32 | id[:, 0] = indexs[id[:, 0]] 33 | id[:, 1] = indexs[id[:, 1]] 34 | out = np.concatenate((id, label), 1) 35 | out_pos = out[out[:, 2] > 0.1] 36 | out_neg = out[out[:, 2] <= 0.1] 37 | print(out_pos.shape) 38 | print(out_neg.shape) 39 | np.savez(sequ+'.npz', pos=out_pos, neg=out_neg) 40 | 41 | 42 | if __name__ == '__main__': 43 | sequs = ['0000', '0002', '0003', '0004', 44 | '0005', '0006', '0007', '0009', '0010'] 45 | for sequ in tqdm(sequs): 46 | gen_pairs(sequ, 10) 47 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import random 6 | from matplotlib import pyplot as plt 7 | from torch.nn.modules.pooling import AdaptiveAvgPool1d, AvgPool1d, MaxPool1d 8 | 9 | 10 | class RIConv(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size): 12 | super(RIConv, self).__init__() 13 | self.in_channels = in_channels 14 | self.out_channels = out_channels 15 | self.kernel_size = kernel_size 16 | self.conv = nn.Sequential(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, 17 | kernel_size=kernel_size, stride=1), nn.BatchNorm1d(out_channels), nn.LeakyReLU(negative_slope=0.1)) 18 | 19 | def forward(self, x): 20 | x = F.pad(x, [0, self.kernel_size-1], mode='circular') 21 | out = self.conv(x) 22 | return out 23 | 24 | 25 | class RIDowsampling(nn.Module): 26 | def __init__(self, ratio=2): 27 | super(RIDowsampling, self).__init__() 28 | self.ratio = ratio 29 | 30 | def forward(self, x): 31 | y = x[:, :, list(range(0, x.shape[2], self.ratio))].unsqueeze(1) 32 | for i in range(1, self.ratio): 33 | index = list(range(i, x.shape[2], self.ratio)) 34 | y = torch.cat([y, x[:, :, index].unsqueeze(1)], 1) 35 | norm = torch.norm(torch.norm(y, 1, 2), 1, 2) 36 | idx = torch.argmax(norm, 1) 37 | idx = idx.unsqueeze(1).expand(x.shape[0], self.ratio) 38 | id_matrix = torch.tensor([list(range(self.ratio))]).expand( 39 | x.shape[0], self.ratio).to(device=x.device) 40 | out = y[id_matrix == idx] 41 | return out 42 | 43 | 44 | class RIAttention(nn.Module): 45 | def __init__(self, channels): 46 | super(RIAttention, self).__init__() 47 | self.channels = channels 48 | self.fc = nn.Sequential( 49 | nn.Linear(in_features=self.channels, out_features=self.channels), nn.Sigmoid()) 50 | 51 | def forward(self, x): 52 | x1 = torch.mean(x, 2) 53 | w = self.fc(x1) 54 | w = w.unsqueeze(2) 55 | out = w*x 56 | return out 57 | 58 | 59 | class RINet(nn.Module): 60 | def __init__(self): 61 | super(RINet, self).__init__() 62 | self.conv1 = nn.Sequential(RIConv(in_channels=12, out_channels=12, kernel_size=3), RIConv( 63 | in_channels=12, out_channels=16, kernel_size=3)) 64 | self.conv2 = nn.Sequential(RIDowsampling(3), RIConv( 65 | in_channels=16, out_channels=16, kernel_size=3)) 66 | self.conv3 = nn.Sequential(RIDowsampling(3), RIConv( 67 | in_channels=16, out_channels=32, kernel_size=3)) 68 | self.conv4 = nn.Sequential(RIDowsampling(2), RIConv( 69 | in_channels=32, out_channels=32, kernel_size=3)) 70 | self.conv5 = nn.Sequential(RIDowsampling(2), RIConv( 71 | in_channels=32, out_channels=64, kernel_size=3)) 72 | self.conv6 = nn.Sequential(RIDowsampling(2), RIConv( 73 | in_channels=64, out_channels=128, kernel_size=3)) 74 | self.pool = AdaptiveAvgPool1d(1) 75 | self.linear = nn.Sequential(nn.Linear(in_features=288, out_features=128), nn.LeakyReLU( 76 | negative_slope=0.1), nn.Linear(in_features=128, out_features=1)) 77 | 78 | def forward(self, x, y): 79 | featurexy = self.gen_feature(torch.cat([x, y], dim=0)) 80 | out, diff = self.gen_score( 81 | featurexy[:x.shape[0]], featurexy[x.shape[0]:]) 82 | return out, diff 83 | 84 | def gen_feature(self, xy): 85 | fxy = [] 86 | xy1 = self.conv1(xy) 87 | fxy.append(self.pool(xy1).view(xy.shape[0], -1)) 88 | xy2 = self.conv2(xy1) 89 | fxy.append(self.pool(xy2).view(xy.shape[0], -1)) 90 | xy3 = self.conv3(xy2) 91 | fxy.append(self.pool(xy3).view(xy.shape[0], -1)) 92 | xy4 = self.conv4(xy3) 93 | fxy.append(self.pool(xy4).view(xy.shape[0], -1)) 94 | xy5 = self.conv5(xy4) 95 | fxy.append(self.pool(xy5).view(xy.shape[0], -1)) 96 | xy6 = self.conv6(xy5) 97 | fxy.append(self.pool(xy6).view(xy.shape[0], -1)) 98 | featurexy = torch.cat(fxy, 1) 99 | return featurexy 100 | 101 | def gen_score(self, fx, fy): 102 | diff = torch.abs(fx-fy) 103 | out = self.linear(diff).view(-1) 104 | if not self.training: 105 | out = torch.sigmoid(out) 106 | return out, torch.norm(diff, dim=1) 107 | 108 | def load(self, model_file): 109 | dict = torch.load(model_file) 110 | self.load_state_dict(dict) 111 | 112 | 113 | class RINet_attention(nn.Module): 114 | def __init__(self): 115 | super(RINet_attention, self).__init__() 116 | self.conv1 = nn.Sequential(RIAttention(12), RIConv(in_channels=12, out_channels=12, kernel_size=3), RIAttention( 117 | 12), RIConv(in_channels=12, out_channels=16, kernel_size=3), RIAttention(16)) 118 | self.conv2 = nn.Sequential(RIDowsampling(3), RIConv( 119 | in_channels=16, out_channels=16, kernel_size=3), RIAttention(16)) 120 | self.conv3 = nn.Sequential(RIDowsampling(3), RIConv( 121 | in_channels=16, out_channels=32, kernel_size=3), RIAttention(32)) 122 | self.conv4 = nn.Sequential(RIDowsampling(2), RIConv( 123 | in_channels=32, out_channels=32, kernel_size=3), RIAttention(32)) 124 | self.conv5 = nn.Sequential(RIDowsampling(2), RIConv( 125 | in_channels=32, out_channels=64, kernel_size=3), RIAttention(64)) 126 | self.conv6 = nn.Sequential(RIDowsampling(2), RIConv( 127 | in_channels=64, out_channels=128, kernel_size=3), RIAttention(128)) 128 | self.pool = AdaptiveAvgPool1d(1) 129 | self.linear = nn.Sequential(nn.Linear(in_features=288, out_features=128), nn.LeakyReLU( 130 | negative_slope=0.1), nn.Linear(in_features=128, out_features=1)) 131 | 132 | def forward(self, x, y): 133 | featurexy = self.gen_feature(torch.cat([x, y], dim=0)) 134 | out, diff = self.gen_score( 135 | featurexy[:x.shape[0]], featurexy[x.shape[0]:]) 136 | return out, diff 137 | 138 | def gen_feature(self, xy): 139 | fxy = [] 140 | xy1 = self.conv1(xy) 141 | fxy.append(self.pool(xy1).view(xy.shape[0], -1)) 142 | xy2 = self.conv2(xy1) 143 | fxy.append(self.pool(xy2).view(xy.shape[0], -1)) 144 | xy3 = self.conv3(xy2) 145 | fxy.append(self.pool(xy3).view(xy.shape[0], -1)) 146 | xy4 = self.conv4(xy3) 147 | fxy.append(self.pool(xy4).view(xy.shape[0], -1)) 148 | xy5 = self.conv5(xy4) 149 | fxy.append(self.pool(xy5).view(xy.shape[0], -1)) 150 | xy6 = self.conv6(xy5) 151 | fxy.append(self.pool(xy6).view(xy.shape[0], -1)) 152 | featurexy = torch.cat(fxy, 1) 153 | return featurexy 154 | 155 | def gen_score(self, fx, fy): 156 | diff = torch.abs(fx-fy) 157 | out = self.linear(diff).view(-1) 158 | if not self.training: 159 | out = torch.sigmoid(out) 160 | return out, torch.norm(diff, dim=1) 161 | 162 | def load(self, model_file): 163 | checkpoint = torch.load(model_file) 164 | self.load_state_dict(checkpoint['state_dict']) 165 | 166 | 167 | if __name__ == "__main__": 168 | net = RINet_attention() 169 | net.eval() 170 | a = np.random.random(size=[32, 12, 360]) 171 | b = np.random.random(size=[32, 12, 360]) 172 | c = np.roll(b, random.randint(1, 360), 2) 173 | a = torch.from_numpy(np.array(a, dtype='float32')) 174 | b = torch.from_numpy(np.array(b, dtype='float32')) 175 | c = torch.from_numpy(np.array(c, dtype='float32')) 176 | # out1,_=net(a,c) 177 | # out2,_=net(a,b) 178 | out3, diff = net(c, b) 179 | print(diff) 180 | # print(norm.shape) 181 | # print(out1) 182 | # print(out2) 183 | # print(out3) 184 | -------------------------------------------------------------------------------- /pic/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lilin-hitcrt/RINet/0e28c26e015c50385816b2cbe6549583486fd486/pic/pipeline.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import string 2 | import torch 3 | from net import RINet, RINet_attention 4 | from database import evalDataset_kitti360, SigmoidDataset_kitti360, SigmoidDataset_train, SigmoidDataset_eval 5 | import numpy as np 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from sklearn import metrics 9 | import os 10 | import argparse 11 | # from tensorboardX import SummaryWriter 12 | from torch.utils.tensorboard.writer import SummaryWriter 13 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 14 | 15 | 16 | def train(cfg): 17 | writer = SummaryWriter() 18 | net = RINet_attention() 19 | net.to(device=device) 20 | print(net) 21 | sequs = cfg.all_seqs 22 | sequs.remove(cfg.seq) 23 | train_dataset = SigmoidDataset_train(sequs=sequs, neg_ratio=cfg.neg_ratio, 24 | eva_ratio=cfg.eval_ratio, desc_folder=cfg.desc_folder, gt_folder=cfg.gt_folder) 25 | test_dataset = SigmoidDataset_eval(sequs=sequs, neg_ratio=cfg.neg_ratio, 26 | eva_ratio=cfg.eval_ratio, desc_folder=cfg.desc_folder, gt_folder=cfg.gt_folder) 27 | # train_dataset=SigmoidDataset_kitti360(['0009','0003','0007','0002','0004','0006','0010'],1) 28 | # test_dataset=evalDataset_kitti360('0005') 29 | batch_size = cfg.batch_size 30 | train_loader = DataLoader( 31 | dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=6) 32 | test_loader = DataLoader( 33 | dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=6) 34 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters( 35 | )), lr=cfg.learning_rate, weight_decay=1e-6) 36 | epoch = cfg.max_epoch 37 | starting_epoch = 0 38 | batch_num = 0 39 | if not cfg.model == "": 40 | checkpoint = torch.load(cfg.model) 41 | starting_epoch = checkpoint['epoch'] 42 | batch_num = checkpoint['batch_num'] 43 | net.load_state_dict(checkpoint['state_dict']) 44 | optimizer.load_state_dict(checkpoint['optimizer']) 45 | for i in range(starting_epoch, epoch): 46 | net.train() 47 | pred = [] 48 | gt = [] 49 | for i_batch, sample_batch in tqdm(enumerate(train_loader), total=len(train_loader), desc='Train epoch '+str(i), leave=False): 50 | optimizer.zero_grad() 51 | out, diff = net(sample_batch["desc1"].to( 52 | device=device), sample_batch["desc2"].to(device=device)) 53 | labels = sample_batch["label"].to(device=device) 54 | loss1 = torch.nn.functional.binary_cross_entropy_with_logits( 55 | out, labels) 56 | loss2 = labels*diff*diff+(1-labels)*torch.nn.functional.relu( 57 | cfg.margin-diff)*torch.nn.functional.relu(cfg.margin-diff) 58 | loss2 = torch.mean(loss2) 59 | loss = loss1+loss2 60 | loss.backward() 61 | optimizer.step() 62 | with torch.no_grad(): 63 | writer.add_scalar( 64 | 'total loss', loss.cpu().item(), global_step=batch_num) 65 | writer.add_scalar('loss1', loss1.cpu().item(), 66 | global_step=batch_num) 67 | writer.add_scalar('loss2', loss2.cpu().item(), 68 | global_step=batch_num) 69 | batch_num += 1 70 | outlabel = out.cpu().numpy() 71 | label = sample_batch['label'].cpu().numpy() 72 | mask = (label > 0.9906840407) | (label < 0.0012710163) 73 | label = label[mask] 74 | label[label < 0.5] = 0 75 | label[label > 0.5] = 1 76 | pred.extend(outlabel[mask].tolist()) 77 | gt.extend(label.tolist()) 78 | pred = np.array(pred, dtype='float32') 79 | pred = np.nan_to_num(pred) 80 | gt = np.array(gt, dtype='float32') 81 | precision, recall, _ = metrics.precision_recall_curve(gt, pred) 82 | F1_score = 2 * precision * recall / (precision + recall) 83 | F1_score = np.nan_to_num(F1_score) 84 | trainaccur = np.max(F1_score) 85 | print('Train F1:', trainaccur) 86 | writer.add_scalar('train f1', trainaccur, global_step=i) 87 | lastaccur = test(net=net, dataloader=test_loader) 88 | writer.add_scalar('eval f1', lastaccur, global_step=i) 89 | print('Eval F1:', lastaccur) 90 | torch.save({'epoch': i, 'state_dict': net.state_dict(), 'optimizer': optimizer.state_dict( 91 | ), 'batch_num': batch_num}, os.path.join(cfg.log_dir, cfg.seq, str(i)+'.ckpt')) 92 | 93 | 94 | def test(net, dataloader): 95 | net.eval() 96 | pred = [] 97 | gt = [] 98 | with torch.no_grad(): 99 | for i_batch, sample_batch in tqdm(enumerate(dataloader), total=len(dataloader), desc="Eval", leave=False): 100 | out, _ = net(sample_batch["desc1"].to( 101 | device=device), sample_batch["desc2"].to(device=device)) 102 | out = out.cpu() 103 | outlabel = out 104 | label = sample_batch['label'] 105 | mask = (label > 0.9906840407) | (label < 0.0012710163) 106 | label = label[mask] 107 | label[label < 0.5] = 0 108 | label[label > 0.5] = 1 109 | pred.extend(outlabel[mask]) 110 | gt.extend(label) 111 | pred = np.array(pred, dtype='float32') 112 | gt = np.array(gt, dtype='float32') 113 | pred = np.nan_to_num(pred) 114 | precision, recall, pr_thresholds = metrics.precision_recall_curve( 115 | gt, pred) 116 | F1_score = 2 * precision * recall / (precision + recall) 117 | F1_score = np.nan_to_num(F1_score) 118 | testaccur = np.max(F1_score) 119 | return testaccur 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--log_dir', default='log/', 125 | help='Log dir. [default: log]') 126 | parser.add_argument('--seq', default='00', 127 | help='Sequence to test. [default: 00]') 128 | parser.add_argument('--all_seqs', type=list, default=['00', '01', '02', '03', '04', '05', '06', '07', '08', 129 | '09', '10'], help="All sequence. [default: ['00','01','02','03','04','05','06','07','08','09','10'] ]") 130 | parser.add_argument('--neg_ratio', type=float, default=1, 131 | help='The proportion of negative samples used during training. [default: 1]') 132 | parser.add_argument('--eval_ratio', type=float, default=0.1, 133 | help='Proportion of samples used for validation. [default: 0.1]') 134 | parser.add_argument('--desc_folder', default="./data/desc_kitti", 135 | help='Folder containing descriptors. [default: ./data/desc_kitti]') 136 | parser.add_argument('--gt_folder', default="./data/gt_kitti", 137 | help='Folder containing gt files. [default: ./data/gt_kitti]') 138 | parser.add_argument('--model', default="", 139 | help='Pretrained model. [default: ""]') 140 | parser.add_argument('--max_epoch', type=int, default=20, 141 | help='Epoch to run. [default: 20]') 142 | parser.add_argument('--batch_size', type=int, default=1024, 143 | help='Batch Size during training. [default: 1024]') 144 | parser.add_argument('--learning_rate', type=float, default=0.02, 145 | help='Initial learning rate. [default: 0.02]') 146 | parser.add_argument('--weight_decay', type=float, 147 | default=1e-6, help='Weight decay. [default: 1e-6]') 148 | parser.add_argument('--margin', type=float, default=0.2, 149 | help='Margin used in contrastive loss. [default: 0.2]') 150 | cfg = parser.parse_args() 151 | if(not os.path.exists(os.path.join(cfg.log_dir, cfg.seq))): 152 | os.makedirs(os.path.join(cfg.log_dir, cfg.seq)) 153 | train(cfg) 154 | --------------------------------------------------------------------------------