├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── kppr ├── __init__.py ├── config │ ├── config.yaml │ └── oxford_data.yaml ├── data │ └── .gitkeep ├── datasets │ ├── __init__.py │ └── datasets.py ├── experiments │ └── kppr │ │ └── lightning_logs │ │ └── version_0 │ │ ├── business_evaluation_query.txt │ │ ├── oxford_evaluation_query.txt │ │ ├── residential_evaluation_query.txt │ │ └── university_evaluation_query.txt ├── models │ ├── __init__.py │ ├── blocks.py │ ├── loss.py │ └── models.py ├── scripts │ └── vis_results.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ └── utils.py ├── requirements.txt └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ckpt filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,images 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,images 4 | experiments/ 5 | experiments 6 | data/ 7 | data 8 | *.json 9 | *.pickle 10 | ### Images ### 11 | # JPEG 12 | *.jpg 13 | *.jpeg 14 | *.jpe 15 | *.jif 16 | *.jfif 17 | *.jfi 18 | 19 | # JPEG 2000 20 | *.jp2 21 | *.j2k 22 | *.jpf 23 | *.jpx 24 | *.jpm 25 | *.mj2 26 | 27 | # JPEG XR 28 | *.jxr 29 | *.hdp 30 | *.wdp 31 | 32 | # Graphics Interchange Format 33 | *.gif 34 | 35 | # RAW 36 | *.raw 37 | 38 | # Web P 39 | *.webp 40 | 41 | # Portable Network Graphics 42 | *.png 43 | 44 | # Animated Portable Network Graphics 45 | *.apng 46 | 47 | # Multiple-image Network Graphics 48 | *.mng 49 | 50 | # Tagged Image File Format 51 | *.tiff 52 | *.tif 53 | 54 | # Scalable Vector Graphics 55 | *.svg 56 | *.svgz 57 | 58 | # Portable Document Format 59 | *.pdf 60 | 61 | # X BitMap 62 | *.xbm 63 | 64 | # BMP 65 | *.bmp 66 | *.dib 67 | 68 | # ICO 69 | *.ico 70 | 71 | # 3D Images 72 | *.3dm 73 | *.max 74 | 75 | ### JupyterNotebooks ### 76 | # gitignore template for Jupyter Notebooks 77 | # website: http://jupyter.org/ 78 | 79 | .ipynb_checkpoints 80 | */.ipynb_checkpoints/* 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # Remove previous ipynb_checkpoints 87 | # git rm -r .ipynb_checkpoints/ 88 | 89 | ### Python ### 90 | # Byte-compiled / optimized / DLL files 91 | __pycache__/ 92 | *.py[cod] 93 | *$py.class 94 | 95 | # C extensions 96 | *.so 97 | 98 | # Distribution / packaging 99 | .Python 100 | build/ 101 | develop-eggs/ 102 | dist/ 103 | downloads/ 104 | eggs/ 105 | .eggs/ 106 | lib/ 107 | lib64/ 108 | parts/ 109 | sdist/ 110 | var/ 111 | wheels/ 112 | pip-wheel-metadata/ 113 | share/python-wheels/ 114 | *.egg-info/ 115 | .installed.cfg 116 | *.egg 117 | MANIFEST 118 | 119 | # PyInstaller 120 | # Usually these files are written by a python script from a template 121 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 122 | *.manifest 123 | *.spec 124 | 125 | # Installer logs 126 | pip-log.txt 127 | pip-delete-this-directory.txt 128 | 129 | # Unit test / coverage reports 130 | htmlcov/ 131 | .tox/ 132 | .nox/ 133 | .coverage 134 | .coverage.* 135 | .cache 136 | nosetests.xml 137 | coverage.xml 138 | *.cover 139 | *.py,cover 140 | .hypothesis/ 141 | .pytest_cache/ 142 | pytestdebug.log 143 | 144 | # Translations 145 | *.mo 146 | *.pot 147 | 148 | # Django stuff: 149 | *.log 150 | local_settings.py 151 | db.sqlite3 152 | db.sqlite3-journal 153 | 154 | # Flask stuff: 155 | instance/ 156 | .webassets-cache 157 | 158 | # Scrapy stuff: 159 | .scrapy 160 | 161 | # Sphinx documentation 162 | docs/_build/ 163 | doc/_build/ 164 | 165 | # PyBuilder 166 | target/ 167 | 168 | # Jupyter Notebook 169 | 170 | # IPython 171 | 172 | # pyenv 173 | .python-version 174 | 175 | # pipenv 176 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 177 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 178 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 179 | # install all needed dependencies. 180 | #Pipfile.lock 181 | 182 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 183 | __pypackages__/ 184 | 185 | # Celery stuff 186 | celerybeat-schedule 187 | celerybeat.pid 188 | 189 | # SageMath parsed files 190 | *.sage.py 191 | 192 | # Environments 193 | .env 194 | .venv 195 | env/ 196 | venv/ 197 | ENV/ 198 | env.bak/ 199 | venv.bak/ 200 | pythonenv* 201 | 202 | # Spyder project settings 203 | .spyderproject 204 | .spyproject 205 | 206 | # Rope project settings 207 | .ropeproject 208 | 209 | # mkdocs documentation 210 | /site 211 | 212 | # mypy 213 | .mypy_cache/ 214 | .dmypy.json 215 | dmypy.json 216 | 217 | # Pyre type checker 218 | .pyre/ 219 | 220 | # pytype static type analyzer 221 | .pytype/ 222 | 223 | # profiling data 224 | .prof 225 | 226 | # End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,images 227 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Photogrammetry & Robotics Bonn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KPPR: Exploiting Momentum Contrast for Point Cloud-Based Place Recognition 2 | 3 | ## Installation 4 | 5 | 1. Install all requirements: `pip install -r requirements.txt` 6 | 2. Install this repository: `pip install -e .` 7 | 8 | ## Usage 9 | 10 | ### Training 11 | 12 | All the following commands should be run in `kppr/` 13 | 14 | - Please update the config files (especially the `oxford_data.yaml` to match your data_dir) 15 | - Run the training: `python train.py` 16 | - The output will be saved in `retriever/experiments/{EXPERIMENT_ID}` 17 | 18 | ### Testing 19 | 20 | - Test the model by running: `python test.py --checkpoint {PATH/TO/CHECKPOINT.ckpt} --dataset {DATASET} --base_dir {PATH/TO/DATA}`, where `{DATASET}` is e.g. `oxford` 21 | - The output will be saved in the same folder as the checkpoint 22 | - All the results can be visualized with: `python scripts/vis_results.py` 23 | - The numbers of the paper are in `experiments/kppr/.../oxford_evaluation_query.txt` 24 | - The pre-trained model can be downloaded [here](https://www.ipb.uni-bonn.de/html/projects/kppr/kppr.ckpt) and should be placed into `experiments/kppr/lightning_logs/version_0/`. 25 | 26 | ## Data 27 | 28 | - The pre-compressed point cloud maps can be downloaded [here](https://www.ipb.uni-bonn.de/html/projects/retriever/oxford_compressed.zip) and should be extracted to `data/` (or simply put a symbolic link). 29 | - For the uncompressed point clouds, I refer to [PointNetVLAD](https://github.com/mikacuy/pointnetvlad). 30 | 31 | ## Citation 32 | 33 | If you use this library for any academic work, please cite the original paper. 34 | 35 | ```bibtex 36 | @article{wiesmann2023ral, 37 | author = {L. Wiesmann and L. Nunes and J. Behley and C. Stachniss}, 38 | title = {{KPPR: Exploiting Momentum Contrast for Point Cloud-Based Place Recognition}}, 39 | journal = ral, 40 | volume = {8}, 41 | number = {2}, 42 | pages = {592-599}, 43 | year = 2023, 44 | issn = {2377-3766}, 45 | doi = {10.1109/LRA.2022.3228174}, 46 | } 47 | ``` 48 | -------------------------------------------------------------------------------- /kppr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/kppr/edcd7f585e569cea34a3367efe8a33cba84a0a8e/kppr/__init__.py -------------------------------------------------------------------------------- /kppr/config/config.yaml: -------------------------------------------------------------------------------- 1 | experiment: 2 | id: tmp 3 | 4 | ##Training 5 | train: 6 | n_gpus: 1 7 | max_epoch: 30 8 | lr: 0.00001 9 | 10 | loss: 11 | type: entropy 12 | params: 13 | margin: 0.5 14 | alpha: 0.3 15 | 16 | ##Network 17 | feature_bank: 15000 18 | network_architecture: KPPR 19 | 20 | point_net: 21 | in_dim: 6 22 | out_dim: 256 23 | kpconv: 24 | in_channels: 256 25 | out_channels: 256 26 | radius: 0.05 27 | num_layer: 7 28 | num_neighbors: 16 29 | kernel_size: 3 30 | f_dscale: 2 31 | precompute_weights: True 32 | 33 | aggregation: 34 | method: "vlad" # transformer, perceiver, vlad 35 | vlad: 36 | feature_dim: 256 37 | out_dim: 256 38 | nr_center: 64 -------------------------------------------------------------------------------- /kppr/config/oxford_data.yaml: -------------------------------------------------------------------------------- 1 | dataset_loader: 'OxfordEmbeddingPad' 2 | data_dir: 'oxford_compressed/' 3 | train_queries: 'oxford_compressed/training_queries_baseline.pickle' 4 | test_queries: 'oxford_compressed/test_queries_baseline.pickle' 5 | num_positives: 2 6 | num_negatives: 18 7 | batch_size: 16 8 | num_worker: 6 9 | get_negatives: False 10 | -------------------------------------------------------------------------------- /kppr/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/kppr/edcd7f585e569cea34a3367efe8a33cba84a0a8e/kppr/data/.gitkeep -------------------------------------------------------------------------------- /kppr/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/kppr/edcd7f585e569cea34a3367efe8a33cba84a0a8e/kppr/datasets/__init__.py -------------------------------------------------------------------------------- /kppr/datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | from pytorch_lightning import LightningDataModule 4 | import pickle5 5 | import numpy as np 6 | import os 7 | import random 8 | import tqdm 9 | import kppr.utils.utils as utils 10 | import pickle 11 | from diskcache import FanoutCache 12 | 13 | 14 | def getOxfordDataModule(cfg): 15 | return OxfordDataModule(cfg, data_class=eval(cfg['dataset_loader'])) 16 | 17 | 18 | cache = FanoutCache(directory=utils.CONFIG_DIR+"../data/cache", 19 | shards=64, 20 | timeout=1, 21 | size_limit=3e11) 22 | 23 | ################################################# 24 | #################### Oxford ##################### 25 | ################################################# 26 | 27 | 28 | @cache.memoize(typed=True) 29 | def dict2bool(query_dict): 30 | with open(utils.DATA_DIR + query_dict, 'rb') as handle: 31 | query_dict = pickle5.load(handle) 32 | query_keys = list(query_dict.keys()) 33 | n = len(query_keys) 34 | 35 | files = [] 36 | is_pos = np.zeros([n, n], dtype=bool) 37 | is_neg = np.zeros([n, n], dtype=bool) 38 | for k in tqdm.tqdm(query_keys): 39 | files.append(query_dict[k]['query']) 40 | pos_idx = np.array(query_dict[k]['positives'], dtype=int) 41 | is_pos[k, pos_idx] = True 42 | neg_idx = np.array(query_dict[k]['negatives'], dtype=int) 43 | is_neg[k, neg_idx] = True 44 | return query_keys, files, is_pos, is_neg 45 | 46 | 47 | class OxfordEmbeddingPad(Dataset): 48 | 49 | def __init__(self, query_dict, data_dir, num_pos=2, num_neg=18, return_only_query=False, get_negatives=False): 50 | super(Dataset, self).__init__() 51 | self.query_keys, self.files, self.is_pos, self.is_neg = dict2bool( 52 | query_dict) 53 | self.data_dir = utils.DATA_DIR+data_dir 54 | self.num_pos = num_pos 55 | self.num_neg = num_neg 56 | 57 | self.return_only_query = return_only_query 58 | print(f'Init Dataset done: {len(self.query_keys)}') 59 | self.get_negatives = get_negatives 60 | 61 | def __len__(self): 62 | return len(self.query_keys) 63 | 64 | def load_pc_file(self, filenames, path): 65 | return pad(loadNumpy(filenames[:-4], path)) 66 | 67 | def load_pc_files(self, filenames, path): 68 | p, m = zip(*[pad(loadNumpy(f[:-4], path)) for f in filenames]) 69 | return np.vstack(p), np.vstack(m) 70 | 71 | def __getitem__(self, index): 72 | query, query_mask = self.load_pc_file(self.files[index], self.data_dir) 73 | if self.return_only_query: 74 | return {'query': query, 'query_mask': query_mask, 'query_idx': index} 75 | 76 | pos = np.where(self.is_pos[index, :])[0] 77 | np.random.shuffle(pos) 78 | pos_files = [] 79 | 80 | act_num_pos = len(pos) 81 | pos_idx = [] 82 | if act_num_pos == 0: 83 | return self.__getitem__(index+1) 84 | for i in range(self.num_pos): 85 | pos_files.append( 86 | self.files[pos[i % act_num_pos]]) 87 | pos_idx.append(pos[i % act_num_pos]) 88 | positives, positives_mask = self.load_pc_files( 89 | pos_files, self.data_dir) 90 | neg_idx = self.is_neg[index, :] 91 | 92 | if self.get_negatives: 93 | neg_files = [] 94 | neg_indices = [] 95 | hard_neg = [] # have no hard neg 96 | if(len(hard_neg) == 0): 97 | neg = np.where(self.is_neg[index, :])[0] 98 | np.random.shuffle(neg) 99 | for i in range(self.num_neg): 100 | neg_files.append( 101 | self.files[neg[i]]) 102 | neg_indices.append(neg[i]) 103 | 104 | negatives, negatives_mask = self.load_pc_files( 105 | neg_files, self.data_dir) 106 | neighbors = [] 107 | for pos_i in pos: 108 | neighbors.append(pos_i) 109 | for neg_i in neg_indices: 110 | for pos_i in np.where(self.is_pos[neg_i, :])[0]: 111 | neighbors.append(pos_i) 112 | possible_negs = list(set(self.query_keys)-set(neighbors)) 113 | random.shuffle(possible_negs) 114 | 115 | if(len(possible_negs) == 0): 116 | return [query, positives, negatives, np.array([])] 117 | 118 | neg2, neg2_mask = self.load_pc_file( 119 | self.files[possible_negs[0]], self.data_dir) 120 | return {'query': query, 121 | 'positives': positives, 122 | 'negatives': negatives, 123 | 'neg2': neg2, 124 | 'is_pos': self.is_pos[index, :], 125 | 'query_idx': index, 126 | 'query_mask': query_mask, 127 | 'positives_mask': positives_mask, 128 | 'negatives_mask': negatives_mask, 129 | 'neg2_mask': neg2_mask, 130 | 'neg_idx': neg_idx, 131 | 'pos_idx': np.stack(pos_idx, 0), 132 | } 133 | 134 | else: 135 | return {'query': query, 136 | 'positives': positives, 137 | 'is_pos': self.is_pos[index, :], 138 | 'query_idx': index, 139 | 'query_mask': query_mask, 140 | 'positives_mask': positives_mask, 141 | 'pos_idx': np.stack(pos_idx, 0), 142 | 'neg_idx': neg_idx, 143 | } 144 | 145 | 146 | def loadNumpy(file, path=''): 147 | return np.load(os.path.join(path, file+'.npy'))[np.newaxis, ...].astype('float32') 148 | 149 | 150 | def pad(array, n_points=2000): 151 | """ array [n x m] -> [n_points x m] 152 | """ 153 | if len(array.shape) == 2: 154 | out = np.zeros([n_points, array.shape[-1]], dtype='float32') 155 | mask = np.ones([n_points], dtype=bool) 156 | l = min(n_points, array.shape[-2]) 157 | out[:l, :] = array[:l, :] 158 | mask[:l] = False 159 | return out, mask 160 | else: 161 | size = list(array.shape) 162 | size[-2] = n_points 163 | out = np.zeros(size, dtype='float32') 164 | mask = np.ones(size[:-1], dtype=bool) 165 | l = min(n_points, array.shape[-2]) 166 | out[..., :l, :] = array[..., :l, :] 167 | mask[..., :l] = False 168 | return out, mask 169 | 170 | 171 | class OxfordDataModule(LightningDataModule): 172 | def __init__(self, cfg, data_class): 173 | super().__init__() 174 | self.cfg = cfg 175 | self.data_class = data_class 176 | 177 | def prepare_data(self): 178 | # Augmentations 179 | pass 180 | 181 | def setup(self, stage=None): 182 | # Create datasets 183 | pass 184 | 185 | def train_dataloader(self, batch_size=None): 186 | batch_size = self.cfg['batch_size'] if batch_size is None else batch_size 187 | data_set = self.data_class( 188 | query_dict=self.cfg['train_queries'], 189 | data_dir=self.cfg['data_dir'], 190 | num_pos=self.cfg['num_positives'], 191 | num_neg=self.cfg['num_negatives'], 192 | get_negatives=self.cfg['get_negatives']) 193 | 194 | loader = DataLoader(data_set, batch_size=batch_size, 195 | num_workers=self.cfg['num_worker'], shuffle=True) 196 | return loader 197 | 198 | def val_dataloader(self, batch_size=None): 199 | batch_size = self.cfg['batch_size'] if batch_size is None else batch_size 200 | data_set = self.data_class( 201 | query_dict=self.cfg['test_queries'], 202 | data_dir=self.cfg['data_dir'], 203 | num_pos=self.cfg['num_positives'], 204 | num_neg=self.cfg['num_negatives']) 205 | 206 | loader = DataLoader(data_set, batch_size=batch_size, 207 | num_workers=self.cfg['num_worker']) 208 | return loader 209 | 210 | def test_dataloader(self, batch_size=None): 211 | batch_size = self.cfg['batch_size'] if batch_size is None else batch_size 212 | data_set = self.data_class( 213 | query_dict=self.cfg['test_queries'], 214 | data_dir=self.cfg['data_dir'], 215 | num_pos=self.cfg['num_positives'], 216 | num_neg=self.cfg['num_negatives']) 217 | 218 | loader = DataLoader(data_set, batch_size=batch_size, 219 | num_workers=self.cfg['num_worker']) 220 | return loader 221 | 222 | def val_latent_dataloader(self, batch_size=None): 223 | batch_size = self.cfg['batch_size'] if batch_size is None else batch_size 224 | data_set = self.data_class( 225 | query_dict=self.cfg['test_queries'], 226 | data_dir=self.cfg['data_dir'], 227 | num_pos=self.cfg['num_positives'], 228 | num_neg=self.cfg['num_negatives'], 229 | return_only_query=True) 230 | 231 | loader = DataLoader(data_set, batch_size=batch_size, 232 | num_workers=self.cfg['num_worker']) 233 | return loader 234 | 235 | 236 | ###################### 237 | ####### Test ####### 238 | ###################### 239 | 240 | def splitIndex(i, cum_sum): 241 | smaller = (i < cum_sum[1:]) 242 | bigger = i >= cum_sum[:-1] 243 | ind = np.argwhere(smaller & bigger).flat[0] 244 | return ind, i-cum_sum[ind] 245 | 246 | 247 | class OxfordQueryEmbLoaderPad(Dataset): 248 | 249 | def __init__(self, query_dict, data_dir): 250 | super().__init__() 251 | with open(query_dict, 'rb') as handle: 252 | self.query_dict = pickle.load(handle) 253 | self.nr_scans = [len(d) for d in self.query_dict] 254 | self.acc_cld = np.array(np.cumsum([0]+self.nr_scans)) 255 | self.dict_keys = [list(d.keys()) for d in self.query_dict] 256 | self.data_dir = data_dir 257 | 258 | def __len__(self): 259 | return self.acc_cld[-1] 260 | 261 | def __getitem__(self, index): 262 | seq, scan_idx = splitIndex(index, self.acc_cld) 263 | data = self.query_dict[seq][self.dict_keys[seq][scan_idx]] 264 | # print(data) 265 | data['points'], data['points_mask'] = self.load_pc_file( 266 | data['query'], self.data_dir) 267 | data['seq'] = seq 268 | data['idx'] = scan_idx 269 | return data 270 | 271 | def getTruePositives(self, seq, scan, target_seq): 272 | return self.query_dict[seq][self.dict_keys[seq][scan]][target_seq] 273 | 274 | def getScan(self, seq, scan): 275 | return self.query_dict[seq][self.dict_keys[seq][scan]]['query'] 276 | 277 | def load_pc_file(self, filenames, path): 278 | return pad(loadNumpy(filenames[:-4], path)) 279 | -------------------------------------------------------------------------------- /kppr/experiments/kppr/lightning_logs/version_0/business_evaluation_query.txt: -------------------------------------------------------------------------------- 1 | # #Top 1 percent recall: 2 | # 0.9208599328994751 3 | # #Top k: 4 | 0.874718845 5 | 0.920859933 6 | 0.941030681 7 | 0.952130616 8 | 0.957924962 9 | 0.964210331 10 | 0.968250871 11 | 0.973291337 12 | 0.975560188 13 | 0.978582978 14 | 0.980598271 15 | 0.982868373 16 | 0.984132290 17 | 0.985389888 18 | 0.987151444 19 | 0.987659097 20 | 0.988415420 21 | 0.989921749 22 | 0.990934372 23 | 0.991941929 24 | 0.992698193 25 | 0.992949486 26 | 0.993199468 27 | 0.994457066 28 | 0.994962156 29 | -------------------------------------------------------------------------------- /kppr/experiments/kppr/lightning_logs/version_0/oxford_evaluation_query.txt: -------------------------------------------------------------------------------- 1 | # #Top 1 percent recall: 2 | # 0.9708302617073059 3 | # #Top k: 4 | 0.915332258 5 | 0.948917568 6 | 0.963250577 7 | 0.971215606 8 | 0.975659788 9 | 0.979308546 10 | 0.981837273 11 | 0.983514249 12 | 0.985350549 13 | 0.986378133 14 | 0.987268388 15 | 0.988266587 16 | 0.989138961 17 | 0.989837766 18 | 0.990520656 19 | 0.991104603 20 | 0.991608024 21 | 0.992061913 22 | 0.992559493 23 | 0.992967844 24 | 0.993292272 25 | 0.993493021 26 | 0.993631423 27 | 0.993799448 28 | 0.993970156 29 | -------------------------------------------------------------------------------- /kppr/experiments/kppr/lightning_logs/version_0/residential_evaluation_query.txt: -------------------------------------------------------------------------------- 1 | # #Top 1 percent recall: 2 | # 0.9509729743003845 3 | # #Top k: 4 | 0.881684721 5 | 0.936864853 6 | 0.950972974 7 | 0.962414384 8 | 0.967765749 9 | 0.972477436 10 | 0.975153148 11 | 0.977162182 12 | 0.978504479 13 | 0.981864870 14 | 0.982531548 15 | 0.983882904 16 | 0.985891938 17 | 0.986558557 18 | 0.986558557 19 | 0.987234235 20 | 0.988585591 21 | 0.989936948 22 | 0.989936948 23 | 0.991270244 24 | 0.991270244 25 | 0.991270244 26 | 0.992612660 27 | 0.992612660 28 | 0.993954957 29 | -------------------------------------------------------------------------------- /kppr/experiments/kppr/lightning_logs/version_0/university_evaluation_query.txt: -------------------------------------------------------------------------------- 1 | # #Top 1 percent recall: 2 | # 0.9800532460212708 3 | # #Top k: 4 | 0.932309330 5 | 0.969055831 6 | 0.977480829 7 | 0.980053246 8 | 0.985847473 9 | 0.987770557 10 | 0.992290974 11 | 0.992290974 12 | 0.992932022 13 | 0.993573129 14 | 0.993573129 15 | 0.994863451 16 | 0.994863451 17 | 0.994863451 18 | 0.994863451 19 | 0.994863451 20 | 0.994863451 21 | 0.995504379 22 | 0.996153772 23 | 0.996153772 24 | 0.996153772 25 | 0.996153772 26 | 0.996153772 27 | 0.996794820 28 | 0.996794820 29 | -------------------------------------------------------------------------------- /kppr/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/kppr/edcd7f585e569cea34a3367efe8a33cba84a0a8e/kppr/models/__init__.py -------------------------------------------------------------------------------- /kppr/models/blocks.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | 5 | from torch.nn.init import kaiming_uniform_ 6 | from torch.nn.parameter import Parameter 7 | import numpy as np 8 | import math 9 | import opt_einsum as oe 10 | 11 | 12 | class VladNet(nn.Module): 13 | def __init__(self, feature_dim, nr_center=64, out_dim=256, norm=True): 14 | super().__init__() 15 | self.feature_dim = feature_dim 16 | self.nr_center = nr_center 17 | self.softmax = nn.Softmax(dim=-1) 18 | self.sm_center = nn.Linear(feature_dim, nr_center) 19 | self.center = nn.Parameter(torch.randn( 20 | feature_dim, nr_center) / feature_dim**0.5) 21 | self.feature_proj = nn.Linear( 22 | feature_dim*nr_center, out_dim, bias=False) 23 | 24 | def forward(self, x: torch.Tensor, mask=None): 25 | """Computes the Vlad [..., out] with learned clusters for an input x [...,n, fd] 26 | 27 | Args: 28 | x (torch.Tensor): Features [...,n,fd] 29 | """ 30 | a = self.sm_center( 31 | x) # quadratic distance from each point to each center 32 | a = self.softmax(a) # softmax to get the weights 33 | if mask is not None: 34 | a = a*torch.logical_not(mask) 35 | a_sum = a.sum(dim=-2, keepdim=True) 36 | center_weighted = self.center * a_sum # reweight the centers 37 | 38 | a = a.transpose(-2, -1) 39 | x_weighted = torch.matmul(a, x).transpose(-2, -1) 40 | vlad = (x_weighted - center_weighted) 41 | vlad = F.normalize(vlad, dim=-1, p=2) 42 | shape = x.shape[:-2]+(self.nr_center * self.feature_dim,) 43 | vlad = vlad.reshape(shape) 44 | vlad = F.normalize(vlad, dim=-1, p=2) 45 | vlad = self.feature_proj(vlad) 46 | return vlad 47 | 48 | 49 | class STNkd(nn.Module): 50 | def __init__(self, k=64, norm=True): 51 | super(STNkd, self).__init__() 52 | self.conv1 = nn.Linear(k, 64) 53 | self.conv2 = nn.Linear(64, 128) 54 | self.conv3 = nn.Linear(128, 1024) 55 | self.fc1 = nn.Linear(1024, 512) 56 | self.fc2 = nn.Linear(512, 256) 57 | self.fc3 = nn.Linear(256, k*k) 58 | self.relu = nn.ReLU() 59 | 60 | # exchanged Batchnorm1d by Layernorm 61 | self.bn1 = nn.LayerNorm(64) if norm else nn.Identity() 62 | self.bn2 = nn.LayerNorm(128) if norm else nn.Identity() 63 | self.bn3 = nn.LayerNorm(1024) if norm else nn.Identity() 64 | self.bn4 = nn.LayerNorm(512) if norm else nn.Identity() 65 | self.bn5 = nn.LayerNorm(256) if norm else nn.Identity() 66 | 67 | self.k = k 68 | 69 | def forward(self, x): 70 | x = F.relu(self.bn1(self.conv1(x))) 71 | x = F.relu(self.bn2(self.conv2(x))) 72 | x = F.relu(self.bn3(self.conv3(x))) 73 | x = torch.max(x, -2, keepdim=True)[0] 74 | 75 | x = F.relu(self.bn4(self.fc1(x))) 76 | x = F.relu(self.bn5(self.fc2(x))) 77 | x = self.fc3(x) 78 | 79 | iden = torch.eye(self.k, device=x.device, dtype=x.dtype) 80 | shape = x.shape[:-1]+(1,) 81 | iden = iden.repeat(*shape) 82 | x = x.view(iden.shape) + iden 83 | return x 84 | 85 | 86 | class PointNetFeat(nn.Module): 87 | def __init__(self, in_dim=3, out_dim=1024, feature_transform=False, norm=True): 88 | super(PointNetFeat, self).__init__() 89 | self.stn = STNkd(k=in_dim, norm=norm) 90 | self.conv1 = nn.Linear(in_dim, 64) 91 | self.conv2 = nn.Linear(64, 128) 92 | self.conv3 = nn.Linear(128, out_dim) 93 | self.bn1 = nn.LayerNorm(64) if norm else nn.Identity() 94 | self.bn2 = nn.LayerNorm(128) if norm else nn.Identity() 95 | self.bn3 = nn.LayerNorm(out_dim) if norm else nn.Identity() 96 | self.feature_transform = feature_transform 97 | if self.feature_transform: 98 | self.fstn = STNkd(k=64, norm=norm) 99 | 100 | def forward(self, x): 101 | x = F.relu(self.bn1(self.conv1(x))) 102 | 103 | if self.feature_transform: 104 | trans_feat = self.fstn(x) 105 | x = torch.matmul(x, trans_feat) 106 | else: 107 | trans_feat = None 108 | 109 | x = F.relu(self.bn2(self.conv2(x))) 110 | x = self.bn3(self.conv3(x)) 111 | return x 112 | 113 | ############################ 114 | ###### KPConv stuff ######## 115 | ############################ 116 | # mostly taken from Thomas Hugues repo https://github.com/HuguesTHOMAS/KPConv-PyTorch 117 | 118 | def knn(q_pts, s_pts, k, cosine_sim=False): 119 | if cosine_sim: 120 | sim = torch.einsum('...in,...jn->...ij', q_pts, s_pts) 121 | _, neighb_inds = torch.topk(sim, k, dim=-1, largest=True) 122 | return neighb_inds 123 | 124 | else: 125 | dist = ((q_pts.unsqueeze(-2) - s_pts.unsqueeze(-3))**2).sum(-1) 126 | _, neighb_inds = torch.topk(dist, k, dim=-1, largest=False) 127 | return neighb_inds 128 | 129 | 130 | def vector_gather(vectors: torch.Tensor, indices: torch.Tensor): 131 | """ 132 | Gathers (batched) vectors according to indices. 133 | Arguments: 134 | vectors: Tensor[B, N1, D] 135 | indices: Tensor[B, N2, K] 136 | Returns: 137 | Tensor[B,N2, K, D] 138 | """ 139 | 140 | # src 141 | vectors = vectors.unsqueeze(-2) 142 | shape = list(vectors.shape) 143 | shape[-2] = indices.shape[-1] 144 | vectors = vectors.expand(shape) 145 | 146 | # Do the magic 147 | shape = list(indices.shape)+[vectors.shape[-1]] 148 | indices = indices.unsqueeze(-1).expand(shape) 149 | out = torch.gather(vectors, dim=-3, index=indices) 150 | return out 151 | 152 | 153 | def gather(x, idx, method=2): 154 | """ 155 | implementation of a custom gather operation for faster backwards. 156 | :param x: input with shape [N, D_1, ... D_d] 157 | :param idx: indexing with shape [n_1, ..., n_m] 158 | :param method: Choice of the method 159 | :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d] 160 | """ 161 | 162 | if method == 0: 163 | return x[idx] 164 | elif method == 1: 165 | x = x.unsqueeze(1) 166 | x = x.expand((-1, idx.shape[-1], -1)) 167 | idx = idx.unsqueeze(2) 168 | idx = idx.expand((-1, -1, x.shape[-1])) 169 | return x.gather(0, idx) 170 | elif method == 2: 171 | for i, ni in enumerate(idx.size()[1:]): 172 | x = x.unsqueeze(i+1) 173 | new_s = list(x.size()) 174 | new_s[i+1] = ni 175 | x = x.expand(new_s) 176 | n = len(idx.size()) 177 | for i, di in enumerate(x.size()[n:]): 178 | idx = idx.unsqueeze(i+n) 179 | new_s = list(idx.size()) 180 | new_s[i+n] = di 181 | idx = idx.expand(new_s) 182 | return x.gather(0, idx) 183 | else: 184 | raise ValueError('Unkown method') 185 | 186 | 187 | class KPConv(nn.Module): 188 | 189 | def __init__(self, in_channels, out_channels, radius, kernel_size=3, KP_extent=None, p_dim=3): 190 | """ 191 | Initialize parameters for KPConvDeformable. 192 | :param in_channels: dimension of input features. 193 | :param out_channels: dimension of output features. 194 | :param radius: radius used for kernel point init. 195 | :param kernel_size: Number of kernel points. 196 | :param KP_extent: influence radius of each kernel point. (float), default: None 197 | :param p_dim: dimension of the point space. Default: 3 198 | :param radial: bool if direction independend convolution 199 | :param align_kp: aligns the kernel points along the main directions of the local neighborhood 200 | """ 201 | super(KPConv, self).__init__() 202 | 203 | # Save parameters 204 | self.p_dim = p_dim # 1D for radial convolution 205 | 206 | self.K = kernel_size ** self.p_dim 207 | self.num_kernels = kernel_size 208 | self.in_channels = in_channels 209 | self.out_channels = out_channels 210 | self.radius = radius 211 | self.KP_extent = radius / (kernel_size-1) * \ 212 | self.p_dim**0.5 if KP_extent is None else KP_extent 213 | 214 | # Initialize weights 215 | self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32), 216 | requires_grad=True) 217 | 218 | # Reset parameters 219 | self.reset_parameters() 220 | 221 | # Initialize kernel points 222 | self.kernel_points = self.init_KP() 223 | 224 | return 225 | 226 | def reset_parameters(self): 227 | kaiming_uniform_(self.weights, a=math.sqrt(5)) 228 | return 229 | 230 | def init_KP(self): 231 | """ 232 | Initialize the kernel point positions in a grid 233 | :return: the tensor of kernel points 234 | """ 235 | 236 | K_points_numpy = self.getKernelPoints(self.radius, 237 | self.num_kernels, dim=self.p_dim) 238 | 239 | return Parameter(torch.tensor(K_points_numpy, dtype=torch.float32), 240 | requires_grad=False) 241 | 242 | 243 | def getKernelPoints(self, radius, num_points=3, dim=3): 244 | """[summary] 245 | 246 | Args: 247 | radius (float): radius 248 | num_points (int, optional): Number of kernel points per dimension. Defaults to 3. 249 | 250 | Returns: 251 | [type]: returns num_points^3 kernel points 252 | """ 253 | xyz = np.linspace(-1, 1, num_points) 254 | if dim == 1: 255 | return xyz[:, None]*radius 256 | 257 | points = np.meshgrid(*(dim*[xyz])) 258 | points = [p.flatten() for p in points] 259 | points = np.vstack(points).T 260 | points /= dim**(0.5) # Normalizes to stay in unit sphere 261 | return points*radius 262 | 263 | def precompute_weights(self, q_pts, s_pts, neighb_inds): 264 | s_pts = torch.cat( 265 | (s_pts, torch.zeros_like(s_pts[..., :1, :]) + 1e6), -2) 266 | 267 | # Get neighbor points and features [n_points, n_neighbors, dim/ in_fdim] 268 | if len(neighb_inds.shape) < 3: 269 | neighbors = s_pts[neighb_inds, :] 270 | else: 271 | neighbors = vector_gather(s_pts, neighb_inds) 272 | 273 | # Center every neighborhood [n_points, n_neighbors, dim] 274 | neighbors = neighbors - q_pts.unsqueeze(-2) 275 | 276 | # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] 277 | 278 | kernel_points = self.kernel_points 279 | neighbors.unsqueeze_(-2) 280 | 281 | differences = neighbors - kernel_points 282 | # Get the square distances [n_points, n_neighbors, n_kpoints] 283 | sq_distances = torch.sum(differences ** 2, dim=-1) 284 | # Get Kernel point influences [n_points, n_kpoints, n_neighbors] 285 | all_weights = torch.clamp( 286 | 1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0) 287 | return all_weights 288 | 289 | @staticmethod 290 | def gather_features(neighb_inds, x): 291 | x = torch.cat((x, torch.zeros_like(x[..., :1, :])), -2) 292 | if len(neighb_inds.shape) < 3: 293 | return gather(x, neighb_inds) 294 | else: 295 | return vector_gather(x, neighb_inds) 296 | 297 | def convolution(self, neighb_weights, neighb_x): 298 | fx = oe.contract('...nkl,...nki,...lio->...no', 299 | neighb_weights, neighb_x, self.weights) 300 | return fx 301 | 302 | def forward(self, q_pts, s_pts, neighb_inds, x): 303 | # Add a fake point/feature in the last row for shadow neighbors 304 | s_pts = torch.cat( 305 | (s_pts, torch.zeros_like(s_pts[..., :1, :]) + 1e6), -2) 306 | 307 | # Get neighbor points and features [n_points, n_neighbors, dim/ in_fdim] 308 | x = torch.cat((x, torch.zeros_like(x[..., :1, :])), -2) 309 | if len(neighb_inds.shape) < 3: 310 | neighbors = s_pts[neighb_inds, :] 311 | neighb_x = gather(x, neighb_inds) 312 | else: 313 | neighbors = vector_gather(s_pts, neighb_inds) 314 | neighb_x = vector_gather(x, neighb_inds) 315 | 316 | # Center every neighborhood [n_points, n_neighbors, dim] 317 | neighbors = neighbors - q_pts.unsqueeze(-2) 318 | 319 | # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] 320 | 321 | # print('neighbors', neighbors.shape) 322 | kernel_points = self.kernel_points 323 | neighbors.unsqueeze_(-2) 324 | 325 | # print(kernel_points.shape,neighbors.shape) 326 | differences = neighbors - kernel_points 327 | # Get the square distances [n_points, n_neighbors, n_kpoints] 328 | sq_distances = torch.sum(differences ** 2, dim=-1) 329 | # Get Kernel point influences [n_points, n_kpoints, n_neighbors] 330 | all_weights = torch.clamp( 331 | 1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0) 332 | 333 | # fx = torch.einsum('...nkl,...nki,...lio->...no', 334 | # all_weights, neighb_x, self.weights) 335 | fx = oe.contract('...nkl,...nki,...lio->...no', 336 | all_weights, neighb_x, self.weights) 337 | return fx 338 | 339 | def __repr__(self): 340 | return 'KPConv(radius: {:.2f}, in_feat: {:d}, out_feat: {:d})'.format(self.radius, 341 | self.in_channels, 342 | self.out_channels) 343 | 344 | 345 | class ResnetKPConv(nn.Module): 346 | def __init__(self, in_channels, out_channels, radius, kernel_size=3, KP_extent=None, p_dim=3, f_dscale=2): 347 | super().__init__() 348 | 349 | self.ln1 = nn.LayerNorm(in_channels) 350 | self.relu = nn.LeakyReLU() 351 | 352 | self.kpconv = KPConv(in_channels=in_channels, 353 | out_channels=out_channels//f_dscale, 354 | radius=radius, 355 | kernel_size=kernel_size, 356 | KP_extent=KP_extent, 357 | p_dim=p_dim) 358 | 359 | self.ln2 = nn.LayerNorm(out_channels//f_dscale) 360 | self.lin = nn.Linear(out_channels//f_dscale, out_channels) 361 | 362 | self.in_projection = nn.Identity() if in_channels == out_channels else nn.Linear( 363 | in_channels, out_channels) 364 | 365 | def forward(self, q_pts, s_pts, neighb_inds, x): 366 | xr = self.relu(self.ln1(x)) 367 | xr = self.kpconv(q_pts, s_pts, neighb_inds, x) 368 | xr = self.relu(self.ln2(xr)) 369 | xr = self.lin(xr) 370 | 371 | return self.in_projection(x) + xr 372 | 373 | @torch.no_grad() 374 | def precompute_weights(self, q_pts, s_pts, neighb_inds): 375 | return self.kpconv.precompute_weights(q_pts, s_pts, neighb_inds) 376 | 377 | def fast_forward(self, neighb_weights, neighb_inds, x): 378 | xr = self.relu(self.ln1(x)) 379 | 380 | neighb_x = self.kpconv.gather_features(neighb_inds, xr) 381 | xr = self.kpconv.convolution(neighb_weights, neighb_x) 382 | 383 | xr = self.relu(self.ln2(xr)) 384 | xr = self.lin(xr) 385 | 386 | return self.in_projection(x) + xr 387 | 388 | -------------------------------------------------------------------------------- /kppr/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def repulsion_loss(sim_ap: torch.Tensor, sim_an: torch.Tensor, is_neg, max_dist=1): 6 | sim_an[~is_neg] = -1000 7 | dist_p = (1 - sim_ap.max(-1)[0])/2 8 | dist_n = (1 - sim_an.max(-1)[0])/2 9 | dist = torch.where(dist_p < dist_n, dist_p, dist_n) + 1e-5 10 | dist[dist > max_dist] = 1 11 | return -dist.log().mean() 12 | 13 | class EntropyContrastiveLoss(nn.Module): 14 | def __init__(self, margin, alpha=0) -> None: 15 | super().__init__() 16 | self.margin = margin 17 | self.alpha = alpha 18 | 19 | def forward(self, q, p, n, is_neg): 20 | sp = torch.einsum('...n,...n->...', q, p) 21 | sn = torch.einsum('...n,...n->...', q, n) 22 | l_rep = repulsion_loss(sp, sn, is_neg, max_dist=0.1) * self.alpha 23 | 24 | within_margin = sn > self.margin # only compute loss if similarity is over margin 25 | 26 | l_c = (1-sp).mean(-1) + (sn * is_neg * 27 | within_margin).sum(-1) / (is_neg * within_margin + 1e-5).sum(-1) 28 | l_c = l_c.mean() 29 | losses = {'l_cont': l_c, 'l_rep': l_rep} 30 | loss = l_c + l_rep 31 | return loss, losses 32 | 33 | def feature_bank_recall(q, p, n, top_k: list, is_neg=None): 34 | tk = [k-1 for k in top_k] 35 | sim_pos = torch.einsum('...n,...n->...', q, p) 36 | sim_pos, _ = sim_pos.max(-1) # get the easiest positive 37 | sim_neg = torch.einsum('...n,...n->...', q, n) * \ 38 | is_neg # set similarity of not neg to 0 39 | top_k_neg, i = torch.topk(sim_neg, top_k[-1], dim=-1) 40 | top_k_neg = top_k_neg[..., tk] 41 | recall = sim_pos[..., None] > top_k_neg 42 | recall = recall.float().mean(0) 43 | recall_dict = {k: r for k, r in zip(top_k, recall)} 44 | return recall_dict 45 | 46 | 47 | def feature_bank_recall_nn(q, p, n, top_k: list, is_neg=None): 48 | tk = [k-1 for k in top_k] 49 | sim_pos = 1/((q - p)**2).sum(-1) 50 | sim_pos, _ = sim_pos.max(-1) # get the easiest positive 51 | sim_neg = 1/((q - n.unsqueeze(-3))**2).sum(-1) * \ 52 | is_neg # set similarity of not neg to 0 53 | top_k_neg, i = torch.topk(sim_neg, top_k[-1], dim=-1) 54 | top_k_neg = top_k_neg[..., tk] 55 | recall = sim_pos[..., None] > top_k_neg 56 | recall = recall.float().mean(0) 57 | recall_dict = {k: r for k, r in zip(top_k, recall)} 58 | return recall_dict 59 | -------------------------------------------------------------------------------- /kppr/models/models.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import torch 3 | import torch.optim.lr_scheduler 4 | import torch.nn as nn 5 | import kppr.models.blocks as blocks 6 | import kppr.models.loss as pnloss 7 | from pytorch_lightning.core.lightning import LightningModule 8 | from copy import deepcopy 9 | 10 | import torch.nn.functional as F 11 | 12 | 13 | def getModel(model_name: str, config: dict, weights: str = None): 14 | """Returns the model with the specific model_name. 15 | 16 | Args: 17 | model_name ([str]): Name of the architecture (should be a LightningModule) 18 | config ([dict]): Parameters of the model 19 | weights ([str], optional): [description]. if specified: loads the weights 20 | 21 | Returns: 22 | [type]: [description] 23 | """ 24 | if weights is None: 25 | return eval(model_name)(config) 26 | else: 27 | print(weights) 28 | return eval(model_name).load_from_checkpoint(weights, hparams=config) 29 | 30 | ################################## 31 | # Base Class 32 | ################################## 33 | 34 | 35 | class KPPR(LightningModule): 36 | def __init__(self, hparams: dict, data_module=None): 37 | super().__init__() 38 | hparams['batch_size'] = hparams['data_config']['batch_size'] 39 | self.save_hyperparameters(hparams) 40 | 41 | self.pnloss = pnloss.EntropyContrastiveLoss( 42 | **hparams['loss']['params']) 43 | # Networks 44 | self.q_model = KPPRNet(hparams) 45 | self.k_model = deepcopy(self.q_model) 46 | self.k_model.requires_grad_(False) 47 | self.alpha = 0.999 48 | 49 | self.feature_bank = FeatureBank( 50 | size=hparams['feature_bank'], f_dim=256) 51 | self.feature_bank_val = FeatureBank( 52 | size=5000, f_dim=256) 53 | 54 | self.data_module = data_module 55 | 56 | self.top_k = [1, 5, 10] 57 | 58 | def forward(self, x, m): 59 | return self.q_model(x, m) 60 | 61 | def getLoss(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, is_neg=None): 62 | return self.pnloss(anchor, positive, negative, is_neg=is_neg) 63 | 64 | def training_step(self, batch: dict, batch_idx): 65 | query = self.forward(batch['query'], batch['query_mask']) 66 | with torch.no_grad(): 67 | for param_q, param_k in zip(self.q_model.parameters(), self.k_model.parameters()): 68 | param_k.data = param_k.data * self.alpha + \ 69 | param_q.data * (1. - self.alpha) 70 | 71 | positives = self.k_model( 72 | batch['positives'], batch['positives_mask']) 73 | negatives, is_negative = self.feature_bank.getFeatures( 74 | batch['neg_idx']) 75 | 76 | loss, losses = self.getLoss( 77 | query, positives, negatives, is_neg=is_negative) 78 | self.feature_bank.addFeatures(positives, batch['pos_idx']) 79 | 80 | for k, v in losses.items(): 81 | self.log(f'train/{k}', v) 82 | 83 | self.log('train_loss', loss) 84 | recall_dict = pnloss.feature_bank_recall_nn( 85 | query, positives, negatives, self.top_k, is_negative) 86 | for k, v in recall_dict.items(): 87 | self.log(f'train/recall_{k}', v) 88 | return loss 89 | 90 | def validation_step(self, batch: dict, batch_idx): 91 | query = self.forward(batch['query'], batch['query_mask']) 92 | positives = self.forward(batch['positives'], batch['positives_mask']) 93 | negatives, is_negative = self.feature_bank_val.getFeatures( 94 | batch['neg_idx']) 95 | 96 | loss, losses = self.getLoss( 97 | query, positives, negatives, is_neg=is_negative) 98 | self.feature_bank_val.addFeatures(positives, batch['pos_idx']) 99 | 100 | self.log('val_loss', loss) 101 | 102 | for k, v in losses.items(): 103 | self.log(f'val/{k}', v) 104 | 105 | recall_dict = pnloss.feature_bank_recall_nn( 106 | query, positives, negatives, self.top_k, is_negative) 107 | for k, v in recall_dict.items(): 108 | self.log(f'val/recall_{k}', v) 109 | return loss 110 | 111 | def test_step(self, batch: dict, batch_idx): 112 | assert False, "test with provided test script!" 113 | 114 | def configure_optimizers(self): 115 | lr = self.hparams['train']['lr'] 116 | 117 | optimizer = torch.optim.AdamW( 118 | self.q_model.parameters(), lr=lr) 119 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 120 | optimizer, T_max=self.hparams['train']['max_epoch'], eta_min=self.hparams['train']['lr']/1e3) 121 | return [optimizer], [scheduler] 122 | 123 | def train_dataloader(self): 124 | return self.data_module.train_dataloader(batch_size=self.hparams['batch_size']) 125 | 126 | def val_dataloader(self): 127 | return self.data_module.val_dataloader(batch_size=self.hparams['batch_size']) 128 | 129 | def test_dataloader(self): 130 | return self.data_module.test_dataloader(batch_size=self.hparams['batch_size']) 131 | 132 | 133 | ####################################################################################### 134 | ######################### Perceiver ################################################### 135 | ####################################################################################### 136 | 137 | class KPPRNet(nn.Module): 138 | def __init__(self, hparams) -> None: 139 | super().__init__() 140 | # PointNet 141 | self.pn = blocks.PointNetFeat( 142 | **hparams['point_net']) 143 | 144 | # ConvNet 145 | self.conv = ConvNet(**hparams['kpconv']) 146 | 147 | am = hparams['aggregation']['method'] 148 | params = hparams['aggregation'][am] 149 | self.aggr = blocks.VladNet(**params) 150 | 151 | def forward(self, x, m): 152 | coords = x[..., :3].clone() 153 | m = m.unsqueeze(-1) 154 | x = self.pn(x) 155 | 156 | x = self.conv(coords, x, m) 157 | 158 | x = self.aggr(x, mask=m) 159 | x = F.normalize(x, dim=-1) 160 | return x 161 | 162 | 163 | 164 | 165 | class FeatureBank(nn.Module): 166 | def __init__(self, size, f_dim) -> None: 167 | super().__init__() 168 | self.register_buffer('fb', torch.full([size, f_dim], 1e8)) 169 | self.fb = nn.functional.normalize(self.fb, dim=0) 170 | 171 | self.register_buffer('idx', torch.full([size], -1, dtype=torch.long)) 172 | self.size = size 173 | 174 | @torch.no_grad() 175 | def addFeatures(self, f, idx): 176 | f = f.view(-1, f.shape[-1]) 177 | idx = idx.view(-1) 178 | N = f.shape[0] 179 | 180 | self.fb = torch.roll(self.fb, N, dims=0) 181 | self.fb[:N] = f.detach() 182 | 183 | self.idx = torch.roll(self.idx, N) 184 | self.idx[:N] = idx 185 | 186 | pass 187 | 188 | @torch.no_grad() 189 | def getFeatures(self, idx): 190 | t = self.idx < 0 191 | dx = idx[..., self.idx] 192 | dx[..., t] = False 193 | return self.fb, dx 194 | 195 | 196 | class ConvNet(nn.Module): 197 | def __init__(self, in_channels, 198 | out_channels, 199 | radius, 200 | num_layer=3, 201 | num_neighbors=32, 202 | kernel_size=3, 203 | KP_extent=None, 204 | p_dim=3, 205 | f_dscale=2, 206 | precompute_weights=True): 207 | super().__init__() 208 | in_c = [in_channels] + num_layer*[out_channels] 209 | self.blocks = nn.ModuleList([blocks.ResnetKPConv( 210 | in_channels=in_channels, 211 | out_channels=out_channels, 212 | radius=radius, 213 | kernel_size=kernel_size, 214 | KP_extent=KP_extent, 215 | p_dim=p_dim, 216 | f_dscale=f_dscale) for in_channels in in_c[:num_layer]]) 217 | self.num_neighbors = num_neighbors 218 | self.num_layer = num_layer 219 | self.precompute_weights = precompute_weights 220 | 221 | def forward(self, coords: torch.Tensor, features: torch.Tensor, mask: torch.Tensor = None): 222 | if self.num_layer > 0: 223 | coords = coords.contiguous() 224 | coords[mask.expand_as(coords)] = 1e6 225 | idx = blocks.knn(coords, coords, self.num_neighbors) 226 | 227 | if self.precompute_weights: 228 | neighb_weights = self.blocks[0].precompute_weights( 229 | coords, coords, idx) 230 | 231 | for block in self.blocks: 232 | if self.precompute_weights: 233 | features = block.fast_forward( 234 | neighb_weights, idx, features) 235 | else: 236 | features = block(coords, coords, idx, features) 237 | return features 238 | -------------------------------------------------------------------------------- /kppr/scripts/vis_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import kppr.utils.utils as utils 4 | import glob 5 | 6 | 7 | if __name__ == '__main__': 8 | vis_all = False 9 | 10 | exp_dir = utils.CONFIG_DIR+'../experiments/' 11 | eval_files = glob.glob(exp_dir+'**/oxford_evaluation_query.txt',recursive=True) 12 | plt.figure() 13 | plt.xlabel('in top k') 14 | plt.ylabel('recall') 15 | if vis_all: 16 | for f in eval_files: 17 | label = f.split('/')[-5:-2] 18 | d = np.loadtxt(f) 19 | plt.plot(np.arange(d.shape[0])+1,d,label=label) 20 | else: 21 | exp = {} 22 | for f in eval_files: 23 | label = f.split('/')[-5:-2] 24 | label = label[0] 25 | d = np.loadtxt(f) 26 | 27 | results = {'x':np.arange(d.shape[0])+1,'y':d, 'label':f.split('/')[-5:-2],'mean':d.mean()} 28 | if label in exp: 29 | if exp[label]['mean'] < results['mean']: 30 | exp[label]=results 31 | else: 32 | exp[label]= results 33 | for k in exp: 34 | plt.plot(exp[k]['x'],exp[k]['y'],label=exp[k]['label']) 35 | plt.legend() 36 | plt.grid() 37 | plt.show() -------------------------------------------------------------------------------- /kppr/test.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | import tqdm 4 | import kppr.datasets.datasets as datasets 5 | import kppr.models.models as models 6 | from torch.utils.data import DataLoader 7 | from kppr.utils import utils 8 | import numpy as np 9 | import os 10 | from kppr.models import blocks 11 | 12 | def computeLatentVectors(data_loader, model): 13 | sequences = [] 14 | latents = [] 15 | with torch.no_grad(): 16 | for batch in tqdm.tqdm(data_loader): 17 | latents.append( 18 | model(batch['points'].cuda(), batch['points_mask'].cuda())) 19 | sequences.append([batch['seq'].item(), batch['idx'].item()]) 20 | sequences = torch.tensor(sequences) 21 | latents = torch.stack(latents).squeeze() 22 | return sequences, latents 23 | 24 | 25 | @click.command() 26 | # Add your options here 27 | @click.option('--checkpoint', 28 | '-ckpt', 29 | type=str, 30 | help='path to checkpoint file (.ckpt)', 31 | required=True) 32 | @click.option('--dataset', 33 | '-d', 34 | type=str, 35 | default='oxford', 36 | help='dataset', 37 | required=False) 38 | @click.option('--base_dir', 39 | '-b', 40 | type=str, 41 | default=utils.DATA_DIR+'oxford_compressed', 42 | help='dataset', 43 | required=True) 44 | def main(checkpoint, dataset, base_dir): 45 | model = models.KPPR.load_from_checkpoint(checkpoint_path=checkpoint).cuda() 46 | model.eval() 47 | 48 | database_file = f'{base_dir}/{dataset}_evaluation_database.pickle' 49 | db_dataset = datasets.OxfordQueryEmbLoaderPad( 50 | query_dict=database_file, data_dir=base_dir) 51 | database_loader = DataLoader(dataset=db_dataset, batch_size=1, 52 | shuffle=False, num_workers=0,) 53 | db_seq, db_latents = computeLatentVectors(database_loader, model) 54 | print(db_latents.shape, db_seq.shape) 55 | 56 | query_file = f'{base_dir}/{dataset}_evaluation_query.pickle' 57 | q_dataset = datasets.OxfordQueryEmbLoaderPad( 58 | query_dict=query_file, data_dir=base_dir) 59 | query_loader = DataLoader(dataset=q_dataset, batch_size=1, 60 | shuffle=False, num_workers=0) 61 | query_seq, query_latents = computeLatentVectors(query_loader, model) 62 | 63 | unique_seq = query_seq[:, 0].unique().tolist() 64 | print(unique_seq) 65 | top_k = 25 66 | top_k_recall = torch.zeros(top_k, device=query_latents.device) 67 | top_one_pct_recall = 0 68 | counter = 0 69 | for seq_i in tqdm.tqdm(unique_seq): # get all queries from seq i 70 | q = query_latents[query_seq[:, 0] == seq_i] 71 | for seq_j in (unique_seq): # get all latents for each other seq 72 | if not (seq_i == seq_j): 73 | db = db_latents[db_seq[:, 0] == seq_j] 74 | one_pct_idx = max(int(np.round(db.shape[0]/100.0)), 1)-1 75 | # query_nn = keops.kNN_Keops(db, q, K=25, metric='euclidean') 76 | query_nn = blocks.knn(q, db, k=25) 77 | # Average over each sequence 78 | in_top_k = torch.zeros(top_k, device=query_latents.device) 79 | in_one_pct = 0 80 | 81 | num_pos = 0 82 | # for each query compute metrics 83 | for scan_id in range(q.shape[0]): 84 | true_positives = q_dataset.getTruePositives( 85 | seq_i, scan_id, target_seq=seq_j) 86 | true_positives = torch.tensor( 87 | true_positives, dtype=query_nn.dtype, device=query_nn.device) 88 | if len(true_positives) < 1: 89 | continue 90 | num_pos += 1 91 | is_in = query_nn[scan_id:scan_id+1, 92 | :] == true_positives.unsqueeze(1) 93 | is_in = is_in.any(axis=0).cumsum(0).clamp(max=1) 94 | in_top_k += is_in 95 | in_one_pct += is_in[one_pct_idx] 96 | top_k_recall += in_top_k / num_pos 97 | top_one_pct_recall += in_one_pct/num_pos 98 | counter += 1 99 | 100 | top_k_recall /= counter 101 | top_one_pct_recall /= counter 102 | print(top_k_recall) 103 | print(top_one_pct_recall) 104 | out_f = os.path.dirname(os.path.abspath(checkpoint)) 105 | out_name = os.path.join(out_f, query_file.split('/') 106 | [-1].split('.')[0]+f'.txt') 107 | np.savetxt(out_name, top_k_recall.cpu().numpy(), fmt='%.9f', 108 | header=f'#Top 1 percent recall:\n{top_one_pct_recall}\n#Top k:') 109 | if __name__ == "__main__": 110 | with torch.no_grad(): 111 | main() 112 | -------------------------------------------------------------------------------- /kppr/train.py: -------------------------------------------------------------------------------- 1 | import click 2 | from os.path import join, dirname, abspath 3 | import subprocess 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning import loggers as pl_loggers 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, ModelSummary 7 | import yaml 8 | import kppr.datasets.datasets as datasets 9 | import kppr.models.models as models 10 | 11 | 12 | @click.command() 13 | # Add your options here 14 | @click.option('--config', 15 | '-c', 16 | type=str, 17 | help='path to the config file (.yaml)', 18 | default=join(dirname(abspath(__file__)), 'config/config.yaml')) 19 | @click.option('--data_config', 20 | '-dc', 21 | type=str, 22 | help='path to the config file (.yaml)', 23 | default=join(dirname(abspath(__file__)), 'config/oxford_data.yaml')) 24 | @click.option('--weights', 25 | '-w', 26 | type=str, 27 | help='path to pretrained weights (.ckpt). Use this flag if you just want to load the weights from the checkpoint file without resuming training.', 28 | default=None) 29 | @click.option('--checkpoint', 30 | '-ckpt', 31 | type=str, 32 | help='path to checkpoint file (.ckpt) to resume training.', 33 | default=None) 34 | def main(config, data_config, weights, checkpoint): 35 | cfg = yaml.safe_load(open(config)) 36 | data_cfg = yaml.safe_load(open(data_config)) 37 | cfg['git_commit_version'] = str(subprocess.check_output( 38 | ['git', 'rev-parse', '--short', 'HEAD']).strip()) 39 | cfg['data_config'] = data_cfg 40 | print(f"Start experiment {cfg['experiment']['id']}") 41 | # Load data and model 42 | data = datasets.getOxfordDataModule(data_cfg) 43 | 44 | model = models.getModel( 45 | cfg['network_architecture'], config=cfg, 46 | weights=weights) 47 | lr_monitor = LearningRateMonitor(logging_interval='step') 48 | checkpoint_saver = ModelCheckpoint(monitor='val/recall_1', 49 | filename='best_{epoch:02d}', 50 | mode='max', 51 | save_last=True) 52 | 53 | tb_logger = pl_loggers.TensorBoardLogger('experiments/'+cfg['experiment']['id'], 54 | default_hp_metric=False) 55 | 56 | print('nr gpus:', cfg['train']['n_gpus']) 57 | # Setup trainer 58 | trainer = Trainer(gpus=cfg['train']['n_gpus'], 59 | logger=tb_logger, 60 | resume_from_checkpoint=checkpoint, 61 | gradient_clip_val=0.2, 62 | max_epochs=cfg['train']['max_epoch'], 63 | callbacks=[lr_monitor, checkpoint_saver, ModelSummary(max_depth=2)],) 64 | 65 | # Train! 66 | trainer.fit(model,data) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /kppr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRBonn/kppr/edcd7f585e569cea34a3367efe8a33cba84a0a8e/kppr/utils/__init__.py -------------------------------------------------------------------------------- /kppr/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | CONFIG_DIR = os.path.dirname( 5 | os.path.realpath(__file__))+'/../config/' 6 | DATA_DIR = os.path.dirname( 7 | os.path.realpath(__file__))+'/../data/' 8 | 9 | 10 | def knn(q_pts, s_pts, k, cosine_sim=False): 11 | if cosine_sim: 12 | sim = torch.einsum('...in,...jn->...ij', q_pts, s_pts) 13 | _, neighb_inds = torch.topk(sim, k, dim=-1, largest=True) 14 | return neighb_inds 15 | else: 16 | dist = ((q_pts.unsqueeze(-2) - s_pts.unsqueeze(-3))**2).sum(-1) 17 | _, neighb_inds = torch.topk(dist, k, dim=-1, largest=False) 18 | return neighb_inds 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Click==7.0 2 | diskcache==5.4.0 3 | pytorch_lightning==1.8.4.post0 4 | matplotlib==3.1.2 5 | tqdm==4.64.0 6 | opt_einsum==3.3.0 7 | numpy==1.17.4 8 | torch==1.12.1+cu116 9 | pickle5==0.0.11 10 | PyYAML==6.0 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | pkg_name = 'kppr' 4 | setup(name=pkg_name, version='1.0', packages=find_packages()) 5 | --------------------------------------------------------------------------------