├── LICENSE ├── README.md ├── __pycache__ ├── buffer.cpython-37.pyc ├── loss.cpython-37.pyc └── utils.cpython-37.pyc ├── buffer.py ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── cambridge_landmarks.cpython-37.pyc │ ├── seven_scenes.cpython-37.pyc │ ├── twelve_scenes.cpython-37.pyc │ ├── twelve_scenes_copy.cpython-37.pyc │ └── utils.cpython-37.pyc ├── seven_scenes.py ├── twelve_scenes.py └── utils.py ├── environment.yml ├── eval.py ├── images └── pipeline1.jpg ├── loss.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── hscnet.cpython-37.pyc │ └── scrnet.cpython-37.pyc ├── hscnet.py └── scrnet.py ├── pnpransac ├── build │ └── temp.linux-x86_64-3.7 │ │ └── pnpransacpy.o ├── pnpransac.cpp ├── pnpransac.cpython-37m-x86_64-linux-gnu.so ├── pnpransac.h ├── pnpransacpy.cpp ├── pnpransacpy.pxd ├── pnpransacpy.pyx └── setup.py ├── train_CL_preds.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 AaltoVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICCV 2021] [Continual Learning for Image-Based Camera Localization](https://arxiv.org/pdf/2108.09112.pdf) 2 | [Shuzhe Wang](https://ffrivera0.github.io/)\*, [Zakaria Laskar](https://scholar.google.com/citations?hl=en&user=kd3XIUkAAAAJ&view_op=list_works&sortby=pubdate)\*, [Iaroslav Melekhov](https://imelekhov.com/), [Xiaotian Li](https://scholar.google.com/citations?user=lht2z_IAAAAJ&hl=en), [Juho Kannala](https://users.aalto.fi/~kannalj1/) 3 | 4 | \* Equal Contribution 5 | 6 | This is the PyTorch implementation of our paper, [Continual Learning for Image-Based Camera Localization](https://arxiv.org/pdf/2108.09112.pdf). In this paper, we approach the problem of visual localization in a continual learning setup – whereby the model is trained on scenes in an incremental manner. Under this setting, all the scenes are not available during training but encountered sequentially. The results show that our method is memory efficient and has only slightly performance degradation compared to joint training. 7 | 8 | ![pipeline](images/pipeline1.jpg) 9 | 10 | ## Setup 11 | The environment is similar to [HSCNet](https://github.com/AaltoVision/hscnet). Python3 and the following packages are required: 12 | ``` 13 | cython 14 | numpy 15 | pytorch 16 | opencv 17 | tqdm 18 | imgaug 19 | ``` 20 | 21 | It is recommended to use a conda environment: 22 | 1. Install anaconda or miniconda. 23 | 2. Create the environment: `conda env create -f environment.yml`. 24 | 3. Activate the environment: `conda activate hscnet`. 25 | 26 | To run the evaluation script, you will need to build the cython module: 27 | 28 | ```bash 29 | cd ./pnpransac 30 | python setup.py build_ext --inplace 31 | ``` 32 | 33 | 34 | 35 | ## Data 36 | 37 | We run our experiments on [7-Scenes](https://www.microsoft.com/en-us/research/project/rgb-d-dataset-7-scenes/), [12-Scenes](https://graphics.stanford.edu/projects/reloc/) and also 19-Scenes by combining the former scenes. To train/evaluate our code, you need to download the datasets from their website. We also need an additional [data package](https://drive.google.com/drive/folders/19KyyoYy-2Nnc2Vu6yXGMvSJvo9Yfgfrh?usp=sharing) which contains other necessary files for reproducing our results. 38 | 39 | 40 | 41 | ## Evaluation 42 | 43 | The trained models for ***Buff-CS*** sampling method with buffer size 256 and 1024 can be download [here](https://drive.google.com/drive/folders/1jYKRicvyq5-Jb81-s9NbcfFj7MqtMmAJ?usp=sharing). We will provide the model of other sampling methods soon. 44 | 45 | To evaluate our method: 46 | 47 | ```bash 48 | python eval.py \ 49 | --model hscnet \ 50 | --dataset [i7S|i12S|i19S] \ 51 | --scene scene_name \ # for i12S, use apt1/kitchen,apt1/living ... 52 | --checkpoint /path/to/saved/model/ \ 53 | --data_path /path/to/data/ 54 | ``` 55 | 56 | 57 | 58 | ## Training 59 | 60 | You can train our continual setting network by running the following command: 61 | 62 | ```bash 63 | python train.py \ 64 | --model hscnet \ 65 | --dataset [i7S|i12S|i19S] \ 66 | --n_iter number_of_training_iterations # default 30000 67 | --data_path /path/to/data/ 68 | --dense_pred [False|True] # False: train without dense representation 69 | --exp_name #set name to you experiments 70 | --buffer_size [128|256|512|1024] 71 | --sampling [Random|Imgbal|CoverageS] 72 | ``` 73 | 74 | 75 | 76 | ## Acknowledgements 77 | 78 | We appreciate the previous open-source repositories [DSAC++](https://github.com/vislearn/LessMore) and [HSCNet](https://github.com/AaltoVision/hscnet). 79 | 80 | 81 | 82 | ## License 83 | 84 | Copyright (c) 2021 AaltoVision. 85 | This code is released under the [MIT License](LICENSE). 86 | 87 | 88 | 89 | ## Citation 90 | 91 | Please consider citing our papers if you find this code useful for your research: 92 | 93 | ``` 94 | @inproceedings{wang2021continual, 95 | title={Continual learning for image-based camera localization}, 96 | author={Wang, Shuzhe and Laskar, Zakaria and Melekhov, Iaroslav and Li, Xiaotian and Kannala, Juho}, 97 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 98 | pages={3252--3262}, 99 | year={2021} 100 | } 101 | ``` 102 | 103 | 104 | -------------------------------------------------------------------------------- /__pycache__/buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/__pycache__/buffer.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | import pickle 6 | import pdb 7 | from collections import ChainMap 8 | # random.seed(521) 9 | # deletelist = ['heads', 'office', 'pumpkin', 'redkitchen', 'stairs'] 10 | 11 | def get_img2subscenes(data_path, dataset): 12 | #### load them in a more efficient way 13 | if dataset == '7Scenes': 14 | scenes = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'] 15 | if dataset == '12Scenes': 16 | scenes = ['apt1/kitchen','apt1/living','apt2/bed', 17 | 'apt2/kitchen','apt2/living','apt2/luke','office1/gates362', 18 | 'office1/gates381','office1/lounge','office1/manolis', 19 | 'office2/5a','office2/5b'] 20 | if dataset == '19Scenes': 21 | scenes = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs', 'apt1/kitchen','apt1/living','apt2/bed', 22 | 'apt2/kitchen','apt2/living','apt2/luke','office1/gates362', 23 | 'office1/gates381','office1/lounge','office1/manolis', 24 | 'office2/5a','office2/5b'] 25 | for i in range(len(scenes)): 26 | scenes[i] = scenes[i].replace('/', '_') 27 | 28 | 29 | for i in range(len(scenes)): 30 | with open(data_path +'/{}/img2subscene/{}.pkl'.format(dataset, scenes[i]),'rb') as data: 31 | scenes[i] = pickle.load(data) 32 | 33 | img2subscenes = dict(ChainMap(*scenes)) 34 | 35 | return img2subscenes 36 | 37 | class createBuffer(): 38 | def __init__(self, buffer_size=256, data_path='', exp='exp_name', dataset = 'i7S'): 39 | 40 | if dataset == 'i7S': 41 | self.dataset = '7Scenes' 42 | if dataset == 'i12S': 43 | self.dataset = '12Scenes' 44 | if dataset == 'i19S': 45 | self.dataset = '19Scenes' 46 | 47 | 48 | self.buffer_scenes = [] 49 | self.buffer_size = buffer_size 50 | # with open('{}/{}/train.txt'.format(data_path, self.dataset), 'r') as f: 51 | # self.frames = f.readlines() 52 | self.buffer_list = [] 53 | self.buffer_fn = '{}/{}/train_buffer_{}.txt'.format(data_path,self.dataset, exp) 54 | print(self.buffer_fn) 55 | self.N = 0 56 | self.buff_id = 100000000 57 | self.buffer_class = dict() 58 | self.img2sub_bin = [] 59 | self.largest = 'awszfdeasqssq' 60 | self.dense_pred_path = '{}/{}/dense_pred_{}'.format(data_path,self.dataset, exp) 61 | if not os.path.exists(self.dense_pred_path): 62 | os.mkdir(self.dense_pred_path) 63 | self.img2subscenes = get_img2subscenes(data_path, self.dataset) 64 | 65 | 66 | def add_buffer_dense(self, frame, preds=None): 67 | 68 | frame = frame[0] 69 | self.N += 1 70 | if len(self.buffer_list) < self.buffer_size: 71 | self.buffer_list.append(frame) 72 | buff_id = len(self.buffer_list) - 1 73 | else: 74 | s = int(random.random() * self.N) 75 | if s < self.buffer_size: 76 | self.buffer_list[s] = frame 77 | buff_id = s 78 | else: 79 | self.buffer_list[0] = frame 80 | buff_id = 0 81 | # save dense file 82 | if preds is not None: 83 | coord_pred = preds[0].squeeze().data.cpu().numpy() 84 | lbl_1_onehot = preds[1].squeeze().data.cpu().numpy() 85 | lbl_2_onehot = preds[2].squeeze().data.cpu().numpy() 86 | pkl_save = {'coord_pred':coord_pred, 'lbl_1': lbl_1_onehot, 'lbl_2': lbl_2_onehot} 87 | pkl_file = open('{}/dense_pred_{}.pkl'.format(self.dense_pred_path, buff_id), 'wb') 88 | pickle.dump(pkl_save, pkl_file) 89 | 90 | # dump to buffer file 91 | buffer_training = open(self.buffer_fn, 'w+') 92 | for item in self.buffer_list: 93 | buffer_training.write('{}\n'.format(item)) 94 | 95 | def add_imb_buffer(self, fname, preds=None, nc=None): 96 | # frame: scene_name seq_num image_name 97 | frame = fname[0] 98 | scene = frame.split(' ')[0] 99 | 100 | # if 1st sample then init scene dict to contain indices of occupied positions 101 | if nc == 0: 102 | self.buffer_class[scene] = [] 103 | 104 | # if buffer available store 105 | if len(self.buffer_list) < self.buffer_size: 106 | self.buffer_list.append(frame) 107 | buff_id = len(self.buffer_list)-1 108 | self.buffer_class[scene].append(buff_id) 109 | else: 110 | # if current scene is not the largest 111 | if scene != self.largest: 112 | # get the indices alloted to largest class and sample 113 | largest_inst = self.buffer_class[self.largest] 114 | buff_id = random.sample(largest_inst,1)[0] 115 | 116 | self.buffer_list[buff_id] = frame 117 | self.buffer_class[scene].append(buff_id) # add buff_id to scene dict 118 | self.buffer_class[self.largest].remove(buff_id) # remove buff_id from largest_scene dict 119 | else: 120 | mc = len(self.buffer_class[scene]) 121 | uid = random.uniform(0,1) 122 | if uid <= mc/nc: 123 | self_inst = self.buffer_class[scene] 124 | buff_id = random.sample(self_inst,1)[0] 125 | 126 | self.buffer_list[buff_id] = frame 127 | # no need to add or remove as it is self-substituting 128 | 129 | else: 130 | return 0 131 | 132 | 133 | # TODO online processing 134 | # find largest class 135 | scene_num = [] 136 | scene_name = [] 137 | for sc in self.buffer_class: 138 | #print(self.buffer_class[sc]) 139 | scene_num.append(len(self.buffer_class[sc])) 140 | scene_name.append(sc) 141 | 142 | self.largest = scene_name[np.argsort(-np.array(scene_num))[0]] 143 | ''' 144 | 145 | self.N += 1 146 | if len(self.buffer_list) < self.buffer_size: 147 | self.buffer_list.append(frame) 148 | buff_id = len(self.buffer_list)-1 149 | self.buffer_class[scene].append(buff_id) 150 | else: 151 | s = int(random.random() * self.N) 152 | if s < self.buffer_size: 153 | if scene != self.largest: 154 | largest_inst = self.buffer_class[self.largest] 155 | buff_id = random.sample(largest_inst,1)[0] 156 | self.buffer_list[buff_id] = frame 157 | self.buffer_class[scene].append(buff_id) 158 | self.buffer_class[self.largest].remove(buff_id) 159 | else: 160 | # no need to add or remove as it is self-substituting 161 | self_inst = self.buffer_class[scene] 162 | buff_id = random.sample(self_inst,1)[0] 163 | self.buffer_list[buff_id] = frame 164 | 165 | else: 166 | self.buffer_list[0] = frame 167 | buff_id = 0 168 | 169 | 170 | scene_coverscore = [] 171 | scene_name = [] 172 | for sc in self.buffer_class: 173 | score_lists = [] 174 | for id in self.buffer_class[sc]: 175 | frame_name = self.buffer_list[id] 176 | score_list= list(self.img2subscenes[frame_name]) 177 | score_lists += score_list 178 | score_lists = set(score_lists) 179 | coverscore = len(score_lists) / 625 180 | scene_coverscore.append(coverscore) 181 | scene_name.append(sc) 182 | # print(scene_coverscore) 183 | self.largest = scene_name[np.argsort(-np.array(scene_coverscore))[0]] 184 | # print(self.largest) 185 | ''' 186 | 187 | # save dense file 188 | if preds is not None: 189 | coord_pred = preds[0].squeeze().data.cpu().numpy() 190 | lbl_1_onehot = preds[1].squeeze().data.cpu().numpy() 191 | lbl_2_onehot = preds[2].squeeze().data.cpu().numpy() 192 | pkl_save = {'coord_pred': coord_pred, 'lbl_1': lbl_1_onehot, 'lbl_2': lbl_2_onehot} 193 | pkl_file = open('{}/dense_pred_{}.pkl'.format(self.dense_pred_path, buff_id), 'wb') 194 | pickle.dump(pkl_save, pkl_file) 195 | 196 | # dump to buffer file 197 | buffer_training = open(self.buffer_fn, 'w+') 198 | for item in self.buffer_list: 199 | buffer_training.write('{}\n'.format(item)) 200 | 201 | def add_bal_buff(self, fname, preds=None, nc=None): 202 | # frame: scene_name seq_num image_name 203 | frame = fname[0] 204 | scene = frame.split(' ')[0] 205 | valid_subsc = np.array(list(self.img2subscenes[frame])) 206 | subsc_mask = -1*np.ones(625) 207 | subsc_mask[valid_subsc-1] = 1 208 | 209 | # if 1st sample then init scene dict to contain indices of occupied positions 210 | if nc == 0: 211 | self.buffer_class[scene] = [] 212 | 213 | # if buffer available store 214 | if len(self.buffer_list) < self.buffer_size: 215 | self.buffer_list.append(frame) 216 | buff_id = len(self.buffer_list)-1 217 | self.buffer_class[scene].append(buff_id) 218 | self.img2sub_bin.append(subsc_mask) 219 | else: 220 | # if current scene is not the largest 221 | if scene != self.largest: 222 | # get the indices alloted to largest class and sample 223 | largest_inst = self.buffer_class[self.largest] 224 | buff_id = random.sample(largest_inst,1)[0] 225 | 226 | ''' 227 | self.buffer_list[self.buff_id] = frame 228 | self.buffer_class[scene].append(self.buff_id) # add buff_id to scene dict 229 | self.img2sub_bin[self.buff_id] = subsc_mask 230 | self.buffer_class[self.largest].remove(self.buff_id) # remove buff_id from largest_scene dict 231 | buff_id = self.buff_id 232 | ''' 233 | self.buffer_list[buff_id] = frame 234 | self.buffer_class[scene].append(buff_id) # add buff_id to scene dict 235 | self.img2sub_bin[buff_id] = subsc_mask 236 | self.buffer_class[self.largest].remove(buff_id) # remove buff_id from largest_scene dict 237 | else: 238 | # mc = len(self.buffer_class[scene]) 239 | # uid = random.uniform(0, 1) 240 | # if uid <= mc / nc: 241 | 242 | # check if keep/drop (flag=1/0) 243 | flag = self.compute_subsc_difference(scene, subsc_mask) 244 | # make the replacement non deterministic for flag==0 items 245 | mc = len(self.buffer_class[scene]) 246 | uid = random.uniform(0, 1) 247 | if uid <= mc / nc: 248 | flag = 1 249 | 250 | if flag == 1: 251 | self_inst = self.buffer_class[scene] 252 | buff_id = random.sample(self_inst,1)[0] 253 | 254 | #self.buffer_list[self.buff_id] = frame 255 | #self.img2sub_bin[self.buff_id] = subsc_mask 256 | #buff_id = self.buff_id 257 | self.buffer_list[buff_id] = frame 258 | self.img2sub_bin[buff_id] = subsc_mask 259 | # no need to add or remove as it is self-substituting 260 | else: 261 | return 0 262 | 263 | # TODO online processing 264 | # find largest class 265 | scene_num = [] 266 | scene_name = [] 267 | for sc in self.buffer_class: 268 | #print(self.buffer_class[sc]) 269 | scene_num.append(len(self.buffer_class[sc])) 270 | scene_name.append(sc) 271 | 272 | self.largest = scene_name[np.argsort(-np.array(scene_num))[0]] 273 | 274 | # compute overlap score of the current largest scene 275 | if len(self.buffer_list) == self.buffer_size: 276 | # collect the binary img 2 subscene vectors 277 | # require : self.buffer_class to get scene to buffer list mapping 278 | # : img2sub_bin to get binary vectors from buffer list mapping 279 | self_inst = self.buffer_class[self.largest] 280 | bin_vecs = np.array([self.img2sub_bin[k] for k in self_inst]) 281 | S = bin_vecs@bin_vecs.T 282 | S = S.sum(1) 283 | self.buff_id = self_inst[np.argsort(-S)[0]] 284 | 285 | 286 | # save dense file 287 | if preds is not None: 288 | coord_pred = preds[0].squeeze().data.cpu().numpy() 289 | lbl_1_onehot = preds[1].squeeze().data.cpu().numpy() 290 | lbl_2_onehot = preds[2].squeeze().data.cpu().numpy() 291 | pkl_save = {'coord_pred': coord_pred, 'lbl_1': lbl_1_onehot, 'lbl_2': lbl_2_onehot} 292 | pkl_file = open('{}/dense_pred_{}.pkl'.format(self.dense_pred_path, buff_id), 'wb') 293 | pickle.dump(pkl_save, pkl_file) 294 | # dump to buffer file 295 | buffer_training = open(self.buffer_fn, 'w+') 296 | for item in self.buffer_list: 297 | buffer_training.write('{}\n'.format(item)) 298 | 299 | 300 | def compute_subsc_difference(self,scene, subsc_mask): 301 | ''' 302 | compute the difference between imcoming image and current sub scene 303 | scene: name of the current scene 304 | sub_mask: array 1 * 625 305 | ''' 306 | buff_subsc_lists = [] 307 | for id in self.buffer_class[scene]: 308 | frame_name = self.buffer_list[id] 309 | buff_subsc_list= list(self.img2subscenes[frame_name]) 310 | buff_subsc_lists += buff_subsc_list 311 | buff_subsc_lists = set(buff_subsc_lists) 312 | buff_subsc_valid = np.array(list(buff_subsc_lists)) 313 | buff_subsc_mask = np.array([0] * 625) 314 | try:buff_subsc_mask[buff_subsc_valid-1] = 1 315 | except: 316 | import pdb 317 | pdb.set_trace() 318 | diff = subsc_mask - buff_subsc_mask 319 | if 1 in diff: 320 | return True 321 | else: 322 | return False 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .seven_scenes import SevenScenes 2 | from .twelve_scenes import TwelveScenes 3 | 4 | def get_dataset(name): 5 | 6 | return { 7 | '7S' : SevenScenes, 8 | '12S' : TwelveScenes 9 | }[name] 10 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/cambridge_landmarks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/cambridge_landmarks.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/seven_scenes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/seven_scenes.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/twelve_scenes.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/twelve_scenes.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/twelve_scenes_copy.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/twelve_scenes_copy.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/datasets/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/seven_scenes.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import random 5 | import numpy as np 6 | import cv2 7 | from torch.utils import data 8 | import pickle 9 | 10 | from .utils import * 11 | 12 | 13 | class SevenScenes(data.Dataset): 14 | def __init__(self, root, dataset='7S', scene='heads', split='train', 15 | model='hscnet', aug='True', Buffer= False, dense_pred_flag= False, exp='exp_name'): 16 | self.Buffer = Buffer 17 | #self.buffer_size = buffer_size 18 | self.intrinsics_color = np.array([[525.0, 0.0, 320.0], 19 | [0.0, 525.0, 240.0], 20 | [0.0, 0.0, 1.0]]) 21 | 22 | self.intrinsics_depth = np.array([[585.0, 0.0, 320.0], 23 | [0.0, 585.0, 240.0], 24 | [0.0, 0.0, 1.0]]) 25 | 26 | self.intrinsics_depth_inv = np.linalg.inv(self.intrinsics_depth) 27 | self.intrinsics_color_inv = np.linalg.inv(self.intrinsics_color) 28 | self.model = model 29 | self.dataset = dataset 30 | self.aug = aug 31 | 32 | self.root = os.path.join(root,'7Scenes') 33 | self.calibration_extrinsics = np.loadtxt(os.path.join(self.root, 34 | 'sensorTrans.txt')) 35 | self.scene = scene 36 | if self.dataset == '7S': 37 | self.scene_ctr = np.loadtxt(os.path.join(self.root, scene, 38 | 'translation.txt')) 39 | self.centers = np.load(os.path.join(self.root, scene, 40 | 'centers.npy')) 41 | else: 42 | self.scenes = ['chess','fire','heads','office','pumpkin', 43 | 'redkitchen','stairs'] 44 | self.transl = [[0,0,0],[10,0,0],[-10,0,0],[0,10,0],[0,-10,0], 45 | [0,0,10],[0,0,-10]] 46 | self.ids = [0,1,2,3,4,5,6] 47 | self.scene_data = {} 48 | for scene, t, d in zip(self.scenes, self.transl, self.ids): 49 | self.scene_data[scene] = (t, d, np.load(os.path.join(self.root, 50 | scene, 'centers.npy')), 51 | np.loadtxt(os.path.join(self.root, 52 | scene,'translation.txt'))) 53 | 54 | self.split = split 55 | self.obj_suffixes = ['.color.png','.pose.txt', '.depth.png', 56 | '.label.png'] 57 | self.obj_keys = ['color','pose', 'depth','label'] 58 | 59 | with open(os.path.join(self.root, '{}{}'.format(self.split, 60 | '.txt')), 'r') as f: 61 | self.frames = f.readlines() 62 | if self.dataset == '7S' or self.split == 'test': 63 | self.frames = [frame for frame in self.frames \ 64 | if self.scene in frame] 65 | 66 | if self.Buffer: 67 | if self.dataset == 'i7S': 68 | split_buffer = 'train_buffer_{}'.format(exp) 69 | with open(os.path.join(self.root, '{}{}'.format(split_buffer, '.txt')), 'r') as f: 70 | self.frames_buffer = f.readlines() 71 | self.dense_pred_prefix = 'dense_pred_{}'.format(exp) 72 | self.dense_pred_path = os.path.join(self.root, self.dense_pred_prefix) 73 | self.dense_pred_flag = dense_pred_flag 74 | 75 | self.scene_keys = {'chess': 0, 'fire': 1, 'heads': 2, 'office': 3, 'pumpkin': 4, 'redkitchen': 5, 76 | 'stairs': 6} 77 | self.buffer_scenes = dict() 78 | for i, frame in enumerate(self.frames_buffer): 79 | scene = frame.split(' ')[0] 80 | scene_key = self.scene_keys[scene] 81 | if scene_key not in self.buffer_scenes: 82 | self.buffer_scenes[scene_key] = [] 83 | else: 84 | self.buffer_scenes[scene_key].append(frame) 85 | if self.dataset == 'i19S': 86 | split_buffer = 'train_buffer_{}'.format(exp) 87 | root_19S = '/data/dataset/7_12_scenes/data/19Scenes' 88 | with open(os.path.join(root_19S, '{}{}'.format(split_buffer, '.txt')), 'r') as f: 89 | self.frames_buffer = f.readlines() 90 | self.dense_pred_prefix = 'dense_pred_{}'.format(exp) 91 | self.dense_pred_path = os.path.join(root_19S, self.dense_pred_prefix) 92 | self.dense_pred_flag = dense_pred_flag 93 | 94 | 95 | 96 | def __len__(self): 97 | return len(self.frames) 98 | 99 | def __getitem__(self, index): 100 | frame = self.frames[index].rstrip('\n') 101 | scene, seq_id, frame_id = frame.split(' ') 102 | #print(scene, seq_id, frame_id) 103 | 104 | if self.dataset!='7S': 105 | centers = self.scene_data[scene][2] 106 | scene_ctr = self.scene_data[scene][3] 107 | else: 108 | centers = self.centers 109 | scene_ctr = self.scene_ctr 110 | 111 | obj_files = ['{}{}'.format(frame_id, 112 | obj_suffix) for obj_suffix in self.obj_suffixes] 113 | obj_files_full = [os.path.join(self.root, scene, 114 | seq_id, obj_file) for obj_file in obj_files] 115 | objs = {} 116 | for key, data in zip(self.obj_keys, obj_files_full): 117 | objs[key] = data 118 | 119 | img = cv2.imread(objs['color']) 120 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 121 | 122 | pose = np.loadtxt(objs['pose']) 123 | 124 | pose[0:3,3] = pose[0:3,3] - scene_ctr 125 | 126 | if self.dataset != '7S' and (self.model != 'hscnet' \ 127 | or self.split == 'test'): 128 | pose[0:3,3] = pose[0:3,3] + np.array(self.scene_data[scene][0]) 129 | 130 | if self.split == 'test': 131 | img, pose = to_tensor_query(img, pose) 132 | return img, pose 133 | 134 | lbl = cv2.imread(objs['label'],-1) 135 | ctr_coord = centers[np.reshape(lbl,(-1))-1,:] 136 | 137 | ctr_coord = np.reshape(ctr_coord,(480,640,3)) * 1000 138 | 139 | depth = cv2.imread(objs['depth'],-1) 140 | 141 | pose[0:3,3] = pose[0:3,3] * 1000 142 | 143 | depth[depth==65535] = 0 144 | depth = depth * 1.0 145 | depth = get_depth(depth, self.calibration_extrinsics, 146 | self.intrinsics_color, self.intrinsics_depth_inv) 147 | coord, mask = get_coord(depth, pose, self.intrinsics_color_inv) 148 | # set flag for data 149 | img, coord, ctr_coord, mask, lbl = data_aug(img, coord, ctr_coord, 150 | mask, lbl, self.aug) 151 | 152 | if self.model == 'hscnet': 153 | coord = coord - ctr_coord 154 | 155 | coord = coord[4::8,4::8,:] 156 | mask = mask[4::8,4::8].astype(np.float16) 157 | lbl = lbl[4::8,4::8].astype(np.float16) 158 | 159 | if self.dataset=='7S': 160 | lbl_1 = (lbl - 1) // 25 161 | else: 162 | lbl_1 = (lbl - 1)//25 + 25*self.scene_data[scene][1] 163 | lbl_2 = ((lbl - 1) % 25) 164 | 165 | if self.dataset=='7S': 166 | N1=25 167 | if self.dataset=='i7S': 168 | N1=175 169 | if self.dataset=='i19S': 170 | N1=475 171 | 172 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh = to_tensor(img, 173 | coord, mask, lbl_1, lbl_2, N1) 174 | if self.Buffer: 175 | data_buffer = self.buffer(index) 176 | else: 177 | data_buffer = () 178 | data_ori = (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame) 179 | 180 | return data_ori, data_buffer 181 | 182 | def buffer(self, index): 183 | max_length = len(self.frames_buffer) 184 | if index >= max_length: 185 | index = index % max_length 186 | frame = self.frames_buffer[index].rstrip('\n') 187 | scene, seq_id, frame_id = frame.split(' ') 188 | if self.dataset!='7S': 189 | centers = self.scene_data[scene][2] 190 | scene_ctr = self.scene_data[scene][3] 191 | else: 192 | centers = self.centers 193 | scene_ctr = self.scene_ctr 194 | obj_files = ['{}{}'.format(frame_id, 195 | obj_suffix) for obj_suffix in self.obj_suffixes] 196 | obj_files_full = [os.path.join(self.root, scene, 197 | seq_id, obj_file) for obj_file in obj_files] 198 | objs = {} 199 | for key, data in zip(self.obj_keys, obj_files_full): 200 | objs[key] = data 201 | 202 | img = cv2.imread(objs['color']) 203 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | 205 | pose = np.loadtxt(objs['pose']) 206 | 207 | pose[0:3,3] = pose[0:3,3] - scene_ctr 208 | 209 | if self.dataset != '7S' and (self.model != 'hscnet' \ 210 | or self.split == 'test'): 211 | pose[0:3,3] = pose[0:3,3] + np.array(self.scene_data[scene][0]) 212 | 213 | if self.split == 'test': 214 | img, pose = to_tensor_query(img, pose) 215 | return img, pose 216 | 217 | lbl = cv2.imread(objs['label'],-1) 218 | ctr_coord = centers[np.reshape(lbl,(-1))-1,:] 219 | 220 | ctr_coord = np.reshape(ctr_coord,(480,640,3)) * 1000 221 | 222 | depth = cv2.imread(objs['depth'],-1) 223 | 224 | pose[0:3,3] = pose[0:3,3] * 1000 225 | 226 | depth[depth==65535] = 0 227 | depth = depth * 1.0 228 | depth = get_depth(depth, self.calibration_extrinsics, 229 | self.intrinsics_color, self.intrinsics_depth_inv) 230 | coord, mask = get_coord(depth, pose, self.intrinsics_color_inv) 231 | # comment the data augmentation 232 | # img, coord, ctr_coord, mask, lbl = data_aug(img, coord, ctr_coord, 233 | # mask, lbl, self.aug) 234 | 235 | if self.model == 'hscnet': 236 | coord = coord - ctr_coord 237 | 238 | coord = coord[4::8,4::8,:] 239 | mask = mask[4::8,4::8].astype(np.float16) 240 | lbl = lbl[4::8,4::8].astype(np.float16) 241 | 242 | if self.dataset=='7S': 243 | lbl_1 = (lbl - 1) // 25 244 | else: 245 | lbl_1 = (lbl - 1)//25 + 25*self.scene_data[scene][1] 246 | lbl_2 = ((lbl - 1) % 25) 247 | 248 | if self.dataset=='7S': 249 | N1=25 250 | if self.dataset=='i7S': 251 | N1=175 252 | if self.dataset=='i19S': 253 | N1=475 254 | 255 | if self.dense_pred_flag: 256 | # load the dense preds from earlier cycles 257 | pkl_file = open('{}/{}_{}.pkl'.format(self.dense_pred_path, 'dense_pred', index), 'rb') 258 | preds = pickle.load(pkl_file) 259 | #preds = np.load('{}/{}_{}.npz'.format(self.dense_pred_path, self.dense_pred_prefix, index)) 260 | dense_pred_lbl_2 = preds['lbl_2'] 261 | dense_pred_lbl_1 = preds['lbl_1'] 262 | coord_pred = preds['coord_pred'] 263 | dense_pred_lbl_2 = torch.from_numpy(dense_pred_lbl_2) 264 | dense_pred_lbl_1 = torch.from_numpy(dense_pred_lbl_1) 265 | coord_pred = torch.from_numpy(coord_pred) 266 | dense_pred = (dense_pred_lbl_1, dense_pred_lbl_2,coord_pred ) 267 | 268 | else: 269 | dense_pred = () 270 | 271 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh = to_tensor(img, 272 | coord, mask, lbl_1, lbl_2, N1) 273 | 274 | return (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame, dense_pred) 275 | 276 | 277 | def get_sampling_prob(self): 278 | scene_nums = [] 279 | # get the number of samples per scene in buffer 280 | for i, sc in self.buffer_scenes: 281 | scene_nums.append(len(self.buffer_scenes[sc])) 282 | 283 | scene_nums = np.array(scene_nums) 284 | weight = 1/scene_nums 285 | scene_prob = weight/weight.sum() 286 | 287 | return scene_prob 288 | 289 | 290 | 291 | -------------------------------------------------------------------------------- /datasets/twelve_scenes.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import random 5 | import numpy as np 6 | import cv2 7 | from torch.utils import data 8 | import pickle 9 | 10 | from .utils import * 11 | 12 | 13 | class TwelveScenes(data.Dataset): 14 | def __init__(self, root, dataset='12S', scene='apt2/bed', split='train', 15 | model='hscnet', aug='True', Buffer= False, dense_pred_flag= False, exp='exp_name'): 16 | self.Buffer = Buffer 17 | # self.buffer_size = buffer_size 18 | self.intrinsics_color = np.array([[572.0, 0.0, 320.0], 19 | [0.0, 572.0, 240.0], 20 | [0.0, 0.0, 1.0]]) 21 | 22 | self.intrinsics_color_inv = np.linalg.inv(self.intrinsics_color) 23 | 24 | 25 | self.model = model 26 | self.dataset = dataset 27 | self.aug = aug 28 | self.root = os.path.join(root,'12Scenes') 29 | self.scene = scene 30 | if self.dataset == '12S': 31 | self.centers = np.load(os.path.join(self.root, scene, 32 | 'centers.npy')) 33 | else: 34 | self.scenes = ['apt1/kitchen','apt1/living','apt2/bed', 35 | 'apt2/kitchen','apt2/living','apt2/luke','office1/gates362', 36 | 'office1/gates381','office1/lounge','office1/manolis', 37 | 'office2/5a','office2/5b'] 38 | self.transl = [[0,-20,0],[0,-20,0],[20,0,0],[20,0,0],[25,0,0], 39 | [20,0,0],[-20,0,0],[-25,5,0],[-20,0,0],[-20,-5,0],[0,20,0], 40 | [0,20,0]] 41 | if self.dataset == 'i12S': 42 | self.ids = [0,1,2,3,4,5,6,7,8,9,10,11] 43 | else: 44 | self.ids = [7,8,9,10,11,12,13,14,15,16,17,18] 45 | self.scene_data = {} 46 | for scene, t, d in zip(self.scenes, self.transl, self.ids): 47 | self.scene_data[scene] = (t, d, np.load(os.path.join(self.root, 48 | scene, 'centers.npy'))) 49 | 50 | self.split, scene = split.split('_') 51 | 52 | 53 | self.obj_suffixes = ['.color.jpg', '.pose.txt', '.depth.png', 54 | '.label.png'] 55 | self.obj_keys = ['color', 'pose', 'depth', 'label'] 56 | 57 | if self.dataset == '12S' or self.split == 'test': 58 | with open(os.path.join(self.root, self.scene, 59 | '{}{}'.format(self.split, '.txt')), 'r') as f: 60 | self.frames = f.readlines() 61 | else: 62 | self.frames = [] 63 | 64 | with open(os.path.join(self.root, scene, 65 | '{}{}'.format(self.split, '.txt')), 'r') as f: 66 | frames = f.readlines() 67 | self.frames = [scene + ' ' + frame for frame in frames ] 68 | 69 | 70 | if self.Buffer: 71 | if self.dataset == 'i12S': 72 | split_buffer = 'train_buffer_{}'.format(exp) 73 | with open(os.path.join(self.root, '{}{}'.format(split_buffer, '.txt')), 'r') as f: 74 | self.frames_buffer = f.readlines() 75 | self.dense_pred_prefix = 'dense_pred_{}'.format(exp) 76 | self.dense_pred_path = os.path.join(self.root, self.dense_pred_prefix) 77 | self.dense_pred_flag = dense_pred_flag 78 | 79 | if self.dataset == 'i19S': 80 | split_buffer = 'train_buffer_{}'.format(exp) 81 | root_19S = '/data/dataset/7_12_scenes/data/19Scenes' 82 | with open(os.path.join(root_19S, '{}{}'.format(split_buffer, '.txt')), 'r') as f: 83 | self.frames_buffer = f.readlines() 84 | self.dense_pred_prefix = 'dense_pred_{}'.format(exp) 85 | self.dense_pred_path = os.path.join(root_19S, self.dense_pred_prefix) 86 | self.dense_pred_flag = dense_pred_flag 87 | 88 | 89 | 90 | def __len__(self): 91 | return len(self.frames) 92 | 93 | def __getitem__(self, index): 94 | frame = self.frames[index].rstrip('\n') 95 | if self.dataset != '12S' and self.split == 'train': 96 | scene, frame = frame.split(' ') 97 | centers = self.scene_data[scene][2] 98 | else: 99 | scene = self.scene 100 | if self.split == 'train': 101 | centers = self.centers 102 | 103 | obj_files = ['{}{}'.format(frame, 104 | obj_suffix) for obj_suffix in self.obj_suffixes] 105 | obj_files_full = [os.path.join(self.root, scene, 'data', 106 | obj_file) for obj_file in obj_files] 107 | objs = {} 108 | for key, data in zip(self.obj_keys, obj_files_full): 109 | objs[key] = data 110 | 111 | img = cv2.imread(objs['color']) 112 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 113 | img = cv2.resize(img, (640, 480)) 114 | 115 | pose = np.loadtxt(objs['pose']) 116 | if self.dataset != '12S' and (self.model != 'hscnet' \ 117 | or self.split == 'test'): 118 | pose[0:3,3] = pose[0:3,3] + np.array(self.scene_data[scene][0]) 119 | 120 | if self.split == 'test': 121 | img, pose = to_tensor_query(img, pose) 122 | return img, pose 123 | 124 | lbl = cv2.imread(objs['label'],-1) 125 | 126 | ctr_coord = centers[np.reshape(lbl,(-1))-1,:] 127 | ctr_coord = np.reshape(ctr_coord,(480,640,3)) * 1000 128 | 129 | depth = cv2.imread(objs['depth'],-1) 130 | 131 | pose[0:3,3] = pose[0:3,3] * 1000 132 | 133 | coord, mask = get_coord(depth, pose, self.intrinsics_color_inv) 134 | 135 | img, coord, ctr_coord, mask, lbl = data_aug(img, coord, ctr_coord, 136 | mask, lbl, self.aug) 137 | 138 | if self.model == 'hscnet': 139 | coord = coord - ctr_coord 140 | 141 | coord = coord[4::8,4::8,:] 142 | mask = mask[4::8,4::8].astype(np.float16) 143 | lbl = lbl[4::8,4::8].astype(np.float16) 144 | 145 | if self.dataset=='12S': 146 | lbl_1 = (lbl - 1) // 25 147 | else: 148 | lbl_1 = (lbl - 1) // 25 + 25*self.scene_data[scene][1] 149 | lbl_2 = ((lbl - 1) % 25) 150 | 151 | if self.dataset=='12S': 152 | N1=25 153 | if self.dataset=='i12S': 154 | N1=300 155 | if self.dataset=='i19S': 156 | N1=475 157 | 158 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh = to_tensor(img, 159 | coord, mask, lbl_1, lbl_2, N1) 160 | 161 | if self.Buffer: 162 | data_buffer = self.buffer(index) 163 | else: 164 | data_buffer = () 165 | data_ori = (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame) 166 | 167 | return data_ori, data_buffer 168 | 169 | def buffer(self, index): 170 | max_length = len(self.frames_buffer) 171 | if index >= max_length: 172 | index = index % max_length 173 | frame = self.frames_buffer[index].rstrip('\n') 174 | if self.dataset != '12S' and self.split == 'train': 175 | if self.dataset == 'i12S': 176 | scene, frame_id = frame.split(' ') 177 | centers = self.scene_data[scene][2] 178 | if self.dataset == 'i19S': 179 | scene_name = frame.split(' ')[0] 180 | if scene_name in self.scenes: 181 | scene, frame_id = frame.split(' ') 182 | centers = self.scene_data[scene][2] 183 | else: 184 | scene, seq_id, frame_id = frame.split(' ') 185 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame, dense_pred = self.load_from_7S(frame, index) 186 | 187 | return (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame, dense_pred) 188 | 189 | 190 | else: 191 | scene = self.scene 192 | if self.split == 'train': 193 | centers = self.centers 194 | 195 | obj_files = ['{}{}'.format(frame_id, 196 | obj_suffix) for obj_suffix in self.obj_suffixes] 197 | obj_files_full = [os.path.join(self.root, scene, 'data', 198 | obj_file) for obj_file in obj_files] 199 | objs = {} 200 | for key, data in zip(self.obj_keys, obj_files_full): 201 | objs[key] = data 202 | img = cv2.imread(objs['color']) 203 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 204 | img = cv2.resize(img, (640, 480)) 205 | 206 | pose = np.loadtxt(objs['pose']) 207 | if self.dataset != '12S' and (self.model != 'hscnet' \ 208 | or self.split == 'test'): 209 | pose[0:3,3] = pose[0:3,3] + np.array(self.scene_data[scene][0]) 210 | 211 | if self.split == 'test': 212 | img, pose = to_tensor_query(img, pose) 213 | return img, pose 214 | 215 | lbl = cv2.imread(objs['label'],-1) 216 | 217 | ctr_coord = centers[np.reshape(lbl,(-1))-1,:] 218 | ctr_coord = np.reshape(ctr_coord,(480,640,3)) * 1000 219 | 220 | depth = cv2.imread(objs['depth'],-1) 221 | 222 | pose[0:3,3] = pose[0:3,3] * 1000 223 | 224 | coord, mask = get_coord(depth, pose, self.intrinsics_color_inv) 225 | 226 | img, coord, ctr_coord, mask, lbl = data_aug(img, coord, ctr_coord, 227 | mask, lbl, self.aug) 228 | 229 | if self.model == 'hscnet': 230 | coord = coord - ctr_coord 231 | 232 | coord = coord[4::8,4::8,:] 233 | mask = mask[4::8,4::8].astype(np.float16) 234 | lbl = lbl[4::8,4::8].astype(np.float16) 235 | 236 | if self.dataset=='12S': 237 | lbl_1 = (lbl - 1) // 25 238 | else: 239 | lbl_1 = (lbl - 1) // 25 + 25*self.scene_data[scene][1] 240 | lbl_2 = ((lbl - 1) % 25) 241 | 242 | if self.dataset=='12S': 243 | N1=25 244 | if self.dataset=='i12S': 245 | N1=300 246 | if self.dataset=='i19S': 247 | N1=475 248 | 249 | if self.dense_pred_flag: 250 | # load the dense preds from earlier cycles 251 | pkl_file = open('{}/{}_{}.pkl'.format(self.dense_pred_path, 'dense_pred', index), 'rb') 252 | preds = pickle.load(pkl_file) 253 | #preds = np.load('{}/{}_{}.npz'.format(self.dense_pred_path, self.dense_pred_prefix, index)) 254 | dense_pred_lbl_2 = preds['lbl_2'] 255 | dense_pred_lbl_1 = preds['lbl_1'] 256 | coord_pred = preds['coord_pred'] 257 | dense_pred_lbl_2 = torch.from_numpy(dense_pred_lbl_2) 258 | dense_pred_lbl_1 = torch.from_numpy(dense_pred_lbl_1) 259 | coord_pred = torch.from_numpy(coord_pred) 260 | dense_pred = (dense_pred_lbl_1, dense_pred_lbl_2,coord_pred ) 261 | 262 | else: 263 | dense_pred = () 264 | 265 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh = to_tensor(img, 266 | coord, mask, lbl_1, lbl_2, N1) 267 | 268 | return (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame, dense_pred) 269 | 270 | 271 | 272 | def load_from_7S(self,frame, index): 273 | scene_name, seq_id, frame_id = frame.split(' ') 274 | intrinsics_color = np.array([[525.0, 0.0, 320.0], 275 | [0.0, 525.0, 240.0], 276 | [0.0, 0.0, 1.0]]) 277 | 278 | intrinsics_depth = np.array([[585.0, 0.0, 320.0], 279 | [0.0, 585.0, 240.0], 280 | [0.0, 0.0, 1.0]]) 281 | 282 | intrinsics_depth_inv = np.linalg.inv(intrinsics_depth) 283 | intrinsics_color_inv = np.linalg.inv(intrinsics_color) 284 | root = '/data/dataset/7_12_scenes/data/7Scenes' 285 | calibration_extrinsics = np.loadtxt(os.path.join(root, 286 | 'sensorTrans.txt')) 287 | 288 | 289 | scenes = ['chess','fire','heads','office','pumpkin', 290 | 'redkitchen','stairs'] 291 | transl = [[0,0,0],[10,0,0],[-10,0,0],[0,10,0],[0,-10,0], 292 | [0,0,10],[0,0,-10]] 293 | ids = [0,1,2,3,4,5,6] 294 | scene_data = {} 295 | for scene, t, d in zip(scenes, transl, ids): 296 | scene_data[scene] = (t, d, np.load(os.path.join(root,scene, 'centers.npy')),np.loadtxt(os.path.join(root, scene,'translation.txt'))) 297 | 298 | obj_suffixes = ['.color.png','.pose.txt', '.depth.png','.label.png'] 299 | obj_keys = ['color','pose', 'depth','label'] 300 | 301 | 302 | centers = scene_data[scene_name][2] 303 | scene_ctr = scene_data[scene_name][3] 304 | 305 | obj_files = ['{}{}'.format(frame_id, 306 | obj_suffix) for obj_suffix in obj_suffixes] 307 | obj_files_full = [os.path.join(root, scene_name, 308 | seq_id, obj_file) for obj_file in obj_files] 309 | objs = {} 310 | for key, data in zip(obj_keys, obj_files_full): 311 | objs[key] = data 312 | 313 | img = cv2.imread(objs['color']) 314 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 315 | 316 | pose = np.loadtxt(objs['pose']) 317 | 318 | pose[0:3,3] = pose[0:3,3] - scene_ctr 319 | if self.dataset != '7S' and (self.model != 'hscnet' \ 320 | or self.split == 'test'): 321 | pose[0:3,3] = pose[0:3,3] + np.array(scene_data[scene_name][0]) 322 | 323 | lbl = cv2.imread(objs['label'],-1) 324 | ctr_coord = centers[np.reshape(lbl,(-1))-1,:] 325 | 326 | ctr_coord = np.reshape(ctr_coord,(480,640,3)) * 1000 327 | 328 | depth = cv2.imread(objs['depth'],-1) 329 | 330 | pose[0:3,3] = pose[0:3,3] * 1000 331 | 332 | depth[depth==65535] = 0 333 | depth = depth * 1.0 334 | depth = get_depth(depth, calibration_extrinsics, 335 | intrinsics_color, intrinsics_depth_inv) 336 | coord, mask = get_coord(depth, pose, intrinsics_color_inv) 337 | # comment the data augmentation 338 | # img, coord, ctr_coord, mask, lbl = data_aug(img, coord, ctr_coord, 339 | # mask, lbl, self.aug) 340 | 341 | coord = coord - ctr_coord 342 | 343 | coord = coord[4::8,4::8,:] 344 | mask = mask[4::8,4::8].astype(np.float16) 345 | lbl = lbl[4::8,4::8].astype(np.float16) 346 | 347 | lbl_1 = (lbl - 1)//25 + 25*scene_data[scene_name][1] 348 | lbl_2 = ((lbl - 1) % 25) 349 | 350 | N1=475 351 | 352 | if self.dense_pred_flag: 353 | # load the dense preds from earlier cycles 354 | pkl_file = open('{}/{}_{}.pkl'.format(self.dense_pred_path, 'dense_pred', index), 'rb') 355 | preds = pickle.load(pkl_file) 356 | #preds = np.load('{}/{}_{}.npz'.format(self.dense_pred_path, self.dense_pred_prefix, index)) 357 | dense_pred_lbl_2 = preds['lbl_2'] 358 | dense_pred_lbl_1 = preds['lbl_1'] 359 | coord_pred = preds['coord_pred'] 360 | dense_pred_lbl_2 = torch.from_numpy(dense_pred_lbl_2) 361 | dense_pred_lbl_1 = torch.from_numpy(dense_pred_lbl_1) 362 | coord_pred = torch.from_numpy(coord_pred) 363 | dense_pred = (dense_pred_lbl_1, dense_pred_lbl_2,coord_pred ) 364 | 365 | else: 366 | dense_pred = () 367 | 368 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh = to_tensor(img, 369 | coord, mask, lbl_1, lbl_2, N1) 370 | 371 | return (img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame, dense_pred) 372 | 373 | 374 | -------------------------------------------------------------------------------- /datasets/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import numpy as np 5 | import random 6 | from imgaug import augmenters as iaa 7 | 8 | def get_depth(depth, calibration_extrinsics, intrinsics_color, 9 | intrinsics_depth_inv): 10 | """Return the calibrated depth image (7-Scenes). 11 | Calibration parameters from DSAC (https://github.com/cvlab-dresden/DSAC) 12 | are used. 13 | """ 14 | img_height, img_width = depth.shape[0], depth.shape[1] 15 | depth_ = np.zeros_like(depth) 16 | x = np.linspace(0, img_width-1, img_width) 17 | y = np.linspace(0, img_height-1, img_height) 18 | xx, yy = np.meshgrid(x, y) 19 | xx = np.reshape(xx, (1, -1)) 20 | yy = np.reshape(yy, (1, -1)) 21 | ones = np.ones_like(xx) 22 | pcoord_depth = np.concatenate((xx, yy, ones), axis=0) 23 | depth = np.reshape(depth, (1, img_height*img_width)) 24 | ccoord_depth = np.dot(intrinsics_depth_inv, pcoord_depth) * depth 25 | ccoord_depth[1,:] = - ccoord_depth[1,:] 26 | ccoord_depth[2,:] = - ccoord_depth[2,:] 27 | ccoord_depth = np.concatenate((ccoord_depth, ones), axis=0) 28 | ccoord_color = np.dot(calibration_extrinsics, ccoord_depth) 29 | ccoord_color = ccoord_color[0:3,:] 30 | ccoord_color[1,:] = - ccoord_color[1,:] 31 | ccoord_color[2,:] = depth 32 | 33 | pcoord_color = np.dot(intrinsics_color, ccoord_color) 34 | pcoord_color = pcoord_color[:,pcoord_color[2,:]!=0] 35 | pcoord_color[0,:] = pcoord_color[0,:]/pcoord_color[2,:]+0.5 36 | pcoord_color[0,:] = pcoord_color[0,:].astype(int) 37 | pcoord_color[1,:] = pcoord_color[1,:]/pcoord_color[2,:]+0.5 38 | pcoord_color[1,:] = pcoord_color[1,:].astype(int) 39 | pcoord_color = pcoord_color[:,pcoord_color[0,:]>=0] 40 | pcoord_color = pcoord_color[:,pcoord_color[1,:]>=0] 41 | pcoord_color = pcoord_color[:,pcoord_color[0,:] ThreadRand::generators; 36 | bool ThreadRand::initialised = false; 37 | 38 | void ThreadRand::forceInit(unsigned seed) 39 | { 40 | initialised = false; 41 | init(seed); 42 | } 43 | 44 | void ThreadRand::init(unsigned seed) 45 | { 46 | #pragma omp critical 47 | { 48 | if(!initialised) 49 | { 50 | unsigned nThreads = omp_get_max_threads(); 51 | 52 | for(unsigned i = 0; i < nThreads; i++) 53 | { 54 | generators.push_back(std::mt19937()); 55 | generators[i].seed(i+seed); 56 | } 57 | 58 | initialised = true; 59 | } 60 | } 61 | } 62 | 63 | int ThreadRand::irand(int min, int max, int tid) 64 | { 65 | std::uniform_int_distribution dist(min, max); 66 | 67 | unsigned threadID = omp_get_thread_num(); 68 | if(tid >= 0) threadID = tid; 69 | 70 | if(!initialised) init(); 71 | 72 | return dist(ThreadRand::generators[threadID]); 73 | } 74 | 75 | double ThreadRand::drand(double min, double max, int tid) 76 | { 77 | std::uniform_real_distribution dist(min, max); 78 | 79 | unsigned threadID = omp_get_thread_num(); 80 | if(tid >= 0) threadID = tid; 81 | 82 | if(!initialised) init(); 83 | 84 | return dist(ThreadRand::generators[threadID]); 85 | } 86 | 87 | double ThreadRand::dgauss(double mean, double stdDev, int tid) 88 | { 89 | std::normal_distribution dist(mean, stdDev); 90 | 91 | unsigned threadID = omp_get_thread_num(); 92 | if(tid >= 0) threadID = tid; 93 | 94 | if(!initialised) init(); 95 | 96 | return dist(ThreadRand::generators[threadID]); 97 | } 98 | 99 | int irand(int incMin, int excMax, int tid) 100 | { 101 | return ThreadRand::irand(incMin, excMax - 1, tid); 102 | } 103 | 104 | double drand(double incMin, double incMax,int tid) 105 | { 106 | return ThreadRand::drand(incMin, incMax, tid); 107 | } 108 | 109 | int igauss(int mean, int stdDev, int tid) 110 | { 111 | return (int) ThreadRand::dgauss(mean, stdDev, tid); 112 | } 113 | 114 | double dgauss(double mean, double stdDev, int tid) 115 | { 116 | return ThreadRand::dgauss(mean, stdDev, tid); 117 | } 118 | 119 | namespace poseSolver { 120 | 121 | std::pair getInvHyp(const std::pair& hyp) 122 | { 123 | cv::Mat_ hypR, trans = cv::Mat_::eye(4, 4); 124 | cv::Rodrigues(hyp.first, hypR); 125 | 126 | hypR.copyTo(trans.rowRange(0,3).colRange(0,3)); 127 | trans(0, 3) = hyp.second.at(0, 0); 128 | trans(1, 3) = hyp.second.at(0, 1); 129 | trans(2, 3) = hyp.second.at(0, 2); 130 | 131 | trans = trans.inv(); 132 | 133 | std::pair invHyp; 134 | cv::Rodrigues(trans.rowRange(0,3).colRange(0,3), invHyp.first); 135 | invHyp.second = cv::Mat_(1, 3); 136 | invHyp.second.at(0, 0) = trans(0, 3); 137 | invHyp.second.at(0, 1) = trans(1, 3); 138 | invHyp.second.at(0, 2) = trans(2, 3); 139 | 140 | return invHyp; 141 | } 142 | 143 | double calcAngularDistance(const std::pair & h1, const std::pair & h2) 144 | { 145 | cv::Mat r1, r2; 146 | cv::Rodrigues(h1.first, r1); 147 | cv::Rodrigues(h2.first, r2); 148 | 149 | cv::Mat rotDiff= r2 * r1.t(); 150 | double trace = cv::trace(rotDiff)[0]; 151 | 152 | trace = std::min(3.0, std::max(-1.0, trace)); 153 | return 180*acos((trace-1.0)/2.0)/CV_PI; 154 | } 155 | 156 | double maxLoss(const std::pair & h1, const std::pair & h2) 157 | { 158 | // measure loss of inverted poses (camera pose instead of scene pose) 159 | std::pair invH1 = getInvHyp(h1); 160 | std::pair invH2 = getInvHyp(h2); 161 | 162 | double rotErr = calcAngularDistance(invH1, invH2); 163 | double tErr = cv::norm(invH1.second - invH2.second); 164 | 165 | return std::max(rotErr, tErr * 100); 166 | } 167 | 168 | inline bool safeSolvePnP( 169 | const std::vector& objPts, 170 | const std::vector& imgPts, 171 | const cv::Mat& camMat, 172 | const cv::Mat& distCoeffs, 173 | cv::Mat& rot, 174 | cv::Mat& trans, 175 | bool extrinsicGuess, 176 | int methodFlag) 177 | { 178 | if(rot.type() == 0) rot = cv::Mat_::zeros(1, 3); 179 | if(trans.type() == 0) trans= cv::Mat_::zeros(1, 3); 180 | 181 | if(!cv::solvePnP(objPts, imgPts, camMat, distCoeffs, rot, trans, extrinsicGuess,methodFlag)) 182 | { 183 | rot = cv::Mat_::zeros(1, 3); 184 | trans = cv::Mat_::zeros(1, 3); 185 | return false; 186 | } 187 | return true; 188 | } 189 | 190 | PnPRANSAC::PnPRANSAC () { 191 | this->camMat = cv::Mat_::eye(3, 3); 192 | } 193 | 194 | PnPRANSAC::PnPRANSAC (float fx, float fy, float cx, float cy) { 195 | this->camMat = cv::Mat_::eye(3, 3); 196 | this->camMat(0,0) = fx; 197 | this->camMat(1,1) = fy; 198 | this->camMat(0,2) = cx; 199 | this->camMat(1,2) = cy; 200 | } 201 | 202 | PnPRANSAC::~PnPRANSAC () {} 203 | 204 | void PnPRANSAC::camMatUpdate(float fx, float fy, float cx, float cy){ 205 | this->camMat = cv::Mat_::eye(3, 3); 206 | this->camMat(0,0) = fx; 207 | this->camMat(1,1) = fy; 208 | this->camMat(0,2) = cx; 209 | this->camMat(1,2) = cy; 210 | } 211 | 212 | double* PnPRANSAC::RANSACLoop( 213 | float* imgPts_, 214 | float* objPts_, 215 | int nPts, 216 | int objHyps) 217 | { 218 | int inlierThreshold2D = 10; 219 | int refSteps = 100; 220 | 221 | std::vector imgPts(nPts); 222 | std::vector objPts(nPts); 223 | #pragma omp parallel for 224 | for(unsigned i=0; i> sampledImgPts(objHyps); 231 | std::vector> sampledObjPts(objHyps); 232 | std::vector> rotHyp(objHyps); 233 | std::vector> tHyp(objHyps); 234 | std::vector scores(objHyps); 235 | std::vector> reproDiff(objHyps); 236 | 237 | // sample hypotheses 238 | #pragma omp parallel for 239 | for(int h = 0; h < objHyps; h++) 240 | while(true) 241 | { 242 | std::vector projections; 243 | std::vector alreadyChosen(nPts,0); 244 | sampledImgPts[h].clear(); 245 | sampledObjPts[h].clear(); 246 | 247 | for(int j = 0; j < 4; j++) 248 | { 249 | int idx = irand(0, nPts); 250 | 251 | if(alreadyChosen[idx] > 0) 252 | { 253 | j--; 254 | continue; 255 | } 256 | 257 | alreadyChosen[idx] = 1; 258 | 259 | sampledImgPts[h].push_back(imgPts[idx]); // 2D location in the original RGB image 260 | sampledObjPts[h].push_back(objPts[idx]); // 3D object coordinate 261 | 262 | } 263 | 264 | if(!safeSolvePnP(sampledObjPts[h], sampledImgPts[h], this->camMat, cv::Mat(), rotHyp[h], tHyp[h], false, CV_P3P)) 265 | { 266 | continue; 267 | 268 | } 269 | 270 | cv::projectPoints(sampledObjPts[h], rotHyp[h], tHyp[h], this->camMat, cv::Mat(), projections); 271 | 272 | // check reconstruction, 4 sampled points should be reconstructed perfectly 273 | bool foundOutlier = false; 274 | for(unsigned j = 0; j < sampledImgPts[h].size(); j++) 275 | { 276 | if(cv::norm(sampledImgPts[h][j] - projections[j]) < inlierThreshold2D) 277 | continue; 278 | foundOutlier = true; 279 | break; 280 | } 281 | if(foundOutlier) 282 | continue; 283 | else{ 284 | // compute reprojection error and hypothesis score 285 | std::vector projections; 286 | cv::projectPoints(objPts, rotHyp[h], tHyp[h], this->camMat, cv::Mat(), projections); 287 | std::vector diff(nPts); 288 | float score = 0.; 289 | for(unsigned pt = 0; pt < imgPts.size(); pt++) 290 | { 291 | float err = cv::norm(imgPts[pt] - projections[pt]); 292 | diff[pt] = err; 293 | score = score + (1. / (1. + std::exp(-(0.5*(err-inlierThreshold2D))))); 294 | } 295 | reproDiff[h] = diff; 296 | scores[h] = score; 297 | break; 298 | } 299 | } 300 | 301 | int hypIdx = std::min_element(scores.begin(),scores.end()) - scores.begin(); // select winning hypothesis 302 | 303 | double convergenceThresh = 0.01; // stop refinement if 6D pose vector converges 304 | 305 | std::vector localDiff = reproDiff[hypIdx]; 306 | 307 | for(int rStep = 0; rStep < refSteps; rStep++) 308 | { 309 | // collect inliers 310 | std::vector localImgPts; 311 | std::vector localObjPts; 312 | 313 | for(int pt = 0; pt < nPts; pt++) 314 | { 315 | if(localDiff[pt] < inlierThreshold2D) 316 | { 317 | localImgPts.push_back(imgPts[pt]); 318 | localObjPts.push_back(objPts[pt]); 319 | } 320 | } 321 | 322 | if(localImgPts.size() < 4) 323 | break; 324 | 325 | // recalculate pose 326 | cv::Mat_ rotNew = rotHyp[hypIdx].clone(); 327 | cv::Mat_ tNew = tHyp[hypIdx].clone(); 328 | 329 | if(!safeSolvePnP(localObjPts, localImgPts, this->camMat, cv::Mat(), rotNew, tNew, true, (localImgPts.size() > 4) ? CV_ITERATIVE : CV_P3P)) 330 | break; //abort if PnP fails 331 | std::pair hypNew; 332 | std::pair hypOld; 333 | hypNew.first = rotNew; 334 | hypNew.second = tNew; 335 | 336 | hypOld.first = rotHyp[hypIdx]; 337 | hypOld.second = tHyp[hypIdx]; 338 | if(maxLoss(hypNew, hypOld) < convergenceThresh) 339 | break; // convergned 340 | 341 | rotHyp[hypIdx] = rotNew; 342 | tHyp[hypIdx] = tNew; 343 | 344 | // recalculate pose errors 345 | std::vector projections; 346 | cv::projectPoints(objPts, rotHyp[hypIdx], tHyp[hypIdx], this->camMat, cv::Mat(), projections); 347 | std::vector diff(nPts); 348 | 349 | #pragma omp parallel for 350 | for(unsigned pt = 0; pt < imgPts.size(); pt++) 351 | { 352 | float err = cv::norm(imgPts[pt] - projections[pt]); 353 | diff[pt] = err; 354 | } 355 | localDiff = diff; 356 | } 357 | 358 | static double pose[6]; 359 | for (int i = 0; i < 3; i++) 360 | pose[i] = rotHyp[hypIdx](0,i); 361 | for (int i = 3; i < 6; i++) 362 | pose[i] = tHyp[hypIdx](0,i-3); 363 | return pose; 364 | } 365 | } 366 | 367 | -------------------------------------------------------------------------------- /pnpransac/pnpransac.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AaltoVision/CL_HSCNet/7b55bcd79f8e985999ee5cd471626a24de58e4f0/pnpransac/pnpransac.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /pnpransac/pnpransac.h: -------------------------------------------------------------------------------- 1 | /* 2 | PnP-RANSAC implementation based on DSAC++ 3 | Code: https://github.com/vislearn/LessMore 4 | Paper: https://arxiv.org/abs/1711.10228 5 | */ 6 | 7 | /* 8 | Copyright (c) 2016, TU Dresden 9 | Copyright (c) 2017, Heidelberg University 10 | All rights reserved. 11 | Redistribution and use in source and binary forms, with or without 12 | modification, are permitted provided that the following conditions are met: 13 | * Redistributions of source code must retain the above copyright 14 | notice, this list of conditions and the following disclaimer. 15 | * Redistributions in binary form must reproduce the above copyright 16 | notice, this list of conditions and the following disclaimer in the 17 | documentation and/or other materials provided with the distribution. 18 | * Neither the name of the TU Dresden, Heidelberg University nor the 19 | names of its contributors may be used to endorse or promote products 20 | derived from this software without specific prior written permission. 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL TU DRESDEN OR HEIDELBERG UNIVERSITY BE LIABLE FOR ANY 25 | DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND 28 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #pragma once 34 | 35 | #include 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include "opencv2/opencv.hpp" 41 | 42 | /** Classes and methods for generating random numbers in multi-threaded programs. */ 43 | 44 | /** 45 | * @brief Provides random numbers for multiple threads. 46 | * 47 | * Singelton class. Holds a random number generator for each thread and gives random numbers for the current thread. 48 | */ 49 | class ThreadRand 50 | { 51 | public: 52 | /** 53 | * @brief Returns a random integer (uniform distribution). 54 | * 55 | * @param min Minimum value of the random integer (inclusive). 56 | * @param max Maximum value of the random integer (exclusive). 57 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 58 | * @return int Random integer value. 59 | */ 60 | static int irand(int min, int max, int tid = -1); 61 | 62 | /** 63 | * @brief Returns a random double value (uniform distribution). 64 | * 65 | * @param min Minimum value of the random double (inclusive). 66 | * @param max Maximum value of the random double (inclusive). 67 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 68 | * @return double Random double value. 69 | */ 70 | static double drand(double min, double max, int tid = -1); 71 | 72 | /** 73 | * @brief Returns a random double value (Gauss distribution). 74 | * 75 | * @param mean Mean of the Gauss distribution to sample from. 76 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 77 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 78 | * @return double Random double value. 79 | */ 80 | static double dgauss(double mean, double stdDev, int tid = -1); 81 | 82 | /** 83 | * @brief Re-Initialize the object with the given seed. 84 | * 85 | * @param seed Seed to initialize the random number generators (seed is incremented by one for each generator). 86 | * @return void 87 | */ 88 | static void forceInit(unsigned seed); 89 | 90 | private: 91 | /** 92 | * @brief List of random number generators. One for each thread. 93 | * 94 | */ 95 | static std::vector generators; 96 | /** 97 | * @brief True if the class has been initialized already 98 | */ 99 | static bool initialised; 100 | /** 101 | * @brief Initialize class with the given seed. 102 | * 103 | * Method will create a random number generator for each thread. The given seed 104 | * will be incremented by one for each generator. This methods is automatically 105 | * called when this calss is used the first time. 106 | * 107 | * @param seed Optional parameter. Seed to be used when initializing the generators. Will be incremented by one for each generator. 108 | * @return void 109 | */ 110 | static void init(unsigned seed = 1305); 111 | }; 112 | 113 | /** 114 | * @brief Returns a random integer (uniform distribution). 115 | * 116 | * This method used the ThreadRand class. 117 | * 118 | * @param min Minimum value of the random integer (inclusive). 119 | * @param max Maximum value of the random integer (exclusive). 120 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 121 | * @return int Random integer value. 122 | */ 123 | int irand(int incMin, int excMax, int tid = -1); 124 | /** 125 | * @brief Returns a random double value (uniform distribution). 126 | * 127 | * This method used the ThreadRand class. 128 | * 129 | * @param min Minimum value of the random double (inclusive). 130 | * @param max Maximum value of the random double (inclusive). 131 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 132 | * @return double Random double value. 133 | */ 134 | double drand(double incMin, double incMax, int tid = -1); 135 | 136 | /** 137 | * @brief Returns a random integer value (Gauss distribution). 138 | * 139 | * This method used the ThreadRand class. 140 | * 141 | * @param mean Mean of the Gauss distribution to sample from. 142 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 143 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 144 | * @return double Random integer value. 145 | */ 146 | int igauss(int mean, int stdDev, int tid = -1); 147 | 148 | /** 149 | * @brief Returns a random double value (Gauss distribution). 150 | * 151 | * This method used the ThreadRand class. 152 | * 153 | * @param mean Mean of the Gauss distribution to sample from. 154 | * @param stdDev Standard deviation of the Gauss distribution to sample from. 155 | * @param tid Optional parameter. ID of the thread to use. If not given, the method will obtain the thread ID itself. 156 | * @return double Random double value. 157 | */ 158 | double dgauss(double mean, double stdDev, int tid = -1); 159 | 160 | namespace poseSolver { 161 | 162 | /** 163 | * @brief Inverts a given transformation. 164 | * @param hyp Input transformation. 165 | * @return Inverted transformation. 166 | */ 167 | std::pair getInvHyp(const std::pair& hyp); 168 | 169 | /** 170 | * @brief Maximum of translational error (cm) and rotational error (deg) between two pose hypothesis. 171 | * @param h1 Pose 1. 172 | * @param h2 Pose 2. 173 | * @return Loss. 174 | */ 175 | double maxLoss(const std::pair& h1, const std::pair& h2); 176 | 177 | /** 178 | * @brief Calculates the rotational distance in degree between two transformations. 179 | * Translation will be ignored. 180 | * 181 | * @param h1 Transformation 1. 182 | * @param h2 Transformation 2. 183 | * @return Angle in degree. 184 | */ 185 | double calcAngularDistance(const std::pair& h1, const std::pair& h2); 186 | 187 | /** 188 | * @brief Wrapper around the OpenCV PnP function that returns a zero pose in case PnP fails. See also documentation of cv::solvePnP. 189 | * @param objPts List of 3D points. 190 | * @param imgPts Corresponding 2D points. 191 | * @param camMat Calibration matrix of the camera. 192 | * @param distCoeffs Distortion coefficients. 193 | * @param rot Output parameter. Camera rotation. 194 | * @param trans Output parameter. Camera translation. 195 | * @param extrinsicGuess If true uses input rot and trans as initialization. 196 | * @param methodFlag Specifies the PnP algorithm to be used. 197 | * @return True if PnP succeeds. 198 | */ 199 | inline bool safeSolvePnP( 200 | const std::vector& objPts, 201 | const std::vector& imgPts, 202 | const cv::Mat& camMat, 203 | const cv::Mat& distCoeffs, 204 | cv::Mat& rot, 205 | cv::Mat& trans, 206 | bool extrinsicGuess, 207 | int methodFlag); 208 | 209 | class PnPRANSAC{ 210 | public: 211 | cv::Mat_ camMat; 212 | PnPRANSAC(); 213 | 214 | PnPRANSAC(float fx, float fy, float cx, float cy); 215 | 216 | ~PnPRANSAC(); 217 | 218 | void camMatUpdate(float fx, float fy, float cx, float cy); 219 | 220 | double* RANSACLoop(float* imgPts, float* objPts, int nPts, int objHyps); 221 | }; 222 | 223 | } 224 | 225 | -------------------------------------------------------------------------------- /pnpransac/pnpransacpy.pxd: -------------------------------------------------------------------------------- 1 | from libcpp cimport bool 2 | from libcpp.vector cimport vector 3 | 4 | cdef extern from "pnpransac.cpp": 5 | pass 6 | 7 | cdef extern from "pnpransac.h" namespace "poseSolver": 8 | cdef cppclass PnPRANSAC: 9 | PnPRANSAC() except + 10 | PnPRANSAC(float, float, float, float) except + 11 | void camMatUpdate(float, float, float, float) 12 | double* RANSACLoop(float*, float*, int, int) -------------------------------------------------------------------------------- /pnpransac/pnpransacpy.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c++ 2 | 3 | import numpy as np 4 | cimport numpy as np 5 | 6 | from pnpransacpy cimport PnPRANSAC 7 | 8 | cdef class pnpransac: 9 | cdef PnPRANSAC c_pnpransac 10 | 11 | def __cinit__(self, float fx, float fy, float cx, float cy): 12 | self.c_pnpransac = PnPRANSAC(fx, fy, cx, cy) 13 | 14 | def update_camMat(self, float fx, float fy, float cx, float cy): 15 | self.c_pnpransac.camMatUpdate(fx, fy, cx, cy) 16 | 17 | def RANSAC_loop(self, np.ndarray[double, ndim=2, mode="c"] img_pts, 18 | np.ndarray[double, ndim=2, mode="c"] obj_pts, int n_hyp): 19 | cdef float[:, :] img_pts_ = img_pts.astype(np.float32) 20 | cdef float[:, :] obj_pts_ = obj_pts.astype(np.float32) 21 | cdef int n_pts 22 | n_pts = img_pts_.shape[0] 23 | assert img_pts_.shape[0] == obj_pts_.shape[0] 24 | cdef double* pose 25 | pose = self.c_pnpransac.RANSACLoop(&img_pts_[0,0], &obj_pts_[0,0], 26 | n_pts, n_hyp) 27 | rot = np.array([pose[0],pose[1],pose[2]]) 28 | transl = np.array([pose[3],pose[4],pose[5]]) 29 | return rot, transl -------------------------------------------------------------------------------- /pnpransac/setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from distutils.core import setup 4 | from distutils.extension import Extension 5 | 6 | from Cython.Distutils import build_ext 7 | from Cython.Build import cythonize 8 | 9 | # where to find opencv headers and libraries 10 | cv_include_dir = os.path.join(sys.prefix, 'include') 11 | cv_library_dir = os.path.join(sys.prefix, 'lib') 12 | 13 | ext_modules = [ 14 | Extension( 15 | "pnpransac", 16 | sources=["pnpransacpy.pyx"], 17 | language="c++", 18 | include_dirs=[cv_include_dir], 19 | library_dirs=[cv_library_dir], 20 | libraries=['opencv_core','opencv_calib3d'], 21 | extra_compile_args=['-fopenmp','-std=c++11'], 22 | ) 23 | ] 24 | 25 | setup( 26 | name='pnpransac', 27 | cmdclass={'build_ext': build_ext}, 28 | ext_modules=cythonize(ext_modules), 29 | ) -------------------------------------------------------------------------------- /train_CL_preds.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import sys 4 | import os 5 | import random 6 | import argparse 7 | from pathlib import Path 8 | import torch 9 | from torch.utils import data 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from models import get_model 14 | from datasets import get_dataset 15 | from loss import * 16 | from utils import * 17 | from buffer import createBuffer 18 | import time 19 | 20 | def train(args): 21 | # prepare datasets 22 | if args.dataset == 'i19S': 23 | datasetSs = get_dataset('7S') 24 | datasetTs = get_dataset('12S') 25 | else: 26 | if args.dataset in ['7S', 'i7S']: 27 | dataset_get = get_dataset('7S') 28 | if args.dataset in ['12S', 'i12S']: 29 | dataset_get = get_dataset('12S') 30 | 31 | 32 | # loss 33 | reg_loss = EuclideanLoss() 34 | if args.model == 'hscnet': 35 | cls_loss = CELoss() 36 | if args.dataset in ['i7S', 'i12S', 'i19S']: 37 | w1, w2, w3 = 1, 1, 100000 38 | else: 39 | w1, w2, w3 = 1, 1, 10 40 | 41 | # prepare model and optimizer 42 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 43 | model = get_model(args.model, args.dataset) 44 | model.init_weights() 45 | model.to(device) 46 | 47 | optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, eps=1e-8, 48 | betas=(0.9, 0.999)) 49 | 50 | # resume from existing or start a new session 51 | if args.resume is not None: 52 | if os.path.isfile(args.resume): 53 | print("Loading model and optimizer from checkpoint '{}'".format\ 54 | (args.resume)) 55 | checkpoint = torch.load(args.resume, map_location=device) 56 | model.load_state_dict(checkpoint['model_state']) 57 | optimizer.load_state_dict(checkpoint['optimizer_state']) 58 | print("Loaded checkpoint '{}' (epoch{})".format(args.resume, 59 | checkpoint['epoch'])) 60 | save_path = Path(args.resume) 61 | args.save_path = save_path.parent 62 | #start_epoch = checkpoint['epoch'] + 1 63 | else: 64 | print("No checkpoint found at '{}'".format(args.resume)) 65 | sys.exit() 66 | else: 67 | if args.dataset in ['i7S', 'i12S', 'i19S']: 68 | model_id = "{}-{}-{}-initlr{}-iters{}-bsize{}-aug{}-{}".format(\ 69 | args.exp_name, args.dataset, args.model, args.init_lr, args.n_iter, 70 | args.batch_size, int(args.aug), args.train_id) 71 | else: 72 | model_id = "{}-{}-{}-initlr{}-iters{}-bsize{}-aug{}-{}".format(\ 73 | args.exp_name, args.dataset, args.scene.replace('/','.'), 74 | args.model, args.init_lr, args.n_iter, args.batch_size, 75 | int(args.aug), args.train_id) 76 | save_path = Path(model_id) 77 | args.save_path = 'checkpoints'/save_path 78 | args.save_path.mkdir(parents=True, exist_ok=True) 79 | start_epoch = 1 80 | 81 | # Continual learning over scenes 82 | buffer = createBuffer(data_path=args.data_path, exp=args.exp_name, buffer_size=args.buffer_size, dataset= args.dataset) 83 | if args.dataset == 'i7S': 84 | scenes = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs'] 85 | if args.dataset == 'i12S': 86 | scenes = ['apt1/kitchen','apt1/living','apt2/bed', 87 | 'apt2/kitchen','apt2/living','apt2/luke','office1/gates362', 88 | 'office1/gates381','office1/lounge','office1/manolis', 89 | 'office2/5a','office2/5b'] 90 | if args.dataset == 'i19S': 91 | scenes = ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs', 'apt1/kitchen','apt1/living','apt2/bed', 92 | 'apt2/kitchen','apt2/living','apt2/luke','office1/gates362', 93 | 'office1/gates381','office1/lounge','office1/manolis', 94 | 'office2/5a','office2/5b'] 95 | 96 | 97 | for i,scene in enumerate(scenes): 98 | # if not first scene 99 | 100 | if args.dataset in ['i7S', 'i12S']: 101 | if i > 0: 102 | dataset = dataset_get(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 103 | model=args.model, aug=args.aug, Buffer=True, dense_pred_flag=args.dense_pred, exp=args.exp_name) 104 | else: 105 | dataset = dataset_get(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 106 | model=args.model, aug=args.aug, Buffer=False, exp=args.exp_name) 107 | 108 | trainloader = data.DataLoader(dataset, batch_size=args.batch_size, 109 | num_workers=4, shuffle=True) 110 | 111 | 112 | buffer_dataset = dataset_get(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 113 | model=args.model, aug=False, Buffer=False, dense_pred_flag=args.dense_pred, exp=args.exp_name) 114 | buffer_trainloader = data.DataLoader(buffer_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) 115 | 116 | if args.dataset == 'i19S': 117 | if i == 0: 118 | dataset = datasetSs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 119 | model=args.model, aug=args.aug, Buffer=False, exp=args.exp_name) 120 | buffer_dataset = datasetSs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 121 | model=args.model, aug=False, Buffer=False, dense_pred_flag=args.dense_pred, exp=args.exp_name) 122 | if i >0 and i < 7: 123 | dataset = datasetSs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 124 | model=args.model, aug=args.aug, Buffer=True, dense_pred_flag=args.dense_pred, exp=args.exp_name) 125 | buffer_dataset = datasetSs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 126 | model=args.model, aug=False, Buffer=False, dense_pred_flag=args.dense_pred, exp=args.exp_name) 127 | if i >= 7: 128 | dataset = datasetTs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 129 | model=args.model, aug=args.aug, Buffer=True, dense_pred_flag=args.dense_pred, exp=args.exp_name) 130 | buffer_dataset = datasetTs(args.data_path, args.dataset, args.scene, split='train_{}'.format(scene), 131 | model=args.model, aug=False, Buffer=False, dense_pred_flag=args.dense_pred, exp=args.exp_name) 132 | 133 | trainloader = data.DataLoader(dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) 134 | 135 | buffer_trainloader = data.DataLoader(buffer_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True) 136 | 137 | # start training 138 | 139 | args.n_epoch = int(np.ceil(args.n_iter * args.batch_size / len(dataset))) 140 | 141 | #for epoch in range(start_epoch, start_epoch + args.n_epoch+1): 142 | for epoch in range(1, args.n_epoch+1): 143 | lr = args.init_lr 144 | 145 | model.train() 146 | train_loss_list = [] 147 | coord_loss_list = [] 148 | if args.model == 'hscnet': 149 | lbl_1_loss_list = [] 150 | lbl_2_loss_list = [] 151 | 152 | for _, (data_ori, data_buffer) in enumerate(tqdm(trainloader)): 153 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, _ = data_ori 154 | 155 | if mask.sum() == 0: 156 | continue 157 | optimizer.zero_grad() 158 | 159 | img = img.to(device) 160 | coord = coord.to(device) 161 | mask = mask.to(device) 162 | train_loss, coord_loss, lbl_1_loss, lbl_2_loss = loss(img, coord, mask, lbl_1, lbl_2, lbl_1_oh, 163 | lbl_2_oh, model, reg_loss, cls_loss, device, w1, w2, w3) 164 | 165 | 166 | 167 | # compute loss for buffer if not first scene 168 | if i > 0 : 169 | # sample a random minibatch from buffer dataloader 170 | img_buff, coord_buff, mask_buff, lbl_1_buff, lbl_2_buff, lbl_1_oh_buff, lbl_2_oh_buff, _, dense_pred = data_buffer 171 | 172 | if mask_buff.sum() == 0: 173 | continue 174 | img_buff = img_buff.to(device) 175 | coord_buff = coord_buff.to(device) 176 | mask_buff = mask_buff.to(device) 177 | 178 | buff_loss = loss_buff_DK(img_buff, coord_buff, mask_buff, lbl_1_buff, lbl_2_buff, lbl_1_oh_buff, 179 | lbl_2_oh_buff, model, reg_loss, cls_loss, device, w1, w2, w3, dense_pred=dense_pred) 180 | 181 | train_loss+= 1 * buff_loss 182 | 183 | 184 | coord_loss_list.append(coord_loss.item()) 185 | if args.model == 'hscnet': 186 | lbl_1_loss_list.append(lbl_1_loss.item()) 187 | lbl_2_loss_list.append(lbl_2_loss.item()) 188 | train_loss_list.append(train_loss.item()) 189 | train_loss.backward() 190 | optimizer.step() 191 | 192 | with open(args.save_path/args.log_summary, 'a') as logfile: 193 | if args.model == 'hscnet': 194 | logtt = 'task {}:Epoch {}/{} - lr: {} - reg_loss: {} - cls_loss_1: {}' \ 195 | ' - cls_loss_2: {} - train_loss: {} '.format(scene, 196 | epoch, args.n_epoch, lr, np.mean(coord_loss_list), 197 | np.mean(lbl_1_loss_list), np.mean(lbl_2_loss_list), 198 | np.mean(train_loss_list)) 199 | else: 200 | logtt = 'Epoch {}/{} - lr: {} - reg_loss: {} - train_loss: {}' \ 201 | '\n'.format( 202 | epoch, args.n_epoch, lr, np.mean(coord_loss_list), 203 | np.mean(train_loss_list)) 204 | print(logtt) 205 | logfile.write(logtt) 206 | 207 | if epoch % int(np.floor(args.n_epoch / 1.)) == 0: 208 | save_state(args.save_path, epoch, model, optimizer) 209 | 210 | #start_epoch = epoch 211 | 212 | # add buffer data 213 | with torch.no_grad(): 214 | for i, (data_ori, data_buffer) in enumerate(tqdm(buffer_trainloader)): 215 | img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, frame = data_ori 216 | 217 | if mask.sum() == 0: 218 | continue 219 | optimizer.zero_grad() 220 | 221 | img = img.to(device) 222 | coord = coord.to(device) 223 | mask = mask.to(device) 224 | 225 | if args.dense_pred: 226 | # predictions 227 | lbl_1 = lbl_1.to(device) 228 | lbl_2 = lbl_2.to(device) 229 | lbl_1_oh = lbl_1_oh.to(device) 230 | lbl_2_oh = lbl_2_oh.to(device) 231 | coord_pred, lbl_2_pred, lbl_1_pred = model(img, lbl_1_oh, 232 | lbl_2_oh) 233 | preds = (coord_pred, lbl_1_pred, lbl_2_pred) 234 | if args.sampling == 'CoverageS': 235 | buffer.add_bal_buff(frame, preds, i) 236 | if args.sampling == 'Imgbal': 237 | buffer.add_imb_buffer(frame, preds, i) 238 | if args.sampling == 'Random': 239 | buffer.add_buffer_dense(frame, preds) 240 | else: 241 | if args.sampling == 'CoverageS': 242 | buffer.add_bal_buff(frame, nc=i) 243 | if args.sampling == 'Imgbal': 244 | buffer.add_imb_buffer(frame, nc=i) 245 | if args.sampling == 'Random': 246 | buffer.add_buffer_dense(frame) 247 | 248 | 249 | save_state(args.save_path, epoch, model, optimizer) 250 | 251 | def loss(img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, model, reg_loss, cls_loss, device, w1, w2, w3, dense_pred=None): 252 | 253 | lbl_1 = lbl_1.to(device) 254 | lbl_2 = lbl_2.to(device) 255 | lbl_1_oh = lbl_1_oh.to(device) 256 | lbl_2_oh = lbl_2_oh.to(device) 257 | coord_pred, lbl_2_pred, lbl_1_pred = model(img,lbl_1_oh, 258 | lbl_2_oh) 259 | 260 | lbl_1_loss = cls_loss(lbl_1_pred, lbl_1 , mask ) 261 | lbl_2_loss = cls_loss(lbl_2_pred, lbl_2 , mask ) 262 | coord_loss = reg_loss(coord_pred, coord , mask ) 263 | 264 | train_loss = w3*coord_loss + w1*lbl_1_loss + w2*lbl_2_loss 265 | 266 | 267 | return train_loss, coord_loss, lbl_1_loss, lbl_2_loss 268 | 269 | def loss_buff(img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, model, reg_loss, cls_loss, device, w1, w2, w3, dense_pred=None): 270 | lbl_1 = lbl_1.to(device) 271 | lbl_2 = lbl_2.to(device) 272 | lbl_1_oh = lbl_1_oh.to(device) 273 | lbl_2_oh = lbl_2_oh.to(device) 274 | coord_pred, lbl_2_pred, lbl_1_pred = model(img,lbl_1_oh, 275 | lbl_2_oh) 276 | 277 | lbl_1_loss = cls_loss(lbl_1_pred, lbl_1, mask) 278 | lbl_2_loss = cls_loss(lbl_2_pred, lbl_2, mask) 279 | coord_loss = reg_loss(coord_pred, coord, mask) 280 | 281 | train_loss = w3 * coord_loss + w1 * lbl_1_loss + w2 * lbl_2_loss 282 | 283 | if dense_pred: 284 | dense_pred_lbl_1 = dense_pred[0].to(device) 285 | dense_pred_lbl_2 = dense_pred[1].to(device) 286 | dense_pred_coord = dense_pred[2].to(device) 287 | 288 | L2_loss = nn.MSELoss() 289 | 290 | buff_lbl_1_loss = L2_loss(lbl_1_pred, dense_pred_lbl_1) 291 | buff_lbl_2_loss = L2_loss(lbl_2_pred, dense_pred_lbl_2) 292 | buff_coord_loss = L2_loss(coord_pred, dense_pred_coord) 293 | 294 | train_loss += 0.5 * (w1 * buff_lbl_1_loss + w2 * buff_lbl_2_loss + w3 * buff_coord_loss) 295 | 296 | #return buff_lbl_1_loss, buff_lbl_2_loss 297 | return train_loss 298 | 299 | 300 | def loss_buff_DK(img, coord, mask, lbl_1, lbl_2, lbl_1_oh, lbl_2_oh, model, reg_loss, cls_loss, device, w1, w2, w3, dense_pred=None): 301 | ## teacher loss as a upper bound ## 302 | lbl_1 = lbl_1.to(device) 303 | lbl_2 = lbl_2.to(device) 304 | lbl_1_oh = lbl_1_oh.to(device) 305 | lbl_2_oh = lbl_2_oh.to(device) 306 | coord_pred, lbl_2_pred, lbl_1_pred = model(img,lbl_1_oh, 307 | lbl_2_oh) 308 | 309 | lbl_1_loss = cls_loss(lbl_1_pred, lbl_1, mask) 310 | lbl_2_loss = cls_loss(lbl_2_pred, lbl_2, mask) 311 | coord_loss = reg_loss(coord_pred, coord, mask) 312 | # student VS gt loss 313 | 314 | if dense_pred: 315 | train_loss = 0.5 * (w3 * coord_loss + w1 * lbl_1_loss + w2 * lbl_2_loss) 316 | dense_pred_lbl_1 = dense_pred[0].to(device) 317 | dense_pred_lbl_2 = dense_pred[1].to(device) 318 | dense_pred_coord = dense_pred[2].to(device) 319 | 320 | L2_loss = nn.MSELoss() 321 | 322 | buff_lbl_1_loss = L2_loss(lbl_1_pred, dense_pred_lbl_1) 323 | buff_lbl_2_loss = L2_loss(lbl_2_pred, dense_pred_lbl_2) 324 | buff_coord_loss = L2_loss(coord_pred, dense_pred_coord) 325 | 326 | # teacher loss 327 | buff_teacher_loss = reg_loss(dense_pred_coord, coord, mask) 328 | buff_student_loss = reg_loss(coord_pred, coord, mask) 329 | if buff_student_loss > buff_teacher_loss: 330 | train_loss += 0.5 * (w1 * buff_lbl_1_loss + w2 * buff_lbl_2_loss + w3 * buff_coord_loss) 331 | else: 332 | train_loss += 0.5 * (w1 * buff_lbl_1_loss + w2 * buff_lbl_2_loss) 333 | else: 334 | train_loss = (w3 * coord_loss + w1 * lbl_1_loss + w2 * lbl_2_loss) 335 | return train_loss 336 | 337 | 338 | 339 | if __name__ == '__main__': 340 | parser = argparse.ArgumentParser(description="Hscnet") 341 | parser.add_argument('--model', nargs='?', type=str, default='hscnet', 342 | choices=('hscnet', 'scrnet'), 343 | help='Model to use [\'hscnet, scrnet\']') 344 | parser.add_argument('--dataset', nargs='?', type=str, default='7S', 345 | choices=('7S', '12S', 'i7S', 'i12S', 'i19S', 346 | 'Cambridge'), help='Dataset to use') 347 | parser.add_argument('--scene', nargs='?', type=str, default='heads', 348 | help='Scene') 349 | parser.add_argument('--n_iter', nargs='?', type=int, default=30000, 350 | help='# of iterations (to reproduce the results from ' \ 351 | 'the paper, 300K for 7S and 12S, 600K for ' \ 352 | 'Cambridge, 900K for the combined scenes)') 353 | parser.add_argument('--init_lr', nargs='?', type=float, default=5e-5, 354 | help='Initial learning rate') 355 | parser.add_argument('--batch_size', nargs='?', type=int, default=1, 356 | help='Batch size') 357 | parser.add_argument('--aug', nargs='?', type=str2bool, default=True, 358 | help='w/ or w/o data augmentation') 359 | parser.add_argument('--resume', nargs='?', type=str, default=None, 360 | help='Path to saved model to resume from') 361 | parser.add_argument('--data_path', required=True, type=str, 362 | help='Path to dataset') 363 | parser.add_argument('--log-summary', default='progress_log_summary.txt', 364 | metavar='PATH', 365 | help='txt where to save per-epoch stats') 366 | parser.add_argument('--train_id', nargs='?', type=str, default='', 367 | help='An identifier string'), 368 | parser.add_argument('--dense_pred', nargs='?', type=str2bool, default=False, 369 | help='store dense predictions in buffer') 370 | parser.add_argument('--exp_name', nargs='?', type=str, default='exp', 371 | help='store dense predictions in buffer') 372 | parser.add_argument('--buffer_size', nargs='?', type=int, default=1024, 373 | help='the length of buffer size') 374 | parser.add_argument('--sampling', nargs='?', type=str, default='Random', 375 | help='choose from Random, Imgbal, CoverageS') 376 | 377 | args = parser.parse_args() 378 | 379 | if args.dataset == '7S': 380 | if args.scene not in ['chess', 'heads', 'fire', 'office', 'pumpkin', 381 | 'redkitchen','stairs']: 382 | print('Selected scene is not valid.') 383 | sys.exit() 384 | 385 | if args.dataset == '12S': 386 | if args.scene not in ['apt1/kitchen', 'apt1/living', 'apt2/bed', 387 | 'apt2/kitchen', 'apt2/living', 'apt2/luke', 388 | 'office1/gates362', 'office1/gates381', 389 | 'office1/lounge', 'office1/manolis', 390 | 'office2/5a', 'office2/5b']: 391 | print('Selected scene is not valid.') 392 | sys.exit() 393 | 394 | if args.dataset == 'Cambridge': 395 | if args.scene not in ['GreatCourt', 'KingsCollege', 'OldHospital', 396 | 'ShopFacade', 'StMarysChurch']: 397 | print('Selected scene is not valid.') 398 | sys.exit() 399 | 400 | seed = 0 401 | np.random.seed(seed) 402 | torch.manual_seed(seed) 403 | random.seed(seed) 404 | if args.dense_pred: 405 | print('Dense predictions will be stored in buffer !') 406 | train(args) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import os 4 | import torch 5 | 6 | def str2bool(v): 7 | return v.lower() in ('yes', 'true', 't', '1') 8 | 9 | def adjust_lr(optimizer, init_lr, c_iter, n_iter): 10 | lr = init_lr * (0.5 ** ((c_iter + 200000 - n_iter) // 50000 + 1 if (c_iter 11 | + 200000 - n_iter) >= 0 else 0)) 12 | for param_group in optimizer.param_groups: 13 | param_group['lr'] = lr 14 | return lr 15 | 16 | def save_state(savepath, epoch, model, optimizer): 17 | state = {'epoch': epoch, 18 | 'model_state': model.state_dict(), 19 | 'optimizer_state': optimizer.state_dict()} 20 | filepath = os.path.join(savepath, 'model.pkl') 21 | torch.save(state, filepath) --------------------------------------------------------------------------------